• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2019 The Khronos Group Inc.
6  * Copyright (c) 2018-2019 NVIDIA Corporation
7  * Copyright (c) 2023 LunarG, Inc.
8  * Copyright (c) 2023 Nintendo
9  *
10  * Licensed under the Apache License, Version 2.0 (the "License");
11  * you may not use this file except in compliance with the License.
12  * You may obtain a copy of the License at
13  *
14  *	  http://www.apache.org/licenses/LICENSE-2.0
15  *
16  * Unless required by applicable law or agreed to in writing, software
17  * distributed under the License is distributed on an "AS IS" BASIS,
18  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19  * See the License for the specific language governing permissions and
20  * limitations under the License.
21  *
22  *//*!
23  * \file
24  * \brief Vulkan Cooperative Matrix tests
25  *//*--------------------------------------------------------------------*/
26 
27 #include "vktComputeCooperativeMatrixTests.hpp"
28 
29 #include "vkBufferWithMemory.hpp"
30 #include "vkImageWithMemory.hpp"
31 #include "vkQueryUtil.hpp"
32 #include "vkBuilderUtil.hpp"
33 #include "vkCmdUtil.hpp"
34 #include "vkTypeUtil.hpp"
35 #include "vkObjUtil.hpp"
36 
37 #include "vktTestGroupUtil.hpp"
38 #include "vktTestCase.hpp"
39 
40 #include "deDefs.h"
41 #include "deFloat16.h"
42 #include "deMath.h"
43 #include "deRandom.h"
44 #include "deSharedPtr.hpp"
45 #include "deString.h"
46 
47 #include "tcuTestCase.hpp"
48 #include "tcuTestLog.hpp"
49 
50 #include <string>
51 #include <sstream>
52 #include <set>
53 #include <algorithm>
54 
55 namespace vkt
56 {
57 namespace compute
58 {
59 namespace
60 {
61 using namespace vk;
62 using namespace std;
63 
64 //#define COOPERATIVE_MATRIX_EXTENDED_DEBUG 1
65 
66 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_FLOAT16_KHR == (uint32_t)VK_COMPONENT_TYPE_FLOAT16_NV);
67 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_FLOAT32_KHR == (uint32_t)VK_COMPONENT_TYPE_FLOAT32_NV);
68 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_FLOAT64_KHR == (uint32_t)VK_COMPONENT_TYPE_FLOAT64_NV);
69 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_SINT8_KHR   == (uint32_t)VK_COMPONENT_TYPE_SINT8_NV  );
70 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_SINT16_KHR  == (uint32_t)VK_COMPONENT_TYPE_SINT16_NV );
71 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_SINT32_KHR  == (uint32_t)VK_COMPONENT_TYPE_SINT32_NV );
72 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_SINT64_KHR  == (uint32_t)VK_COMPONENT_TYPE_SINT64_NV );
73 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_UINT8_KHR   == (uint32_t)VK_COMPONENT_TYPE_UINT8_NV  );
74 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_UINT16_KHR  == (uint32_t)VK_COMPONENT_TYPE_UINT16_NV );
75 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_UINT32_KHR  == (uint32_t)VK_COMPONENT_TYPE_UINT32_NV );
76 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_UINT64_KHR  == (uint32_t)VK_COMPONENT_TYPE_UINT64_NV );
77 
78 DE_STATIC_ASSERT((uint32_t)VK_SCOPE_DEVICE_KHR       == (uint32_t)VK_SCOPE_DEVICE_NV);
79 DE_STATIC_ASSERT((uint32_t)VK_SCOPE_WORKGROUP_KHR    == (uint32_t)VK_SCOPE_WORKGROUP_NV);
80 DE_STATIC_ASSERT((uint32_t)VK_SCOPE_SUBGROUP_KHR     == (uint32_t)VK_SCOPE_SUBGROUP_NV);
81 DE_STATIC_ASSERT((uint32_t)VK_SCOPE_QUEUE_FAMILY_KHR == (uint32_t)VK_SCOPE_QUEUE_FAMILY_NV);
82 
83 typedef enum
84 {
85 	UT_NV = 0,
86 	UT_KHR_A,
87 	UT_KHR_B,
88 	UT_KHR_Result,
89 } UseType;
90 
91 typedef enum
92 {
93 	TT_LENGTH = 0,
94 	TT_CONSTANT,
95 	TT_CONVERT,
96 	TT_COMPOSITE,
97 	TT_COMPOSITE_RVALUE,
98 	TT_ADD,
99 	TT_SUB,
100 	TT_DIV,
101 	TT_MUL,
102 	TT_NEGATE,
103 	TT_MATRIXTIMESSCALAR,
104 	TT_FUNC,
105 	TT_MATRIXMULADD,
106 	TT_COMPOSITE_ARRAY,
107 	TT_MATRIXMULADD_ARRAY,
108 	TT_MATRIXMULADD_SATURATED,
109 	TT_MATRIXMULADD_WRAPPING,
110 	TT_MATRIXMULADD_STRIDE0,
111 } TestType;
112 
113 typedef enum
114 {
115 	SC_BUFFER = 0,
116 	SC_WORKGROUP,
117 	SC_WORKGROUP_VARIABLE_POINTERS,
118 	SC_BUFFER_VARIABLE_POINTERS,
119 	SC_PHYSICAL_STORAGE_BUFFER,
120 } StorageClass;
121 
122 enum SubgroupSizeMode
123 {
124 	SUBGROUP_SIZE_NONE = 0,
125 	SUBGROUP_SIZE_MIN = 1,
126 	SUBGROUP_SIZE_MAX = 2,
127 };
128 
129 const VkFlags allShaderStages = VK_SHADER_STAGE_COMPUTE_BIT;
130 
131 struct CaseDef
132 {
133 	TestType							testType;
134 	deUint32							subgroupsPerWorkgroupX;
135 	deUint32							subgroupsPerWorkgroupY;
136 	deUint32							workgroupsX;
137 	deUint32							workgroupsY;
138 	VkComponentTypeKHR					inputType;
139 	VkComponentTypeKHR					outputType;
140 	bool								colMajor;
141 	StorageClass						storageClass;
142 	UseType								useType;
143 	SubgroupSizeMode					subgroupSizeMode;
144 	vk::ComputePipelineConstructionType	computePipelineConstructionType;
145 };
146 
isKhr(UseType useType)147 bool isKhr (UseType useType)
148 {
149 	return useType != UT_NV;
150 }
151 
isMatrixMulAddOp(TestType testType)152 bool isMatrixMulAddOp (TestType testType)
153 {
154 	return testType == TT_MATRIXMULADD || testType == TT_MATRIXMULADD_ARRAY || testType == TT_MATRIXMULADD_SATURATED || testType == TT_MATRIXMULADD_WRAPPING || testType == TT_MATRIXMULADD_STRIDE0;
155 }
156 
157 template<typename T>
getCooperativeMatrixProperties(const InstanceInterface &,VkPhysicalDevice,uint32_t *,T *)158 VkResult getCooperativeMatrixProperties (const InstanceInterface&, VkPhysicalDevice, uint32_t*, T*)
159 {
160 	TCU_THROW(InternalError, "Not Implementetd");
161 }
162 
getCooperativeMatrixProperties(const InstanceInterface & vki,VkPhysicalDevice physicalDevice,uint32_t * pPropertyCount,VkCooperativeMatrixPropertiesKHR * pProperties)163 VkResult getCooperativeMatrixProperties (const InstanceInterface& vki, VkPhysicalDevice physicalDevice, uint32_t* pPropertyCount, VkCooperativeMatrixPropertiesKHR* pProperties)
164 {
165 	return vki.getPhysicalDeviceCooperativeMatrixPropertiesKHR(physicalDevice, pPropertyCount, pProperties);
166 }
167 
getCooperativeMatrixProperties(const InstanceInterface & vki,VkPhysicalDevice physicalDevice,uint32_t * pPropertyCount,VkCooperativeMatrixPropertiesNV * pProperties)168 VkResult getCooperativeMatrixProperties (const InstanceInterface& vki, VkPhysicalDevice physicalDevice, uint32_t* pPropertyCount, VkCooperativeMatrixPropertiesNV* pProperties)
169 {
170 	return vki.getPhysicalDeviceCooperativeMatrixPropertiesNV(physicalDevice, pPropertyCount, pProperties);
171 }
172 
convertCooperativeMatrixProperties(const VkCooperativeMatrixPropertiesNV & properties)173 VkCooperativeMatrixPropertiesKHR convertCooperativeMatrixProperties (const VkCooperativeMatrixPropertiesNV& properties)
174 {
175 	VkCooperativeMatrixPropertiesKHR result = initVulkanStructure();
176 
177 	result.sType					= (VkStructureType)		properties.sType;
178 	result.pNext					= (void*)				properties.pNext;
179 	result.MSize					= (uint32_t)			properties.MSize;
180 	result.NSize					= (uint32_t)			properties.NSize;
181 	result.KSize					= (uint32_t)			properties.KSize;
182 	result.AType					= (VkComponentTypeKHR)	properties.AType;
183 	result.BType					= (VkComponentTypeKHR)	properties.BType;
184 	result.CType					= (VkComponentTypeKHR)	properties.CType;
185 	result.ResultType				= (VkComponentTypeKHR)	properties.DType;
186 	result.saturatingAccumulation	= (VkBool32)			VK_FALSE;
187 	result.scope					= (VkScopeKHR)			properties.scope;
188 
189 	return result;
190 }
191 
convertCooperativeMatrixProperties(const std::vector<VkCooperativeMatrixPropertiesNV> & properties)192 std::vector<VkCooperativeMatrixPropertiesKHR> convertCooperativeMatrixProperties (const std::vector <VkCooperativeMatrixPropertiesNV>& properties)
193 {
194 	std::vector<VkCooperativeMatrixPropertiesKHR> result(properties.size());
195 
196 	for (size_t i = 0; i < properties.size(); ++i)
197 		result[i] = convertCooperativeMatrixProperties(properties[i]);
198 
199 	return result;
200 }
201 
202 template<typename T>
getCooperativeMatrixPropertiesAll(Context & context,std::vector<T> & properties)203 void getCooperativeMatrixPropertiesAll (Context& context, std::vector<T>& properties)
204 {
205 	deUint32	propertyCount	= 0;
206 
207 	VK_CHECK(getCooperativeMatrixProperties(context.getInstanceInterface(), context.getPhysicalDevice(), &propertyCount, (T*)DE_NULL));
208 
209 	if (propertyCount > 0)
210 	{
211 		const T sample = initVulkanStructureConst();
212 
213 		properties.resize(propertyCount, sample);
214 
215 		VK_CHECK(getCooperativeMatrixProperties(context.getInstanceInterface(), context.getPhysicalDevice(), &propertyCount, properties.data()));
216 	}
217 	else
218 	{
219 		properties.clear();
220 	}
221 }
222 
getCooperativeMatrixPropertiesConverted(Context & context,const bool khr)223 std::vector<VkCooperativeMatrixPropertiesKHR> getCooperativeMatrixPropertiesConverted (Context& context, const bool khr)
224 {
225 	std::vector<VkCooperativeMatrixPropertiesKHR> properties;
226 
227 	if (khr)
228 	{
229 		getCooperativeMatrixPropertiesAll(context, properties);
230 	}
231 	else
232 	{
233 		std::vector<VkCooperativeMatrixPropertiesNV> propertiesNV;
234 
235 		getCooperativeMatrixPropertiesAll(context, propertiesNV);
236 
237 		properties = convertCooperativeMatrixProperties(propertiesNV);
238 	}
239 
240 	return properties;
241 }
242 
getSubgroupSizeFromMode(Context & context,const SubgroupSizeMode subgroupSizeMode)243 deUint32 getSubgroupSizeFromMode (Context&					context,
244 								  const SubgroupSizeMode	subgroupSizeMode)
245 {
246 #ifndef CTS_USES_VULKANSC
247 	const VkPhysicalDeviceSubgroupSizeControlProperties&	subgroupSizeControlProperties = context.getSubgroupSizeControlProperties();
248 #else
249 	const VkPhysicalDeviceSubgroupSizeControlPropertiesEXT&	subgroupSizeControlProperties = context.getSubgroupSizeControlPropertiesEXT();
250 #endif // CTS_USES_VULKANSC
251 
252 	switch (subgroupSizeMode)
253 	{
254 		case SUBGROUP_SIZE_MAX:		return subgroupSizeControlProperties.maxSubgroupSize;
255 		case SUBGROUP_SIZE_MIN:		return subgroupSizeControlProperties.minSubgroupSize;
256 		case SUBGROUP_SIZE_NONE:	return context.getSubgroupProperties().subgroupSize;
257 		default:					TCU_THROW(NotSupportedError, "Unsupported Subgroup size");
258 	}
259 }
260 
261 
262 class CooperativeMatrixTestInstance : public TestInstance
263 {
264 public:
265 						CooperativeMatrixTestInstance	(Context& context, const CaseDef& data);
266 						~CooperativeMatrixTestInstance	(void);
267 	tcu::TestStatus		iterate							(void);
268 private:
269 	CaseDef			m_data;
270 };
271 
CooperativeMatrixTestInstance(Context & context,const CaseDef & data)272 CooperativeMatrixTestInstance::CooperativeMatrixTestInstance (Context& context, const CaseDef& data)
273 	: vkt::TestInstance		(context)
274 	, m_data				(data)
275 {
276 }
277 
~CooperativeMatrixTestInstance(void)278 CooperativeMatrixTestInstance::~CooperativeMatrixTestInstance (void)
279 {
280 }
281 
282 class CooperativeMatrixTestCase : public TestCase
283 {
284 	public:
285 								CooperativeMatrixTestCase		(tcu::TestContext& context, const char* name, const CaseDef data);
286 								~CooperativeMatrixTestCase	(void);
287 	virtual	void				initPrograms		(SourceCollections& programCollection) const;
288 	virtual TestInstance*		createInstance		(Context& context) const;
289 	virtual void				checkSupport		(Context& context) const;
290 
291 private:
292 	CaseDef					m_data;
293 };
294 
CooperativeMatrixTestCase(tcu::TestContext & context,const char * name,const CaseDef data)295 CooperativeMatrixTestCase::CooperativeMatrixTestCase (tcu::TestContext& context, const char* name, const CaseDef data)
296 	: vkt::TestCase	(context, name)
297 	, m_data		(data)
298 {
299 }
300 
~CooperativeMatrixTestCase(void)301 CooperativeMatrixTestCase::~CooperativeMatrixTestCase (void)
302 {
303 }
304 
checkSupport(Context & context) const305 void CooperativeMatrixTestCase::checkSupport (Context& context) const
306 {
307 	if (!context.contextSupports(vk::ApiVersion(0, 1, 1, 0)))
308 	{
309 		TCU_THROW(NotSupportedError, "Vulkan 1.1 not supported");
310 	}
311 
312 	if (isKhr(m_data.useType))
313 	{
314 		if (!context.getCooperativeMatrixFeatures().cooperativeMatrix)
315 		{
316 			TCU_THROW(NotSupportedError, "VkPhysicalDeviceCooperativeMatrixFeaturesKHR::cooperativeMatrix not supported");
317 		}
318 	}
319 	else
320 	{
321 		if (!context.getCooperativeMatrixFeaturesNV().cooperativeMatrix)
322 		{
323 			TCU_THROW(NotSupportedError, "VkPhysicalDeviceCooperativeMatrixFeaturesNV::cooperativeMatrix not supported");
324 		}
325 	}
326 
327 	if (!context.getVulkanMemoryModelFeatures().vulkanMemoryModel)
328 	{
329 		TCU_THROW(NotSupportedError, "vulkanMemoryModel not supported");
330 	}
331 
332 	if ((m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS || m_data.storageClass == SC_BUFFER_VARIABLE_POINTERS) &&
333 		!context.getVariablePointersFeatures().variablePointers)
334 	{
335 		TCU_THROW(NotSupportedError, "variable pointers not supported");
336 	}
337 
338 	if (m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER && !context.isBufferDeviceAddressSupported())
339 	{
340 		TCU_THROW(NotSupportedError, "buffer device address not supported");
341 	}
342 
343 	if (!context.getShaderFloat16Int8Features().shaderFloat16 &&
344 		(m_data.inputType == VK_COMPONENT_TYPE_FLOAT16_KHR || m_data.outputType == VK_COMPONENT_TYPE_FLOAT16_KHR))
345 	{
346 		TCU_THROW(NotSupportedError, "shaderFloat16 not supported");
347 	}
348 
349 	std::vector<VkCooperativeMatrixPropertiesKHR>	properties		= getCooperativeMatrixPropertiesConverted(context, isKhr(m_data.useType));
350 	bool											supported[2]	= { false, false };
351 	const auto										isMMA			= isMatrixMulAddOp(m_data.testType);
352 	const auto										isMMASat		= m_data.testType == TT_MATRIXMULADD_SATURATED;
353 
354 	for (size_t i = 0; i < properties.size(); ++i)
355 	{
356 		const VkCooperativeMatrixPropertiesKHR*	p	= &properties[i];
357 
358 		if (p->scope != VK_SCOPE_SUBGROUP_KHR)
359 			continue;
360 
361 		if (isMMA && isMMASat != static_cast<bool>(p->saturatingAccumulation))
362 			continue;
363 
364 		if (isMMA)
365 		{
366 			if (p->AType == m_data.inputType &&
367 				p->BType == m_data.inputType &&
368 				p->CType == m_data.outputType &&
369 				p->ResultType == m_data.outputType)
370 			{
371 				supported[0] = supported[1] = true;
372 			}
373 		}
374 		else
375 		{
376 			const VkComponentTypeKHR types[2] = { m_data.inputType, m_data.outputType };
377 
378 			for (deUint32 j = 0; j < 2; ++j)
379 			{
380 				switch (m_data.useType)
381 				{
382 					case UT_NV:
383 					{
384 						if (p->AType == types[j] || p->BType == types[j] || p->CType == types[j] || p->ResultType == types[j])
385 							supported[j] = true;
386 
387 						break;
388 					}
389 					case UT_KHR_A:
390 					{
391 						if (p->AType == types[j])
392 							supported[j] = true;
393 
394 						break;
395 					}
396 					case UT_KHR_B:
397 					{
398 						if (p->BType == types[j])
399 							supported[j] = true;
400 
401 						break;
402 					}
403 					case UT_KHR_Result:
404 					{
405 						if (p->ResultType == types[j])
406 							supported[j] = true;
407 
408 						break;
409 					}
410 					default:
411 						TCU_THROW(InternalError, "Unsupported use type");
412 				}
413 			}
414 		}
415 	}
416 
417 	if (!supported[0] || !supported[1])
418 		TCU_THROW(NotSupportedError, "cooperative matrix combination not supported");
419 
420 	checkShaderObjectRequirements(context.getInstanceInterface(), context.getPhysicalDevice(), m_data.computePipelineConstructionType);
421 }
422 
423 struct {
424 	const char *typeName;
425 	const char *coopmatTypeName;
426 	deUint32 bits;
427 	bool isSigned;
428 } componentTypeInfo[] =
429 {
430 	{ "float16_t",	"fcoopmatNV",	16, true },
431 	{ "float32_t",	"fcoopmatNV",	32, true },
432 	{ "float64_t",	"fcoopmatNV",	64, true },
433 	{ "int8_t",		"icoopmatNV",	8, true },
434 	{ "int16_t",	"icoopmatNV",	16, true },
435 	{ "int32_t",	"icoopmatNV",	32, true },
436 	{ "int64_t",	"icoopmatNV",	64, true },
437 	{ "uint8_t",	"ucoopmatNV",	8, false },
438 	{ "uint16_t",	"ucoopmatNV",	16, false },
439 	{ "uint32_t",	"ucoopmatNV",	32, false },
440 	{ "uint64_t",	"ucoopmatNV",	64, false },
441 };
442 
isFloatType(VkComponentTypeKHR t)443 bool isFloatType (VkComponentTypeKHR t)
444 {
445 	switch (t)
446 	{
447 		case VK_COMPONENT_TYPE_FLOAT16_KHR:
448 		case VK_COMPONENT_TYPE_FLOAT32_KHR:
449 		case VK_COMPONENT_TYPE_FLOAT64_KHR:
450 			return true;
451 		default:
452 			return false;
453 	}
454 }
455 
isSIntType(VkComponentTypeKHR t)456 bool isSIntType (VkComponentTypeKHR t)
457 {
458 	switch (t)
459 	{
460 		case VK_COMPONENT_TYPE_SINT8_KHR:
461 		case VK_COMPONENT_TYPE_SINT16_KHR:
462 		case VK_COMPONENT_TYPE_SINT32_KHR:
463 		case VK_COMPONENT_TYPE_SINT64_KHR:
464 			return true;
465 		default:
466 			return false;
467 	}
468 }
469 
initPrograms(SourceCollections & programCollection) const470 void CooperativeMatrixTestCase::initPrograms (SourceCollections& programCollection) const
471 {
472 	const char*			suffix	= (isKhr(m_data.useType) ? "" : "NV");
473 	const char*			ext		= isKhr(m_data.useType)
474 								? "#extension GL_KHR_cooperative_matrix : enable\n"
475 								: "#extension GL_NV_cooperative_matrix : enable\n"
476 								  "#extension GL_NV_integer_cooperative_matrix : enable\n";
477 	const char*			sat		= (m_data.testType == TT_MATRIXMULADD_SATURATED) ? ", gl_MatrixOperandsSaturatingAccumulation" : "";
478 	std::stringstream	css;
479 	css << "#version 450 core\n";
480 	css << "#pragma use_vulkan_memory_model\n";
481 	css <<
482 		"#extension GL_KHR_shader_subgroup_basic : enable\n"
483 		"#extension GL_KHR_memory_scope_semantics : enable\n"
484 		<< ext <<
485 		"#extension GL_EXT_shader_explicit_arithmetic_types : enable\n"
486 		"#extension GL_EXT_buffer_reference : enable\n"
487 		"// strides overriden by spec constants\n"
488 		"layout(constant_id = 2) const int AStride = 1;\n"
489 		"layout(constant_id = 3) const int BStride = 1;\n"
490 		"layout(constant_id = 4) const int CStride = 1;\n"
491 		"layout(constant_id = 5) const int OStride = 1;\n"
492 		"layout(constant_id = 6) const int M = 1;\n"
493 		"layout(constant_id = 7) const int N = 1;\n"
494 		"layout(constant_id = 8) const int K = 1;\n"
495 		"layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z = 1) in;\n";
496 
497 	if (m_data.storageClass == SC_BUFFER_VARIABLE_POINTERS || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
498 		css << "#pragma use_variable_pointers\n";
499 
500 	struct
501 	{
502 		string rows, cols;
503 	} dims[4];
504 
505 	if (isMatrixMulAddOp(m_data.testType))
506 	{
507 		dims[0].rows = "M";
508 		dims[0].cols = "K";
509 		dims[1].rows = "K";
510 		dims[1].cols = "N";
511 		dims[2].rows = "M";
512 		dims[2].cols = "N";
513 		dims[3].rows = "M";
514 		dims[3].cols = "N";
515 	}
516 	else
517 	{
518 		dims[0].rows = "M";
519 		dims[0].cols = "N";
520 		dims[1].rows = "M";
521 		dims[1].cols = "N";
522 		dims[2].rows = "M";
523 		dims[2].cols = "N";
524 		dims[3].rows = "M";
525 		dims[3].cols = "N";
526 	}
527 
528 	const char *typeStrA = componentTypeInfo[m_data.inputType].typeName;
529 	const char *typeStrB = componentTypeInfo[m_data.inputType].typeName;
530 	const char *typeStrC = componentTypeInfo[m_data.outputType].typeName;
531 	const char *typeStrO = componentTypeInfo[m_data.outputType].typeName;
532 
533 	css << "const int workgroupsX = " << m_data.workgroupsX << ";\n";
534 	css << "const uvec2 subgroupsPerWG = uvec2(" << m_data.subgroupsPerWorkgroupX << ", " << m_data.subgroupsPerWorkgroupY << ");\n";
535 
536 	if (m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER)
537 	{
538 		css << "layout(buffer_reference) buffer InputA { " << typeStrA << " x[]; };\n";
539 		css << "layout(buffer_reference) buffer InputB { " << typeStrB << " x[]; };\n";
540 		css << "layout(buffer_reference) buffer InputC { " << typeStrC << " x[]; };\n";
541 		css << "layout(buffer_reference) buffer Output { " << typeStrO << " x[]; };\n";
542 		css << "layout(set=0, binding=4) buffer Params { InputA inputA; InputB inputB; InputC inputC; Output outputO; } params;\n";
543 	}
544 	else
545 	{
546 		css << "layout(set=0, binding=0) coherent buffer InputA { " << typeStrA << " x[]; } inputA;\n";
547 		css << "layout(set=0, binding=1) coherent buffer InputB { " << typeStrB << " x[]; } inputB;\n";
548 		css << "layout(set=0, binding=2) coherent buffer InputC { " << typeStrC << " x[]; } inputC;\n";
549 		css << "layout(set=0, binding=3) coherent buffer Output { " << typeStrO << " x[]; } outputO;\n";
550 	}
551 
552 	if (m_data.storageClass == SC_WORKGROUP || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
553 	{
554 		css << "shared " << typeStrA << " sharedA[" << dims[0].rows << " * " << dims[0].cols << " * subgroupsPerWG.x * subgroupsPerWG.y];\n";
555 		css << "shared " << typeStrB << " sharedB[" << dims[1].rows << " * " << dims[1].cols << " * subgroupsPerWG.x * subgroupsPerWG.y];\n";
556 		css << "shared " << typeStrC << " sharedC[" << dims[2].rows << " * " << dims[2].cols << " * subgroupsPerWG.x * subgroupsPerWG.y];\n";
557 		css << "shared " << typeStrO << " sharedO[" << dims[3].rows << " * " << dims[3].cols << " * subgroupsPerWG.x * subgroupsPerWG.y];\n";
558 	}
559 
560 	std::stringstream matAType, matBType, matCType, outputMatType;
561 
562 	if (isKhr(m_data.useType))
563 	{
564 		const bool	useSame		= !isMatrixMulAddOp(m_data.testType);
565 		const char*	sameType	= m_data.useType == UT_KHR_A ? "gl_MatrixUseA"
566 								: m_data.useType == UT_KHR_B ? "gl_MatrixUseB"
567 								: m_data.useType == UT_KHR_Result ? "gl_MatrixUseAccumulator"
568 								: "Invalid use";
569 		const char*	atype		= useSame ? sameType : "gl_MatrixUseA";
570 		const char*	btype		= useSame ? sameType : "gl_MatrixUseB";
571 		const char*	ctype		= useSame ? sameType : "gl_MatrixUseAccumulator";
572 		const char*	rtype		= useSame ? sameType : "gl_MatrixUseAccumulator";
573 
574 		matAType << "coopmat<" << componentTypeInfo[m_data.inputType].typeName << ", gl_ScopeSubgroup, " << dims[0].rows << ", " << dims[0].cols << ", " << atype << ">";
575 		matBType << "coopmat<" << componentTypeInfo[m_data.inputType].typeName << ", gl_ScopeSubgroup, " << dims[1].rows << ", " << dims[1].cols << ", " << btype << ">";
576 		matCType << "coopmat<" << componentTypeInfo[m_data.outputType].typeName << ", gl_ScopeSubgroup, " << dims[2].rows << ", " << dims[2].cols << ", " << ctype << ">";
577 		outputMatType << "coopmat<" << componentTypeInfo[m_data.outputType].typeName << ", gl_ScopeSubgroup, " << dims[3].rows << ", " << dims[3].cols << ", " << rtype << ">";
578 	}
579 	else
580 	{
581 		matAType << componentTypeInfo[m_data.inputType].coopmatTypeName << "<" << componentTypeInfo[m_data.inputType].bits << ", gl_ScopeSubgroup, " << dims[0].rows << ", " << dims[0].cols << ">";
582 		matBType << componentTypeInfo[m_data.inputType].coopmatTypeName << "<" << componentTypeInfo[m_data.inputType].bits << ", gl_ScopeSubgroup, " << dims[1].rows << ", " << dims[1].cols << ">";
583 		matCType << componentTypeInfo[m_data.outputType].coopmatTypeName << "<" << componentTypeInfo[m_data.outputType].bits << ", gl_ScopeSubgroup, " << dims[2].rows << ", " << dims[2].cols << ">";
584 		outputMatType << componentTypeInfo[m_data.outputType].coopmatTypeName << "<" << componentTypeInfo[m_data.outputType].bits << ", gl_ScopeSubgroup, " << dims[3].rows << ", " << dims[3].cols << ">";
585 	}
586 
587 	css << matAType.str() << " matA;\n";
588 	css << matBType.str() << " matB;\n";
589 	css << matCType.str() << " matC;\n";
590 	css << outputMatType.str() << " matO;\n";
591 
592 	if (m_data.testType == TT_CONSTANT)
593 		css << "const " << outputMatType.str() << " matConst = " << outputMatType.str() << "(1.0);\n";
594 
595 	if (m_data.testType == TT_FUNC)
596 		css << matAType.str() << " f(" << matAType.str() << " m) { return -m; }\n";
597 
598 	css <<
599 		"void main()\n"
600 		"{\n"
601 		// matrixID is the x,y index of the matrix owned by this subgroup.
602 		"   uvec2 subgroupXY = uvec2(gl_SubgroupID % subgroupsPerWG.x, gl_SubgroupID / subgroupsPerWG.x);\n"
603 		"   uvec2 matrixID = uvec2(gl_WorkGroupID.xy) * subgroupsPerWG + subgroupXY;\n";
604 
605 	if (m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER)
606 	{
607 		css << "   InputA inputA = params.inputA;\n";
608 		css << "   InputB inputB = params.inputB;\n";
609 		css << "   InputC inputC = params.inputC;\n";
610 		css << "   Output outputO = params.outputO;\n";
611 	}
612 
613 	string strides[4];
614 	for (deUint32 i = 0; i < 4; ++i)
615 	{
616 		strides[i] = (m_data.colMajor ? dims[i].rows : dims[i].cols) + string(" * ") + de::toString(m_data.subgroupsPerWorkgroupX * m_data.workgroupsX);
617 	}
618 
619 	// element<i> is the starting element in buffer memory.
620 	// elementS<i> is the starting element in shared memory.
621 	css << "   uint element0 = " << strides[0] << " * " << (m_data.colMajor ? dims[0].cols : dims[0].rows) << " * matrixID.y + " << (m_data.colMajor ? dims[0].rows : dims[0].cols) << " * matrixID.x;\n"
622 		   "   uint element1 = " << strides[1] << " * " << (m_data.colMajor ? dims[1].cols : dims[1].rows) << " * matrixID.y + " << (m_data.colMajor ? dims[1].rows : dims[1].cols) << " * matrixID.x;\n"
623 		   "   uint element2 = " << strides[2] << " * " << (m_data.colMajor ? dims[2].cols : dims[2].rows) << " * matrixID.y + " << (m_data.colMajor ? dims[2].rows : dims[2].cols) << " * matrixID.x;\n"
624 		   "   uint element3 = " << strides[3] << " * " << (m_data.colMajor ? dims[3].cols : dims[3].rows) << " * matrixID.y + " << (m_data.colMajor ? dims[3].rows : dims[3].cols) << " * matrixID.x;\n"
625 		   "   uint elementS0, elementS1, elementS2, elementS3;\n";
626 
627 	// For shared memory tests, copy the matrix from buffer memory into
628 	// workgroup memory. For simplicity, do it all on a single thread.
629 	if (m_data.storageClass == SC_WORKGROUP || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
630 	{
631 		const char *name[] =
632 		{
633 			"sharedA",
634 			"sharedB",
635 			"sharedC",
636 		};
637 		const char *inputName[] =
638 		{
639 			"inputA",
640 			"inputB",
641 			"inputC",
642 		};
643 		for (deUint32 m = 0; m < 4; ++m)
644 		{
645 			string sharedStride = strides[m] + " / workgroupsX";
646 			css << "       elementS" << m << " = " << sharedStride << " * " << (m_data.colMajor ? dims[m].cols : dims[m].rows) << " * subgroupXY.y + " << (m_data.colMajor ? dims[m].rows : dims[m].cols) << " * subgroupXY.x;\n";
647 		}
648 		css << "   if (subgroupElect()) {\n";
649 		// copy all three input buffers.
650 		for (deUint32 m = 0; m < 3; ++m)
651 		{
652 			string sharedStride = strides[m] + " / workgroupsX";
653 			css <<  "       for (int i = 0; i < " << dims[m].rows << "; ++i) {\n"
654 					"       for (int j = 0; j < " << dims[m].cols << "; ++j) {\n"
655 					"           int localElementInput = " << strides[m] << " * " << (m_data.colMajor ? "j" : "i") << " + " << (m_data.colMajor ? "i" : "j") << ";\n"
656 					"           int localElementShared = " << sharedStride << " * " << (m_data.colMajor ? "j" : "i") << " + " << (m_data.colMajor ? "i" : "j") << ";\n"
657 					"           " << name[m] << "[elementS" << m << " + localElementShared] = " << inputName[m] << ".x[element" << m << " + localElementInput];\n"
658 					"       }\n"
659 					"       }\n";
660 			strides[m] = sharedStride;
661 		}
662 		css << "   }\n";
663 		css << "   controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease);\n";
664 	}
665 
666 	const char *colMajorNV  = (m_data.colMajor ? "true" : "false");
667 	const char* colMajorKHR = (m_data.colMajor ? "gl_CooperativeMatrixLayoutColumnMajor" : "gl_CooperativeMatrixLayoutRowMajor");
668 	const char* colMajor    = (isKhr(m_data.useType) ? colMajorKHR : colMajorNV);
669 
670 	string loadStrides[3] = { strides[0], strides[1], strides[2] };
671 	// Load with a stride of 0
672 	if (m_data.testType == TT_MATRIXMULADD_STRIDE0)
673 		loadStrides[0] = loadStrides[1] = loadStrides[2] = "0";
674 
675 	if (m_data.storageClass == SC_WORKGROUP || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
676 	{
677 		css <<  "   coopMatLoad" << suffix << "(matA, sharedA, elementS0, " << loadStrides[0] << ", " << colMajor << ");\n"
678 				"   coopMatLoad" << suffix << "(matB, sharedB, elementS1, " << loadStrides[1] << ", " << colMajor << ");\n"
679 				"   coopMatLoad" << suffix << "(matC, sharedC, elementS2, " << loadStrides[2] << ", " << colMajor << ");\n";
680 	}
681 	else
682 	{
683 		css << "   coopMatLoad" << suffix << "(matA, inputA.x, element0, " << loadStrides[0] << ", " << colMajor << ");\n"
684 			   "   coopMatLoad" << suffix << "(matB, inputB.x, element1, " << loadStrides[1] << ", " << colMajor << ");\n"
685 			   "   coopMatLoad" << suffix << "(matC, inputC.x, element2, " << loadStrides[2] << ", " << colMajor << ");\n";
686 	}
687 
688 	if (m_data.testType == TT_COMPOSITE_ARRAY ||
689 		m_data.testType == TT_MATRIXMULADD_ARRAY)
690 	{
691 		css << "   " << matAType.str() << " matAArr[2];\n    matAArr[1] = matA; matAArr[0] = " << matAType.str() << "(0.0);\n"
692 			   "   " << matBType.str() << " matBArr[2];\n    matBArr[1] = matB; matBArr[0] = " << matBType.str() << "(0.0);\n"
693 			   "   " << matCType.str() << " matCArr[2];\n    matCArr[1] = matC; matCArr[0] = " << matCType.str() << "(0.0);\n"
694 			   "   " << outputMatType.str() << " matOArr[2];\n";
695 	}
696 
697 	switch (m_data.testType)
698 	{
699 	default:
700 		DE_ASSERT(0);
701 		// fall through
702 	case TT_LENGTH:
703 		css << "   matO = " << outputMatType.str() << "(matO.length());\n";
704 		break;
705 	case TT_CONSTANT:
706 		css << "   matO = matConst;\n";
707 		break;
708 	case TT_CONVERT:
709 		css << "   matO = " << outputMatType.str() << "(matA);\n";
710 		break;
711 	case TT_COMPOSITE:
712 		css << "   " << matAType.str() << " t = " << matAType.str() << "(matB[0]);\n"
713 			"   for (int i = 1; i < matA.length(); ++i) {\n"
714 			"       matO[i] = matA[i] + matB[i];\n"
715 			"   }\n"
716 			"   if (matA.length() > 0)\n"
717 			"       matO[0] = matA[0] + t[0];\n";
718 		break;
719 	case TT_COMPOSITE_RVALUE:
720 		css << "   for (int i = 1; i < matA.length(); ++i) {\n"
721 			   "       matO[i] = matA[i] + matB[i];\n"
722 			   "   }\n"
723 			   "   " << matAType.str() << " t = matA;\n"
724 			   "   if (matA.length() > 0) {\n"
725 			   "       matO[0] = (t += matB)[0];\n"
726 			   "   }\n";
727 		break;
728 	case TT_COMPOSITE_ARRAY:
729 		css << "   for (int i = 0; i < matA.length(); ++i) {\n"
730 			   "       matOArr[1][i] = matAArr[1][i] + matBArr[1][i];\n"
731 			   "   }\n";
732 		break;
733 	case TT_ADD:
734 		css << "   matO = matA + matB;\n";
735 		break;
736 	case TT_SUB:
737 		css << "   matO = matA - matB;\n";
738 		break;
739 	case TT_DIV:
740 		css << "   matO = matA / matB;\n";
741 		break;
742 	case TT_MUL:
743 		css << "   matO = matA * matB;\n";
744 		break;
745 	case TT_NEGATE:
746 		css << "   matO = -matA;\n";
747 		break;
748 	case TT_FUNC:
749 		css << "   matO = f(matA);\n";
750 		break;
751 	case TT_MATRIXTIMESSCALAR:
752 		css << "   matO = (" << typeStrA << "(2.0)*matA)*" << typeStrA << "(3.0);\n";
753 		break;
754 	case TT_MATRIXMULADD_STRIDE0:
755 	case TT_MATRIXMULADD_WRAPPING:
756 	case TT_MATRIXMULADD_SATURATED:
757 	case TT_MATRIXMULADD:
758 		css << "   matO = coopMatMulAdd" << suffix << "(matA, matB, matC" << sat << ");\n";
759 		break;
760 	case TT_MATRIXMULADD_ARRAY:
761 		css << "   matOArr[1] = coopMatMulAdd" << suffix << "(matAArr[1], matBArr[1], matCArr[1]);\n";
762 		break;
763 	}
764 
765 	if (m_data.testType == TT_COMPOSITE_ARRAY ||
766 		m_data.testType == TT_MATRIXMULADD_ARRAY)
767 	{
768 		css << "   matOArr[0] = " << outputMatType.str() << "(0.0);\n";
769 		css << "   matO = matOArr[1];\n";
770 	}
771 
772 	if (m_data.storageClass == SC_WORKGROUP || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
773 	{
774 		string sharedStride = strides[3] + " / workgroupsX";
775 		css << "   coopMatStore" << suffix << "(matO, sharedO, elementS3, " << sharedStride << ", " << colMajor << ");\n";
776 		css << "   controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease);\n";
777 		css << "   if (subgroupElect()) {\n";
778 		css << "       for (int i = 0; i < " << dims[3].rows << "; ++i) {\n"
779 			   "       for (int j = 0; j < " << dims[3].cols << "; ++j) {\n"
780 			   "           int localElementInput = " << strides[3] << " * " << (m_data.colMajor ? "j" : "i") << " + " << (m_data.colMajor ? "i" : "j") << ";\n"
781 			   "           int localElementShared = " << sharedStride << " * " << (m_data.colMajor ? "j" : "i") << " + " << (m_data.colMajor ? "i" : "j") << ";\n"
782 			   "           outputO.x[element3 + localElementInput] = sharedO[elementS3 + localElementShared];\n"
783 			   "       }\n"
784 			   "       }\n";
785 		css << "   }\n";
786 	}
787 	else
788 	{
789 		css << "   coopMatStore" << suffix << "(matO, outputO.x, element3, " << strides[3] << ", " << colMajor << ");\n";
790 	}
791 
792 	css <<
793 		"}\n";
794 
795 	const vk::ShaderBuildOptions	buildOptions	(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
796 
797 	programCollection.glslSources.add("test") << glu::ComputeSource(css.str()) << buildOptions;
798 }
799 
createInstance(Context & context) const800 TestInstance* CooperativeMatrixTestCase::createInstance (Context& context) const
801 {
802 	return new CooperativeMatrixTestInstance(context, m_data);
803 }
804 
setDataFloat(void * base,VkComponentTypeKHR dt,deUint32 i,float value)805 void setDataFloat (void *base, VkComponentTypeKHR dt, deUint32 i, float value)
806 {
807 	if (dt == VK_COMPONENT_TYPE_FLOAT32_KHR)
808 	{
809 		((float *)base)[i] = value;
810 	}
811 	else
812 	{
813 		DE_ASSERT(dt == VK_COMPONENT_TYPE_FLOAT16_KHR);
814 		((deFloat16 *)base)[i] = deFloat32To16(value);
815 	}
816 }
817 
getDataFloat(void * base,VkComponentTypeKHR dt,deUint32 i)818 float getDataFloat (void *base, VkComponentTypeKHR dt, deUint32 i)
819 {
820 	if (dt == VK_COMPONENT_TYPE_FLOAT32_KHR)
821 	{
822 		return ((float *)base)[i];
823 	}
824 	else
825 	{
826 		DE_ASSERT(dt == VK_COMPONENT_TYPE_FLOAT16_KHR);
827 		return deFloat16To32(((deFloat16 *)base)[i]);
828 	}
829 }
830 
setDataInt(void * base,VkComponentTypeKHR dt,deUint32 i,deUint32 value)831 void setDataInt (void *base, VkComponentTypeKHR dt, deUint32 i, deUint32 value)
832 {
833 	DE_ASSERT(componentTypeInfo[dt].bits <= 32);
834 
835 	switch (dt)
836 	{
837 		case VK_COMPONENT_TYPE_UINT8_KHR:	((deUint8  *)base)[i] = (deUint8)value; break;
838 		case VK_COMPONENT_TYPE_UINT16_KHR:	((deUint16 *)base)[i] = (deUint16)value; break;
839 		case VK_COMPONENT_TYPE_UINT32_KHR:	((deUint32 *)base)[i] = (deUint32)value; break;
840 		case VK_COMPONENT_TYPE_SINT8_KHR:	((deInt8  *)base)[i] = (deInt8)value; break;
841 		case VK_COMPONENT_TYPE_SINT16_KHR:	((deInt16 *)base)[i] = (deInt16)value; break;
842 		case VK_COMPONENT_TYPE_SINT32_KHR:	((deInt32 *)base)[i] = (deInt32)value; break;
843 		default:							TCU_THROW(InternalError, "Unsupported type");
844 	}
845 }
846 
getDataInt(void * base,VkComponentTypeKHR dt,deUint32 i)847 deUint32 getDataInt (void *base, VkComponentTypeKHR dt, deUint32 i)
848 {
849 	DE_ASSERT(componentTypeInfo[dt].bits <= 32);
850 
851 	switch (dt)
852 	{
853 		case VK_COMPONENT_TYPE_UINT8_KHR:	return ((deUint8*)base)[i];
854 		case VK_COMPONENT_TYPE_UINT16_KHR:	return ((deUint16*)base)[i];
855 		case VK_COMPONENT_TYPE_UINT32_KHR:	return ((deUint32*)base)[i];
856 		case VK_COMPONENT_TYPE_SINT8_KHR:	return ((deInt8*)base)[i];
857 		case VK_COMPONENT_TYPE_SINT16_KHR:	return ((deInt16*)base)[i];
858 		case VK_COMPONENT_TYPE_SINT32_KHR:	return ((deInt32 *)base)[i];
859 		default:							TCU_THROW(InternalError, "Unsupported type");
860 	}
861 }
862 
863 template <typename T>
getDataConvertedToT(void * base,VkComponentTypeKHR dt,deUint32 i)864 T getDataConvertedToT (void *base, VkComponentTypeKHR dt, deUint32 i)
865 {
866 	DE_ASSERT(componentTypeInfo[dt].bits <= 32);
867 
868 	switch (dt)
869 	{
870 		case VK_COMPONENT_TYPE_UINT8_KHR:	return (T)((deUint8*)base)[i];
871 		case VK_COMPONENT_TYPE_UINT16_KHR:	return (T)((deUint16*)base)[i];
872 		case VK_COMPONENT_TYPE_UINT32_KHR:	return (T)((deUint32*)base)[i];
873 		case VK_COMPONENT_TYPE_SINT8_KHR:	return (T)((deInt8*)base)[i];
874 		case VK_COMPONENT_TYPE_SINT16_KHR:	return (T)((deInt16*)base)[i];
875 		case VK_COMPONENT_TYPE_SINT32_KHR:	return (T)((deInt32 *)base)[i];
876 		case VK_COMPONENT_TYPE_FLOAT32_KHR:
877 		{
878 			float temp = ((float *)base)[i];
879 			if (std::numeric_limits<T>::min() == 0)
880 				temp = std::max(temp, 0.0f);
881 			return (T)temp;
882 		}
883 		case VK_COMPONENT_TYPE_FLOAT16_KHR:
884 		{
885 			float temp = deFloat16To32(((deFloat16 *)base)[i]);
886 			if (std::numeric_limits<T>::min() == 0)
887 				temp = std::max(temp, 0.0f);
888 			return (T)temp;
889 		}
890 		default:
891 			TCU_THROW(InternalError, "Unsupported type");
892 	}
893 }
894 
895 template<typename T>
satAdd(T a,T b)896 T satAdd(T a, T b)
897 {
898 	if (a > 0)
899 	{
900 		if (b > std::numeric_limits<T>::max() - a)
901 			return std::numeric_limits<T>::max();
902 	}
903 	else if (b < std::numeric_limits<T>::min() - a)
904 	{
905 		return std::numeric_limits<T>::min();
906 	}
907 
908 	return (T)(a + b);
909 }
910 
satAddData(VkComponentTypeKHR dt,deUint32 a,deUint32 b)911 deUint32 satAddData (VkComponentTypeKHR dt, deUint32 a, deUint32 b)
912 {
913 	DE_ASSERT(componentTypeInfo[dt].bits <= 32);
914 
915 	switch (dt)
916 	{
917 		case VK_COMPONENT_TYPE_UINT8_KHR:	return deMinu32(a + b, std::numeric_limits<deUint8>::max());
918 		case VK_COMPONENT_TYPE_UINT16_KHR:	return deMinu32(a + b, std::numeric_limits<deUint16>::max());
919 		case VK_COMPONENT_TYPE_UINT32_KHR:	return (a + b >= a) ? a + b : std::numeric_limits<deUint32>::max();
920 		case VK_COMPONENT_TYPE_SINT8_KHR:	return (deUint32)satAdd((deInt8)a,  (deInt8)b);
921 		case VK_COMPONENT_TYPE_SINT16_KHR:	return (deUint32)satAdd((deInt16)a, (deInt16)b);
922 		case VK_COMPONENT_TYPE_SINT32_KHR:	return (deUint32)satAdd((deInt32)a, (deInt32)b);
923 		default:							TCU_THROW(InternalError, "Unsupported type");
924 	}
925 }
926 
getLimit(VkComponentTypeKHR dt,bool positive)927 deUint32 getLimit (VkComponentTypeKHR dt, bool positive)
928 {
929 	DE_ASSERT(componentTypeInfo[dt].bits <= 32);
930 
931 	switch (dt)
932 	{
933 		case VK_COMPONENT_TYPE_UINT8_KHR:	return deUint32(positive ? std::numeric_limits<deUint8>::max()  : std::numeric_limits<deUint8>::min());
934 		case VK_COMPONENT_TYPE_UINT16_KHR:	return deUint32(positive ? std::numeric_limits<deUint16>::max() : std::numeric_limits<deUint16>::min());
935 		case VK_COMPONENT_TYPE_UINT32_KHR:	return deUint32(positive ? std::numeric_limits<deUint32>::max() : std::numeric_limits<deUint32>::min());
936 		case VK_COMPONENT_TYPE_SINT8_KHR:	return deUint32(positive ? std::numeric_limits<deInt8>::max()   : std::numeric_limits<deInt8>::min());
937 		case VK_COMPONENT_TYPE_SINT16_KHR:	return deUint32(positive ? std::numeric_limits<deInt16>::max()  : std::numeric_limits<deInt16>::min());
938 		case VK_COMPONENT_TYPE_SINT32_KHR:	return deUint32(positive ? std::numeric_limits<deInt32>::max()  : std::numeric_limits<deInt32>::min());
939 		default:							TCU_THROW(InternalError, "Unsupported type");
940 	}
941 }
942 
setSingleElementInt(void * data,VkComponentTypeKHR dt,deUint32 start,deUint32 count,deUint32 step,deUint32 at,deUint32 val)943 void setSingleElementInt (void *data, VkComponentTypeKHR dt, deUint32 start, deUint32 count, deUint32 step, deUint32 at, deUint32 val)
944 {
945 	for (deUint32 i = 0; i < count; i++)
946 		setDataInt(data, dt, start + i * step, (i == at) ? val : 0);
947 }
948 
949 #ifdef COOPERATIVE_MATRIX_EXTENDED_DEBUG
dumpWholeMatrix(void * data,VkComponentTypeKHR dt,bool colMajor,deUint32 matrixElemCount,deUint32 stride)950 string dumpWholeMatrix (void* data, VkComponentTypeKHR dt, bool colMajor, deUint32 matrixElemCount, deUint32 stride)
951 {
952 	const deUint32		rowsCount	= colMajor ? stride : matrixElemCount / stride;
953 	const deUint32		colsCount	= colMajor ? matrixElemCount / stride : stride;
954 	bool				floatType	= isFloatType(dt);
955 	bool				sIntType	= isSIntType(dt);
956 	std::stringstream	ss;
957 
958 	DE_ASSERT(rowsCount * colsCount == matrixElemCount);
959 
960 	for (deUint32 r = 0; r < rowsCount; r++)
961 	{
962 		for (deUint32 c = 0; c < colsCount; c++)
963 		{
964 			const deUint32 i = colMajor ? rowsCount * c + r : colsCount * r + c;
965 
966 			if (floatType)
967 				ss << getDataFloat(data, dt, i) << "\t";
968 			else if (sIntType)
969 				ss << (deInt32)getDataInt(data, dt, i) << "\t";
970 			else
971 				ss << getDataInt(data, dt, i) << "\t";
972 		}
973 
974 		ss << std::endl;
975 	}
976 
977 	return ss.str();
978 }
979 #endif
980 
iterate(void)981 tcu::TestStatus CooperativeMatrixTestInstance::iterate (void)
982 {
983 	const DeviceInterface&	vk						= m_context.getDeviceInterface();
984 	const VkDevice			device					= m_context.getDevice();
985 	Allocator&				allocator				= m_context.getDefaultAllocator();
986 	MemoryRequirement		memoryDeviceAddress		= m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER &&
987 													  m_context.isDeviceFunctionalitySupported("VK_KHR_buffer_device_address") ? MemoryRequirement::DeviceAddress : MemoryRequirement::Any;
988 	qpTestResult			finalres				= QP_TEST_RESULT_NOT_SUPPORTED;
989 	tcu::TestLog&			log						= m_context.getTestContext().getLog();
990 	const bool				saturated				= (m_data.testType == TT_MATRIXMULADD_SATURATED);
991 	const deUint32			subgroupSize			= getSubgroupSizeFromMode(m_context, m_data.subgroupSizeMode);
992 	const float				epsilon					= 1.0f / float(1ull<<17); // 131072 is epsilon circa 1e-5
993 
994 	deRandom rnd;
995 	deRandom_init(&rnd, 1234);
996 
997 	std::vector<VkCooperativeMatrixPropertiesKHR>	properties = getCooperativeMatrixPropertiesConverted(m_context, isKhr(m_data.useType));
998 
999 	struct TestTuple
1000 	{
1001 		TestTuple() {}
1002 		TestTuple(deUint32 m, deUint32 n, deUint32 k) : M(m), N(n), K(k) {}
1003 
1004 		bool operator<(const TestTuple &other) const
1005 		{
1006 			return M < other.M ||
1007 				   (M == other.M && N < other.N) ||
1008 				   (M == other.M && N == other.N && K < other.K);
1009 		}
1010 
1011 		deUint32 M, N, K;
1012 	};
1013 
1014 	vector<TestTuple> testSizes;
1015 
1016 	if (isMatrixMulAddOp(m_data.testType))
1017 	{
1018 		for (size_t i = 0; i < properties.size(); ++i)
1019 		{
1020 			VkCooperativeMatrixPropertiesKHR *p = &properties[i];
1021 
1022 			if (p->AType == m_data.inputType &&
1023 				p->BType == m_data.inputType &&
1024 				p->CType == m_data.outputType &&
1025 				p->ResultType == m_data.outputType &&
1026 				p->scope == VK_SCOPE_SUBGROUP_KHR)
1027 			{
1028 				testSizes.push_back(TestTuple(p->MSize, p->NSize, p->KSize));
1029 			}
1030 		}
1031 	}
1032 	else
1033 	{
1034 		set<TestTuple> typeSizes[2];
1035 		VkComponentTypeKHR types[2] = { m_data.inputType, m_data.outputType };
1036 		const bool aType = (m_data.useType == UT_KHR_A) || (m_data.useType == UT_NV);
1037 		const bool bType = (m_data.useType == UT_KHR_B) || (m_data.useType == UT_NV);
1038 		const bool rType = (m_data.useType == UT_KHR_Result) || (m_data.useType == UT_NV);
1039 
1040 		for (deUint32 i = 0; i < properties.size(); ++i)
1041 		{
1042 			VkCooperativeMatrixPropertiesKHR *p = &properties[i];
1043 
1044 			if (p->scope != VK_SCOPE_SUBGROUP_KHR)
1045 				continue;
1046 
1047 			for (deUint32 j = 0; j < 2; ++j)
1048 			{
1049 				// For these tests, m_data.M/N are always the matrix size. Check if they match
1050 				// any input or output in the list.
1051 				if (aType && p->AType == types[j]) typeSizes[j].insert(TestTuple(p->MSize, p->KSize, 0));
1052 				if (bType && p->BType == types[j]) typeSizes[j].insert(TestTuple(p->KSize, p->NSize, 0));
1053 				if (rType && (p->CType == types[j] || p->ResultType == types[j])) typeSizes[j].insert(TestTuple(p->MSize, p->NSize, 0));
1054 			}
1055 		}
1056 		// Test those sizes that are supported for both the input and output type.
1057 		std::set_intersection(typeSizes[0].begin(), typeSizes[0].end(),
1058 							  typeSizes[1].begin(), typeSizes[1].end(),
1059 							  std::back_inserter(testSizes));
1060 	}
1061 
1062 	properties.resize(0);
1063 
1064 	for (unsigned int s = 0; s < testSizes.size(); ++s)
1065 	{
1066 		// When testing a multiply, MxNxK is the type of matrix multiply.
1067 		// Otherwise, MxN is the size of the input/output matrices
1068 		deUint32 M, N, K;
1069 		M = testSizes[s].M;
1070 		N = testSizes[s].N;
1071 		K = testSizes[s].K;
1072 
1073 		log << tcu::TestLog::Message << "Testing M = " << M << ", N = " << N << ", K = " << K << tcu::TestLog::EndMessage;
1074 
1075 		struct
1076 		{
1077 			deUint32 rows, cols;
1078 		} dims[4];
1079 
1080 		if (isMatrixMulAddOp(m_data.testType))
1081 		{
1082 			dims[0].rows = M;
1083 			dims[0].cols = K;
1084 			dims[1].rows = K;
1085 			dims[1].cols = N;
1086 			dims[2].rows = M;
1087 			dims[2].cols = N;
1088 			dims[3].rows = M;
1089 			dims[3].cols = N;
1090 		}
1091 		else
1092 		{
1093 			dims[0].rows = M;
1094 			dims[0].cols = N;
1095 			dims[1].rows = M;
1096 			dims[1].cols = N;
1097 			dims[2].rows = M;
1098 			dims[2].cols = N;
1099 			dims[3].rows = M;
1100 			dims[3].cols = N;
1101 		}
1102 
1103 		VkComponentTypeKHR dataTypes[4];
1104 		size_t elementSize[4];
1105 		VkDeviceSize bufferSizes[5];
1106 		de::MovePtr<BufferWithMemory> buffers[5];
1107 		vk::VkDescriptorBufferInfo bufferDescriptors[5];
1108 		deUint32 strides[4]; // in elements
1109 		deUint32 loadStrides[4];
1110 		deUint32 totalElements[4];
1111 
1112 		for (deUint32 i = 0; i < 5; ++i)
1113 		{
1114 			if (i < 4)
1115 			{
1116 				// A/B use input type, C/D use output type
1117 				dataTypes[i] = (i < 2) ? m_data.inputType : m_data.outputType;
1118 				elementSize[i] = componentTypeInfo[dataTypes[i]].bits / 8;
1119 
1120 				strides[i] = (m_data.colMajor ? dims[i].rows : dims[i].cols) * m_data.subgroupsPerWorkgroupX * m_data.workgroupsX;
1121 				loadStrides[i] = strides[i];
1122 				totalElements[i] = strides[i] * (m_data.colMajor ? dims[i].cols : dims[i].rows) * m_data.subgroupsPerWorkgroupY * m_data.workgroupsY;
1123 
1124 				bufferSizes[i] = totalElements[i] * elementSize[i];
1125 			}
1126 			else
1127 			{
1128 				bufferSizes[4] = sizeof(VkDeviceAddress)*4;
1129 			}
1130 
1131 			try
1132 			{
1133 				buffers[i] = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
1134 					vk, device, allocator, makeBufferCreateInfo(bufferSizes[i], VK_BUFFER_USAGE_STORAGE_BUFFER_BIT|VK_BUFFER_USAGE_TRANSFER_DST_BIT|VK_BUFFER_USAGE_TRANSFER_SRC_BIT|
1135 					(memoryDeviceAddress == MemoryRequirement::DeviceAddress ?  VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT_EXT : 0)),
1136 					MemoryRequirement::HostVisible | MemoryRequirement::Cached | MemoryRequirement::Coherent | memoryDeviceAddress));
1137 			}
1138 			catch (const tcu::NotSupportedError&)
1139 			{
1140 				buffers[i] = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
1141 					vk, device, allocator, makeBufferCreateInfo(bufferSizes[i], VK_BUFFER_USAGE_STORAGE_BUFFER_BIT|VK_BUFFER_USAGE_TRANSFER_DST_BIT|VK_BUFFER_USAGE_TRANSFER_SRC_BIT|
1142 					(memoryDeviceAddress == MemoryRequirement::DeviceAddress ?  VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT_EXT : 0)),
1143 					MemoryRequirement::HostVisible | memoryDeviceAddress));
1144 			}
1145 
1146 			bufferDescriptors[i] = makeDescriptorBufferInfo(**buffers[i], 0, bufferSizes[i]);
1147 		}
1148 
1149 		// Load with a stride of 0
1150 		if (m_data.testType == TT_MATRIXMULADD_STRIDE0)
1151 			loadStrides[0] = loadStrides[1] = loadStrides[2] = loadStrides[3] = 0;
1152 
1153 		void *ptrs[5];
1154 		for (deUint32 i = 0; i < 5; ++i)
1155 		{
1156 			ptrs[i] = buffers[i]->getAllocation().getHostPtr();
1157 		}
1158 
1159 		vk::DescriptorSetLayoutBuilder layoutBuilder;
1160 
1161 		layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
1162 		layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
1163 		layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
1164 		layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
1165 		layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
1166 
1167 		vk::Unique<vk::VkDescriptorSetLayout>	descriptorSetLayout(layoutBuilder.build(vk, device));
1168 
1169 		vk::Unique<vk::VkDescriptorPool>		descriptorPool(vk::DescriptorPoolBuilder()
1170 			.addType(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 5u)
1171 			.build(vk, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u));
1172 		vk::Unique<vk::VkDescriptorSet>			descriptorSet		(makeDescriptorSet(vk, device, *descriptorPool, *descriptorSetLayout));
1173 
1174 		vk::DescriptorSetUpdateBuilder setUpdateBuilder;
1175 		if (m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER)
1176 		{
1177 			VkBufferDeviceAddressInfo info
1178 			{
1179 				VK_STRUCTURE_TYPE_BUFFER_DEVICE_ADDRESS_INFO,		// VkStructureType	 sType;
1180 				DE_NULL,											// const void*		 pNext;
1181 				0,													// VkBuffer			buffer
1182 			};
1183 			VkDeviceAddress *addrsInMemory = (VkDeviceAddress *)ptrs[4];
1184 			for (deUint32 i = 0; i < 4; ++i)
1185 			{
1186 				info.buffer = **buffers[i];
1187 				VkDeviceAddress addr = vk.getBufferDeviceAddress(device, &info);
1188 				addrsInMemory[i] = addr;
1189 			}
1190 			setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(4),
1191 				VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[4]);
1192 		}
1193 		else
1194 		{
1195 			setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(0),
1196 				VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[0]);
1197 			setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(1),
1198 				VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[1]);
1199 			setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(2),
1200 				VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[2]);
1201 			setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(3),
1202 				VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[3]);
1203 		}
1204 
1205 		setUpdateBuilder.update(vk, device);
1206 
1207 		VkPipelineBindPoint bindPoint = VK_PIPELINE_BIND_POINT_COMPUTE;
1208 
1209 		const deUint32 specData[9] =
1210 		{
1211 			subgroupSize * m_data.subgroupsPerWorkgroupX,
1212 			m_data.subgroupsPerWorkgroupY,
1213 			strides[0],
1214 			strides[1],
1215 			strides[2],
1216 			strides[3],
1217 			M,
1218 			N,
1219 			K,
1220 		};
1221 
1222 		const vk::VkSpecializationMapEntry entries[9] =
1223 		{
1224 			{0, (deUint32)(sizeof(deUint32) * 0), sizeof(deUint32)},
1225 			{1, (deUint32)(sizeof(deUint32) * 1), sizeof(deUint32)},
1226 			{2, (deUint32)(sizeof(deUint32) * 2), sizeof(deUint32)},
1227 			{3, (deUint32)(sizeof(deUint32) * 3), sizeof(deUint32)},
1228 			{4, (deUint32)(sizeof(deUint32) * 4), sizeof(deUint32)},
1229 			{5, (deUint32)(sizeof(deUint32) * 5), sizeof(deUint32)},
1230 			{6, (deUint32)(sizeof(deUint32) * 6), sizeof(deUint32)},
1231 			{7, (deUint32)(sizeof(deUint32) * 7), sizeof(deUint32)},
1232 			{8, (deUint32)(sizeof(deUint32) * 8), sizeof(deUint32)},
1233 		};
1234 
1235 		const vk::VkSpecializationInfo specInfo =
1236 		{
1237 			9,						// mapEntryCount
1238 			entries,				// pMapEntries
1239 			sizeof(specData),		// dataSize
1240 			specData				// pData
1241 		};
1242 
1243 		for (deUint32 i = 0; i < 4; ++i)
1244 			for (deUint32 j = 0; j < totalElements[i]; ++j)
1245 			{
1246 				if (isFloatType(dataTypes[i]))
1247 				{
1248 					if (!isMatrixMulAddOp(m_data.testType))
1249 						setDataFloat(ptrs[i], dataTypes[i], j, ((float)(deRandom_getUint32(&rnd) & 0xff) - 64.0f)/2.0f);
1250 					else
1251 						setDataFloat(ptrs[i], dataTypes[i], j, ((float)(deRandom_getUint32(&rnd) & 0xf) - 4.0f)/2.0f);
1252 				}
1253 				else
1254 				{
1255 					if (m_data.testType == TT_MATRIXMULADD_WRAPPING)
1256 					{
1257 						// Choose matrix values that should cause overflow and underflow, to
1258 						// verify wrapping behavior. Use the full range of values for A and B.
1259 						// For matrix C, use values clustered near where the type wraps (zero
1260 						// for unsigned, 2^(N-1) for signed).
1261 						deUint32 bits = componentTypeInfo[dataTypes[i]].bits;
1262 						deUint32 value;
1263 						if (i == 2) {
1264 							value = (deRandom_getUint32(&rnd) & 0xff) - 128;
1265 							if (componentTypeInfo[dataTypes[i]].isSigned)
1266 								value += (1U << (bits - 1));
1267 						} else {
1268 							deUint32 mask = (bits == 32) ? 0xFFFFFFFFU : ((1U << bits) - 1U);
1269 							value = deRandom_getUint32(&rnd) & mask;
1270 						}
1271 						setDataInt(ptrs[i], dataTypes[i], j, value);
1272 					}
1273 					else if (m_data.testType == TT_MATRIXMULADD_SATURATED)
1274 					{
1275 						setDataInt(ptrs[i], dataTypes[i], j, 0);
1276 					}
1277 					else
1278 					{
1279 						deUint32 value = (deRandom_getUint32(&rnd) & 0xff) - 128;
1280 						setDataInt(ptrs[i], dataTypes[i], j, value);
1281 					}
1282 				}
1283 			}
1284 
1285 		if (m_data.testType == TT_MATRIXMULADD_SATURATED)
1286 		{
1287 			// Set 1st row of A to 1,0,0...
1288 			setSingleElementInt(ptrs[0], dataTypes[0], 0, dims[0].cols, (m_data.colMajor ? strides[0] : 1), 0, 1);
1289 
1290 			// Set 1st column of B to 1,0,0...
1291 			setSingleElementInt(ptrs[1], dataTypes[1], 0, dims[1].rows, (m_data.colMajor ? 1 : strides[1]), 0, 1);
1292 
1293 			// Set C element at {0,0} to maximum type value, thus we will have overflow at plus operation in D=A*B+C for this element
1294 			setDataInt(ptrs[2], dataTypes[2], 0, getLimit(dataTypes[2], true));
1295 
1296 			// Check underflow if all involved elements support negative values
1297 			if (isSIntType(dataTypes[1]) && isSIntType(dataTypes[2]) && isSIntType(dataTypes[3]))
1298 			{
1299 				// Set 2nd row of A to 0,1,0,0...
1300 				setSingleElementInt(ptrs[0], dataTypes[0], (m_data.colMajor ? 1 : strides[0]), dims[0].cols, (m_data.colMajor ? strides[0] : 1), 1, 1);
1301 
1302 				// Set 2nd column of B to 0,-1,0,0...
1303 				setSingleElementInt(ptrs[1], dataTypes[1], (m_data.colMajor ? strides[1] : 1), dims[1].rows, (m_data.colMajor ? 1 : strides[1]), 1, -1);
1304 
1305 				// Set C element at {1,1} to minimum type value, thus we will have underflow at plus operation in D=A*B+C for this element
1306 				setDataInt(ptrs[2], dataTypes[2], strides[2] + 1, getLimit(dataTypes[2], false));
1307 			}
1308 		}
1309 
1310 		flushAlloc(vk, device, buffers[0]->getAllocation());
1311 		flushAlloc(vk, device, buffers[1]->getAllocation());
1312 		flushAlloc(vk, device, buffers[2]->getAllocation());
1313 		flushAlloc(vk, device, buffers[3]->getAllocation());
1314 
1315 		ComputePipelineWrapper			pipeline(vk, device, m_data.computePipelineConstructionType, m_context.getBinaryCollection().get("test"));
1316 		pipeline.setDescriptorSetLayout(descriptorSetLayout.get());
1317 		pipeline.setSpecializationInfo(specInfo);
1318 		pipeline.setSubgroupSize(m_data.subgroupSizeMode == SUBGROUP_SIZE_NONE ? 0 : getSubgroupSizeFromMode(m_context, m_data.subgroupSizeMode));
1319 		pipeline.buildPipeline();
1320 
1321 		const VkQueue					queue					= m_context.getUniversalQueue();
1322 		Move<VkCommandPool>				cmdPool					= createCommandPool(vk, device, 0, m_context.getUniversalQueueFamilyIndex());
1323 		Move<VkCommandBuffer>			cmdBuffer				= allocateCommandBuffer(vk, device, *cmdPool, VK_COMMAND_BUFFER_LEVEL_PRIMARY);
1324 
1325 		beginCommandBuffer(vk, *cmdBuffer, 0u);
1326 
1327 		vk.cmdBindDescriptorSets(*cmdBuffer, bindPoint, pipeline.getPipelineLayout(), 0u, 1, &*descriptorSet, 0u, DE_NULL);
1328 		pipeline.bind(*cmdBuffer);
1329 
1330 		vk.cmdDispatch(*cmdBuffer, m_data.workgroupsX, m_data.workgroupsY, 1);
1331 
1332 		endCommandBuffer(vk, *cmdBuffer);
1333 
1334 		submitCommandsAndWait(vk, device, queue, cmdBuffer.get());
1335 
1336 		invalidateAlloc(vk, device, buffers[3]->getAllocation());
1337 
1338 		qpTestResult res = QP_TEST_RESULT_PASS;
1339 
1340 		if (m_data.testType == TT_CONVERT)
1341 		{
1342 			for (deUint32 i = 0; i < totalElements[3]; ++i)
1343 			{
1344 				// Store results as double, which has enough range to hold all the other types exactly.
1345 				double inputA, output;
1346 
1347 				// This loads the data according to dataTypes[0], and then converts to the template parameter type
1348 				switch (dataTypes[3]) {
1349 				case VK_COMPONENT_TYPE_UINT8_KHR:	inputA = getDataConvertedToT<uint8_t>(ptrs[0], dataTypes[0], i); break;
1350 				case VK_COMPONENT_TYPE_UINT16_KHR:	inputA = getDataConvertedToT<uint16_t>(ptrs[0], dataTypes[0], i); break;
1351 				case VK_COMPONENT_TYPE_UINT32_KHR:	inputA = getDataConvertedToT<uint32_t>(ptrs[0], dataTypes[0], i); break;
1352 				case VK_COMPONENT_TYPE_SINT8_KHR:	inputA = getDataConvertedToT<int8_t>(ptrs[0], dataTypes[0], i); break;
1353 				case VK_COMPONENT_TYPE_SINT16_KHR:	inputA = getDataConvertedToT<int16_t>(ptrs[0], dataTypes[0], i); break;
1354 				case VK_COMPONENT_TYPE_SINT32_KHR:	inputA = getDataConvertedToT<int32_t>(ptrs[0], dataTypes[0], i); break;
1355 				case VK_COMPONENT_TYPE_FLOAT32_KHR: inputA = getDataConvertedToT<float>(ptrs[0], dataTypes[0], i); break;
1356 				case VK_COMPONENT_TYPE_FLOAT16_KHR:
1357 				{
1358 					float temp = getDataConvertedToT<float>(ptrs[0], dataTypes[0], i);
1359 					inputA = deFloat16To32(deFloat32To16(temp));
1360 					break;
1361 				}
1362 				default: TCU_THROW(InternalError, "Unexpected type");
1363 				}
1364 
1365 				switch (dataTypes[3]) {
1366 				case VK_COMPONENT_TYPE_UINT8_KHR:	output = getDataConvertedToT<uint8_t>(ptrs[3], dataTypes[3], i); break;
1367 				case VK_COMPONENT_TYPE_UINT16_KHR:	output = getDataConvertedToT<uint16_t>(ptrs[3], dataTypes[3], i); break;
1368 				case VK_COMPONENT_TYPE_UINT32_KHR:	output = getDataConvertedToT<uint32_t>(ptrs[3], dataTypes[3], i); break;
1369 				case VK_COMPONENT_TYPE_SINT8_KHR:	output = getDataConvertedToT<int8_t>(ptrs[3], dataTypes[3], i); break;
1370 				case VK_COMPONENT_TYPE_SINT16_KHR:	output = getDataConvertedToT<int16_t>(ptrs[3], dataTypes[3], i); break;
1371 				case VK_COMPONENT_TYPE_SINT32_KHR:	output = getDataConvertedToT<int32_t>(ptrs[3], dataTypes[3], i); break;
1372 				case VK_COMPONENT_TYPE_FLOAT32_KHR: output = getDataConvertedToT<float>(ptrs[3], dataTypes[3], i); break;
1373 				case VK_COMPONENT_TYPE_FLOAT16_KHR:
1374 				{
1375 					float temp = getDataConvertedToT<float>(ptrs[3], dataTypes[3], i);
1376 					output = deFloat16To32(deFloat32To16(temp));
1377 					break;
1378 				}
1379 				default: TCU_THROW(InternalError, "Unexpected type");
1380 				}
1381 
1382 				if (inputA != output) {
1383 					res = QP_TEST_RESULT_FAIL;
1384 					break;
1385 				}
1386 			}
1387 		}
1388 		else if (isFloatType(dataTypes[0]))
1389 		{
1390 			if (!isMatrixMulAddOp(m_data.testType))
1391 			{
1392 				for (deUint32 i = 0; i < totalElements[3]; ++i)
1393 				{
1394 					float inputA = getDataFloat(ptrs[0], dataTypes[0], i);
1395 					float inputB = getDataFloat(ptrs[1], dataTypes[1], i);
1396 					float output = getDataFloat(ptrs[3], dataTypes[3], i);
1397 					switch (m_data.testType)
1398 					{
1399 					case TT_LENGTH:
1400 						if (output < 1.0f || output > (float)(N*M))
1401 							res = QP_TEST_RESULT_FAIL;
1402 						// We expect the matrix to be spread evenly across invocations, it is
1403 						// surprising (but not necessarily illegal) if not
1404 						if (output != (float)(N*M/subgroupSize) &&
1405 							res == QP_TEST_RESULT_PASS)
1406 							res = QP_TEST_RESULT_QUALITY_WARNING;
1407 						break;
1408 					case TT_CONSTANT:
1409 						if (output != 1.0f)
1410 							res = QP_TEST_RESULT_FAIL;
1411 						break;
1412 					case TT_COMPOSITE:
1413 					case TT_COMPOSITE_RVALUE:
1414 					case TT_COMPOSITE_ARRAY:
1415 					case TT_ADD:
1416 						if (output != inputA + inputB)
1417 							res = QP_TEST_RESULT_FAIL;
1418 						break;
1419 					case TT_SUB:
1420 						if (output != inputA - inputB)
1421 							res = QP_TEST_RESULT_FAIL;
1422 						break;
1423 					case TT_DIV:
1424 						{
1425 							float ulp = (m_data.inputType == VK_COMPONENT_TYPE_FLOAT16_KHR) ? 1.0f/1024.0f : 1.0f/(8.0f*1024.0f*1024.0f);
1426 							// division allows 2.5ulp, but we'll use 3.
1427 							ulp *= 3;
1428 							if (inputB != 0 && fabs(output - inputA / inputB) > ulp * fabs(inputA / inputB))
1429 								res = QP_TEST_RESULT_FAIL;
1430 						}
1431 						break;
1432 					case TT_MUL:
1433 					{
1434 						if (dataTypes[0] == VK_COMPONENT_TYPE_FLOAT16_KHR)
1435 						{
1436 							const float		expected32	= inputA * inputB;
1437 							const deFloat16	expected16	= deFloat32To16(expected32);
1438 							const float		expected	= deFloat16To32(expected16);
1439 
1440 							if (output != expected)
1441 								res = QP_TEST_RESULT_FAIL;
1442 						}
1443 						else
1444 						{
1445 							if (output != inputA * inputB)
1446 								res = QP_TEST_RESULT_FAIL;
1447 						}
1448 						break;
1449 					}
1450 					case TT_NEGATE:
1451 					case TT_FUNC:
1452 						if (output != -inputA)
1453 							res = QP_TEST_RESULT_FAIL;
1454 						break;
1455 					case TT_MATRIXTIMESSCALAR:
1456 						if (output != 6.0*inputA)
1457 							res = QP_TEST_RESULT_FAIL;
1458 						break;
1459 					default:
1460 						break;
1461 					}
1462 				}
1463 			}
1464 			else
1465 			{
1466 				deUint32 ik, kj, ij;
1467 				for (deUint32 mX = 0; mX < m_data.subgroupsPerWorkgroupX*m_data.workgroupsX; ++mX)
1468 				{
1469 					for (deUint32 mY = 0; mY < m_data.subgroupsPerWorkgroupY*m_data.workgroupsY; ++mY)
1470 					{
1471 						for (deUint32 i = 0; i < M; ++i)
1472 						{
1473 							for (deUint32 j = 0; j < N; ++j)
1474 							{
1475 								float ref = 0;
1476 								for (deUint32 k = 0; k < K; ++k)
1477 								{
1478 									if (m_data.colMajor)
1479 										ik = mX * M + i + strides[0] * mY * K + loadStrides[0] * k;
1480 									else
1481 										ik = mX * K + k + strides[0] * mY * M + loadStrides[0] * i;
1482 
1483 									float Aik = getDataFloat(ptrs[0], dataTypes[0], ik);
1484 
1485 									if (m_data.colMajor)
1486 										kj = mX * K + k + strides[1] * mY * N + loadStrides[1] * j;
1487 									else
1488 										kj = mX * N + j + strides[1] * mY * K + loadStrides[1] * k;
1489 
1490 									float Bkj = getDataFloat(ptrs[1], dataTypes[1], kj);
1491 
1492 									ref += Aik*Bkj;
1493 								}
1494 
1495 								if (m_data.colMajor)
1496 									ij = mX * M + i + strides[2] * mY * N + loadStrides[2] * j;
1497 								else
1498 									ij = mX * N + j + strides[2] * mY * M + loadStrides[2] * i;
1499 
1500 								float Cij = getDataFloat(ptrs[2], dataTypes[2], ij);
1501 
1502 								ref += Cij;
1503 
1504 								// When loading with stride 0, ij for matrix D is different from matrix C
1505 								if (m_data.colMajor)
1506 									ij = mX * M + i + strides[2] * (mY * N + j);
1507 								else
1508 									ij = mX * N + j + strides[2] * (mY * M + i);
1509 
1510 								float Dij = getDataFloat(ptrs[3], dataTypes[3], ij);
1511 
1512 								if (fabs(ref - Dij) > epsilon)
1513 								{
1514 									res = QP_TEST_RESULT_FAIL;
1515 								}
1516 							}
1517 						}
1518 					}
1519 				}
1520 			}
1521 		} else {
1522 			if (!isMatrixMulAddOp(m_data.testType))
1523 			{
1524 				for (deUint32 i = 0; i < totalElements[3]; ++i)
1525 				{
1526 					deUint32 inputA = getDataInt(ptrs[0], dataTypes[0], i);
1527 					deUint32 inputB = getDataInt(ptrs[1], dataTypes[1], i);
1528 					deUint32 output = getDataInt(ptrs[3], dataTypes[3], i);
1529 					int resultSize = componentTypeInfo[dataTypes[3]].bits;
1530 					deUint32 mask = resultSize == 32 ? ~0 : ((1 << resultSize) - 1);
1531 					switch (m_data.testType)
1532 					{
1533 					case TT_LENGTH:
1534 						if (output < 1 || output > N*M)
1535 							res = QP_TEST_RESULT_FAIL;
1536 						// We expect the matrix to be spread evenly across invocations, it is
1537 						// surprising (but not necessarily illegal) if not
1538 						if (output != N*M/subgroupSize &&
1539 							res == QP_TEST_RESULT_PASS)
1540 							res = QP_TEST_RESULT_QUALITY_WARNING;
1541 						break;
1542 					case TT_CONSTANT:
1543 						if (output != 1)
1544 							res = QP_TEST_RESULT_FAIL;
1545 						break;
1546 					case TT_COMPOSITE:
1547 					case TT_COMPOSITE_RVALUE:
1548 					case TT_COMPOSITE_ARRAY:
1549 					case TT_ADD:
1550 						if ((output & mask) != ((inputA + inputB) & mask)) {
1551 							res = QP_TEST_RESULT_FAIL;
1552 						}
1553 						break;
1554 					case TT_SUB:
1555 						if ((output & mask) != ((inputA - inputB) & mask))
1556 							res = QP_TEST_RESULT_FAIL;
1557 						break;
1558 					case TT_DIV:
1559 						{
1560 							if (isSIntType(dataTypes[3]))
1561 							{
1562 								if (inputB != 0 && ((deInt32)output & mask) != (((deInt32)inputA / (deInt32)inputB) & mask))
1563 									res = QP_TEST_RESULT_FAIL;
1564 							} else
1565 							{
1566 								if (inputB != 0 && output != inputA / inputB)
1567 									res = QP_TEST_RESULT_FAIL;
1568 							}
1569 						}
1570 						break;
1571 					case TT_MUL:
1572 					{
1573 						if (((deInt32)output & mask) != (((deInt32)inputA * (deInt32)inputB) & mask))
1574 						{
1575 							res = QP_TEST_RESULT_FAIL;
1576 						}
1577 
1578 						break;
1579 					}
1580 					case TT_NEGATE:
1581 					case TT_FUNC:
1582 						if ((output & mask) != ((-(deInt32)inputA) & mask))
1583 							res = QP_TEST_RESULT_FAIL;
1584 						break;
1585 					case TT_MATRIXTIMESSCALAR:
1586 						if ((output & mask) != ((6*inputA) & mask)) {
1587 							res = QP_TEST_RESULT_FAIL;
1588 						}
1589 						break;
1590 					default:
1591 						break;
1592 					}
1593 				}
1594 			}
1595 			else
1596 			{
1597 				deUint32 ik, kj, ij;
1598 				for (deUint32 mX = 0; mX < m_data.subgroupsPerWorkgroupX*m_data.workgroupsX; ++mX)
1599 				{
1600 					for (deUint32 mY = 0; mY < m_data.subgroupsPerWorkgroupY*m_data.workgroupsY; ++mY)
1601 					{
1602 						for (deUint32 i = 0; i < M; ++i)
1603 						{
1604 							for (deUint32 j = 0; j < N; ++j)
1605 							{
1606 								deUint32 ref = 0;
1607 
1608 								for (deUint32 k = 0; k < K; ++k)
1609 								{
1610 									if (m_data.colMajor)
1611 										ik = mX * M + i + strides[0] * mY * K + loadStrides[0] * k;
1612 									else
1613 										ik = mX * K + k + strides[0] * mY * M + loadStrides[0] * i;
1614 
1615 									deUint32 Aik = getDataInt(ptrs[0], dataTypes[0], ik);
1616 
1617 									if (m_data.colMajor)
1618 										kj = mX * K + k + strides[1] * mY * N + loadStrides[1] * j;
1619 									else
1620 										kj = mX * N + j + strides[1] * mY * K + loadStrides[1] * k;
1621 
1622 									deUint32 Bkj = getDataInt(ptrs[1], dataTypes[1], kj);
1623 
1624 									ref += Aik*Bkj;
1625 								}
1626 
1627 								if (m_data.colMajor)
1628 									ij = mX * M + i + strides[2] * mY * N + loadStrides[2] * j;
1629 								else
1630 									ij = mX * N + j + strides[2] * mY * M + loadStrides[2] * i;
1631 
1632 								deUint32 Cij = getDataInt(ptrs[2], dataTypes[2], ij);
1633 
1634 								if (saturated)
1635 								{
1636 									ref = satAddData(dataTypes[2], ref, Cij);
1637 								}
1638 								else
1639 								{
1640 									ref += Cij;
1641 									// truncate the result to the size of C's type.
1642 									deUint32 bits = componentTypeInfo[dataTypes[3]].bits;
1643 									deUint32 mask = (bits == 32) ? 0xFFFFFFFFU : ((1U << bits) - 1U);
1644 									ref &= mask;
1645 								}
1646 
1647 								// When loading with stride 0, ij for matrix D is different from matrix C
1648 								if (m_data.colMajor)
1649 									ij = mX * M + i + strides[2] * (mY * N + j);
1650 								else
1651 									ij = mX * N + j + strides[2] * (mY * M + i);
1652 
1653 								deUint32 Dij = getDataInt(ptrs[3], dataTypes[3], ij);
1654 
1655 								if (ref != Dij)
1656 								{
1657 									res = QP_TEST_RESULT_FAIL;
1658 								}
1659 							}
1660 						}
1661 					}
1662 				}
1663 			}
1664 		}
1665 
1666 		if (res != QP_TEST_RESULT_PASS)
1667 		{
1668 			finalres = res;
1669 
1670 			log << tcu::TestLog::Message << "failed with M = " << M << ", N = " << N << ", K = " << K << tcu::TestLog::EndMessage;
1671 
1672 #ifdef COOPERATIVE_MATRIX_EXTENDED_DEBUG
1673 			for (int i = 0; i < 4; i++)
1674 			{
1675 				const char* matrixNames[] = { "A", "B", "C", "D" };
1676 
1677 				log << tcu::TestLog::Message
1678 					<< "Matrix " << matrixNames[i]
1679 					<< "[rows="
1680 					<< m_data.subgroupsPerWorkgroupY * m_data.workgroupsY * dims[i].rows
1681 					<< ", cols="
1682 					<< m_data.subgroupsPerWorkgroupX * m_data.workgroupsX * dims[i].cols << "]:\n"
1683 					<< dumpWholeMatrix(ptrs[i], dataTypes[i], m_data.colMajor, totalElements[i], strides[i])
1684 					<< tcu::TestLog::EndMessage;
1685 			}
1686 #endif
1687 		}
1688 		else
1689 		{
1690 			if (finalres == QP_TEST_RESULT_NOT_SUPPORTED)
1691 				finalres = res;
1692 		}
1693 	}
1694 
1695 	return tcu::TestStatus(finalres, qpGetTestResultName(finalres));
1696 }
1697 
getUseType(UseType useType)1698 const char* getUseType (UseType useType)
1699 {
1700 	switch (useType)
1701 	{
1702 		case UT_NV:			return "nv";
1703 		case UT_KHR_A:		return "khr_a";
1704 		case UT_KHR_B:		return "khr_b";
1705 		case UT_KHR_Result:	return "khr_r";
1706 		default:			TCU_THROW(InternalError, "Unknown use type");
1707 	}
1708 }
1709 
createCooperativeMatrixTestsInternal(tcu::TestContext & testCtx,vk::ComputePipelineConstructionType computePipelineConstructionType,UseType useType)1710 tcu::TestCaseGroup*	createCooperativeMatrixTestsInternal (tcu::TestContext& testCtx, vk::ComputePipelineConstructionType computePipelineConstructionType, UseType useType)
1711 {
1712 	de::MovePtr<tcu::TestCaseGroup> group	(new tcu::TestCaseGroup(testCtx, getUseType(useType)));
1713 
1714 	typedef struct
1715 	{
1716 		deUint32				value;
1717 		const char*				name;
1718 	} TestGroupCase;
1719 
1720 	typedef struct
1721 	{
1722 		deUint32				value[2];
1723 		const char*				name;
1724 	} TestGroupCase2;
1725 
1726 	typedef struct
1727 	{
1728 		SubgroupSizeMode		value;
1729 		const char*				name;
1730 	} SubGroubSizes;
1731 
1732 	TestGroupCase ttCases[] =
1733 	{
1734 		// OpCooperativeMatrixLength
1735 		{ TT_LENGTH,				"length"},
1736 		// OpConstantComposite
1737 		{ TT_CONSTANT,				"constant"},
1738 		// OpCompositeConstruct
1739 		{ TT_COMPOSITE,				"composite"},
1740 		// OpCompositeExtract
1741 		{ TT_COMPOSITE_RVALUE,		"composite_rvalue"},
1742 		// OpFAdd/OpIAdd
1743 		{ TT_ADD,					"add"},
1744 		// OpFSub/OpISub
1745 		{ TT_SUB,					"sub"},
1746 		// OpFDiv/OpSDiv/OpUDiv
1747 		{ TT_DIV,					"div"},
1748 		// OpFMul/OpIMul
1749 		{ TT_MUL,					"mul"},
1750 		// OpFNegate/OpSNegate
1751 		{ TT_NEGATE,				"negate"},
1752 		// OpMatrixTimesScalar
1753 		{ TT_MATRIXTIMESSCALAR,		"matrixtimesscalar"},
1754 		// OpFunctionParameter
1755 		{ TT_FUNC,					"func"},
1756 		// OpCooperativeMatrixMulAdd
1757 		{ TT_MATRIXMULADD,			"matrixmuladd"},
1758 		// OpCompositeConstruct w/array
1759 		{ TT_COMPOSITE_ARRAY,		"composite_array"},
1760 		// OpCooperativeMatrixMulAdd w/array
1761 		{ TT_MATRIXMULADD_ARRAY,	"matrixmuladd_array"},
1762 		// OpCooperativeMatrixMulAdd w/saturations
1763 		{ TT_MATRIXMULADD_SATURATED,"matrixmuladd_saturated"},
1764 		// OpCooperativeMatrixMulAdd w/wrapping
1765 		{ TT_MATRIXMULADD_WRAPPING,	"matrixmuladd_wrapping"},
1766 		// OpCooperativeMatrixMulAdd w/stride==0
1767 		{ TT_MATRIXMULADD_STRIDE0,	"matrixmuladd_stride0"},
1768 	};
1769 	TestGroupCase2 dtCases[] =
1770 	{
1771 		// A/B are fp32 C/D are fp32
1772 		{ { VK_COMPONENT_TYPE_FLOAT32_KHR,	VK_COMPONENT_TYPE_FLOAT32_KHR },	"float32_float32"},
1773 		// A/B are fp32 C/D are fp16
1774 		{ { VK_COMPONENT_TYPE_FLOAT32_KHR,	VK_COMPONENT_TYPE_FLOAT16_KHR },	"float32_float16"},
1775 		// A/B are fp16 C/D are fp32
1776 		{ { VK_COMPONENT_TYPE_FLOAT16_KHR,	VK_COMPONENT_TYPE_FLOAT32_KHR },	"float16_float32"},
1777 		// A/B are fp16 C/D are fp16
1778 		{ { VK_COMPONENT_TYPE_FLOAT16_KHR,	VK_COMPONENT_TYPE_FLOAT16_KHR },	"float16_float16"},
1779 		// A/B are u8 C/D are u8
1780 		{ { VK_COMPONENT_TYPE_UINT8_KHR,	VK_COMPONENT_TYPE_UINT8_KHR },		"uint8_uint8"},
1781 		// A/B are u8 C/D are u32
1782 		{ { VK_COMPONENT_TYPE_UINT8_KHR,	VK_COMPONENT_TYPE_UINT32_KHR },		"uint8_uint32"},
1783 		// A/B are s8 C/D are s8
1784 		{ { VK_COMPONENT_TYPE_SINT8_KHR,	VK_COMPONENT_TYPE_SINT8_KHR },		"sint8_sint8"},
1785 		// A/B are s8 C/D are s32
1786 		{ { VK_COMPONENT_TYPE_SINT8_KHR,	VK_COMPONENT_TYPE_SINT32_KHR },		"sint8_sint32"},
1787 		// A/B are u8 C/D are s32
1788 		{ { VK_COMPONENT_TYPE_UINT8_KHR,	VK_COMPONENT_TYPE_SINT32_KHR },		"uint8_sint32"},
1789 		// A/B are u32 C/D are u32
1790 		{ { VK_COMPONENT_TYPE_UINT32_KHR,	VK_COMPONENT_TYPE_UINT32_KHR },		"uint32_uint32"},
1791 		// A/B are u32 C/D are u8
1792 		{ { VK_COMPONENT_TYPE_UINT32_KHR,	VK_COMPONENT_TYPE_UINT8_KHR },		"uint32_uint8"},
1793 		// A/B are s32 C/D are s32
1794 		{ { VK_COMPONENT_TYPE_SINT32_KHR,	VK_COMPONENT_TYPE_SINT32_KHR },		"sint32_sint32"},
1795 		// A/B are s32 C/D are s8
1796 		{ { VK_COMPONENT_TYPE_SINT32_KHR,	VK_COMPONENT_TYPE_SINT8_KHR },		"sint32_sint8"},
1797 	};
1798 	SubGroubSizes sgsCases[] =
1799 	{
1800 		// Default subgroup size
1801 		{ SUBGROUP_SIZE_NONE,	"" },
1802 		// Minimum subgroup size
1803 		{ SUBGROUP_SIZE_MIN,	"_min"},
1804 		// Maximum subgroup size
1805 		{ SUBGROUP_SIZE_MAX,	"_max"},
1806 	};
1807 
1808 	TestGroupCase colCases[] =
1809 	{
1810 		{ 0,		"rowmajor"},
1811 		{ 1,		"colmajor"},
1812 	};
1813 
1814 	TestGroupCase scCases[] =
1815 	{
1816 		// SSBO
1817 		{ SC_BUFFER,						"buffer"},
1818 		// shared memory
1819 		{ SC_WORKGROUP,						"workgroup"},
1820 		// SSBO w/variable pointers
1821 		{ SC_BUFFER_VARIABLE_POINTERS,		"buffer_varptr"},
1822 		// shared memory w/variable pointers
1823 		{ SC_WORKGROUP_VARIABLE_POINTERS,	"workgroup_varptr"},
1824 		// physical_storage_buffer
1825 		{ SC_PHYSICAL_STORAGE_BUFFER,		"physical_buffer"},
1826 	};
1827 
1828 	// Types tested for conversions. Excludes 64b types.
1829 	VkComponentTypeKHR allTypes[] =
1830 	{
1831 		VK_COMPONENT_TYPE_FLOAT16_KHR,
1832 		VK_COMPONENT_TYPE_FLOAT32_KHR,
1833 		VK_COMPONENT_TYPE_SINT8_KHR,
1834 		VK_COMPONENT_TYPE_SINT16_KHR,
1835 		VK_COMPONENT_TYPE_SINT32_KHR,
1836 		VK_COMPONENT_TYPE_UINT8_KHR,
1837 		VK_COMPONENT_TYPE_UINT16_KHR,
1838 		VK_COMPONENT_TYPE_UINT32_KHR,
1839 	};
1840 
1841 	for (int ttNdx = 0; ttNdx < DE_LENGTH_OF_ARRAY(ttCases); ttNdx++)
1842 	{
1843 		const TestType	testType = (TestType)ttCases[ttNdx].value;
1844 
1845 		for (int sgsNdx = 0; sgsNdx < DE_LENGTH_OF_ARRAY(sgsCases); sgsNdx++)
1846 		{
1847 			if (testType != TT_MATRIXMULADD && sgsCases[sgsNdx].value != SUBGROUP_SIZE_NONE)
1848 				continue;
1849 
1850 			if (testType == TT_MATRIXMULADD && sgsCases[sgsNdx].value != SUBGROUP_SIZE_NONE && useType == UT_NV)
1851 				continue;
1852 
1853 			const string					name	= string(ttCases[ttNdx].name) + sgsCases[sgsNdx].name;
1854 			de::MovePtr<tcu::TestCaseGroup>	ttGroup	(new tcu::TestCaseGroup(testCtx, name.c_str()));
1855 
1856 			for (int dtNdx = 0; dtNdx < DE_LENGTH_OF_ARRAY(dtCases); dtNdx++)
1857 			{
1858 				de::MovePtr<tcu::TestCaseGroup> dtGroup(new tcu::TestCaseGroup(testCtx, dtCases[dtNdx].name));
1859 				for (int scNdx = 0; scNdx < DE_LENGTH_OF_ARRAY(scCases); scNdx++)
1860 				{
1861 					de::MovePtr<tcu::TestCaseGroup> scGroup(new tcu::TestCaseGroup(testCtx, scCases[scNdx].name));
1862 					for (int colNdx = 0; colNdx < DE_LENGTH_OF_ARRAY(colCases); colNdx++)
1863 					{
1864 						const VkComponentTypeKHR	inputType = (VkComponentTypeKHR)dtCases[dtNdx].value[0];
1865 						const VkComponentTypeKHR	outputType = (VkComponentTypeKHR)dtCases[dtNdx].value[1];
1866 						const bool					isMatrixMul = isMatrixMulAddOp(testType);
1867 
1868 						// useType isn't used for matrixmul shaders. Don't generate 3 copies of those tests.
1869 						if (isMatrixMul && (useType == UT_KHR_A || useType == UT_KHR_B)) {
1870 							continue;
1871 						}
1872 
1873 						// NV extension doesn't support mixing signedness
1874 						if (isMatrixMul && (useType == UT_NV) && isSIntType(inputType) != isSIntType(outputType)) {
1875 							continue;
1876 						}
1877 
1878 						if (!isMatrixMul && inputType != outputType)
1879 							continue;
1880 
1881 						if (isMatrixMul && componentTypeInfo[inputType].bits > componentTypeInfo[outputType].bits)
1882 							continue;
1883 
1884 						if (testType == TT_MUL && useType == UT_NV)
1885 							continue;
1886 
1887 						if (testType == TT_MATRIXMULADD_SATURATED && (isFloatType(inputType) || useType == UT_NV))
1888 							continue;
1889 
1890 						if (testType == TT_MATRIXMULADD_WRAPPING && (isFloatType(inputType) || useType == UT_NV))
1891 							continue;
1892 
1893 						if (testType == TT_MATRIXMULADD_STRIDE0 && useType == UT_NV)
1894 							continue;
1895 
1896 						if (testType == TT_LENGTH && useType != UT_NV && (outputType == VK_COMPONENT_TYPE_SINT8_KHR || outputType == VK_COMPONENT_TYPE_UINT8_KHR))
1897 							continue;
1898 
1899 						CaseDef c =
1900 						{
1901 							testType,							//  TestType							testtype;
1902 							2u,									//  deUint32							subgroupsPerWorkgroupX;
1903 							2u,									//  deUint32							subgroupsPerWorkgroupY;
1904 							4u,									//  deUint32							workgroupsX;
1905 							4u,									//  deUint32							workgroupsY;
1906 							inputType,							//  VkComponentTypeKHR					inputType;
1907 							outputType,							//  VkComponentTypeKHR					outputType;
1908 							!!colCases[colNdx].value,			//  bool								colMajor;
1909 							(StorageClass)scCases[scNdx].value,	//  StorageClass						storageClass;
1910 							useType,							//  UseType								useType;
1911 							sgsCases[sgsNdx].value,				//  SubgroupSizeMode					subgroupSizeMode;
1912 							computePipelineConstructionType,	//  vk::ComputePipelineConstructionType	computePipelineConstructionType;
1913 						};
1914 
1915 						scGroup->addChild(new CooperativeMatrixTestCase(testCtx, colCases[colNdx].name, c));
1916 					}
1917 					dtGroup->addChild(scGroup.release());
1918 				}
1919 				ttGroup->addChild(dtGroup.release());
1920 			}
1921 			group->addChild(ttGroup.release());
1922 		}
1923 	}
1924 
1925 	{
1926 		const string					name	= string("convert");
1927 		const string					desc	= string("OpFConvert/OpSConvert/OpUConvert/OpBitcast");
1928 		de::MovePtr<tcu::TestCaseGroup>	ttGroup	(new tcu::TestCaseGroup(testCtx, name.c_str()));
1929 
1930 		for (int dtNdx1 = 0; dtNdx1 < DE_LENGTH_OF_ARRAY(allTypes); dtNdx1++)
1931 		{
1932 			for (int dtNdx2 = 0; dtNdx2 < DE_LENGTH_OF_ARRAY(allTypes); dtNdx2++)
1933 			{
1934 				const VkComponentTypeKHR	inputType = (VkComponentTypeKHR)allTypes[dtNdx1];
1935 				const VkComponentTypeKHR	outputType = (VkComponentTypeKHR)allTypes[dtNdx2];
1936 				const string			name2	= string("input_") + string(componentTypeInfo[inputType].typeName) + string("_output_") + string(componentTypeInfo[outputType].typeName);
1937 				de::MovePtr<tcu::TestCaseGroup> dtGroup(new tcu::TestCaseGroup(testCtx, name2.c_str()));
1938 				for (int scNdx = 0; scNdx < DE_LENGTH_OF_ARRAY(scCases); scNdx++)
1939 				{
1940 					de::MovePtr<tcu::TestCaseGroup> scGroup(new tcu::TestCaseGroup(testCtx, scCases[scNdx].name));
1941 					for (int colNdx = 0; colNdx < DE_LENGTH_OF_ARRAY(colCases); colNdx++)
1942 					{
1943 
1944 						CaseDef c =
1945 						{
1946 							TT_CONVERT,							//  TestType							testtype;
1947 							2u,									//  deUint32							subgroupsPerWorkgroupX;
1948 							2u,									//  deUint32							subgroupsPerWorkgroupY;
1949 							4u,									//  deUint32							workgroupsX;
1950 							4u,									//  deUint32							workgroupsY;
1951 							inputType,							//  VkComponentTypeKHR					inputType;
1952 							outputType,							//  VkComponentTypeKHR					outputType;
1953 							!!colCases[colNdx].value,			//  bool								colMajor;
1954 							(StorageClass)scCases[scNdx].value,	//  StorageClass						storageClass;
1955 							useType,							//  UseType								useType;
1956 							SUBGROUP_SIZE_NONE,					//  SubgroupSizeMode					subgroupSizeMode;
1957 							computePipelineConstructionType,	//  vk::ComputePipelineConstructionType	computePipelineConstructionType;
1958 						};
1959 
1960 						scGroup->addChild(new CooperativeMatrixTestCase(testCtx, colCases[colNdx].name, c));
1961 					}
1962 					dtGroup->addChild(scGroup.release());
1963 				}
1964 				ttGroup->addChild(dtGroup.release());
1965 			}
1966 		}
1967 		group->addChild(ttGroup.release());
1968 	}
1969 
1970 	return group.release();
1971 }
1972 
1973 }	// anonymous
1974 
createCooperativeMatrixTests(tcu::TestContext & testCtx,vk::ComputePipelineConstructionType computePipelineConstructionType)1975 tcu::TestCaseGroup* createCooperativeMatrixTests (tcu::TestContext& testCtx, vk::ComputePipelineConstructionType computePipelineConstructionType)
1976 {
1977 	de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(testCtx, "cooperative_matrix"));
1978 
1979 	group->addChild(createCooperativeMatrixTestsInternal(testCtx, computePipelineConstructionType, UT_NV));
1980 	group->addChild(createCooperativeMatrixTestsInternal(testCtx, computePipelineConstructionType, UT_KHR_A));
1981 	group->addChild(createCooperativeMatrixTestsInternal(testCtx, computePipelineConstructionType, UT_KHR_B));
1982 	group->addChild(createCooperativeMatrixTestsInternal(testCtx, computePipelineConstructionType, UT_KHR_Result));
1983 
1984 	return group.release();
1985 }
1986 
1987 }	// compute
1988 }	// vkt
1989