• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2019 The Khronos Group Inc.
6  * Copyright (c) 2018-2019 NVIDIA Corporation
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  *	  http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  *
20  *//*!
21  * \file
22  * \brief Vulkan Cooperative Matrix tests
23  *//*--------------------------------------------------------------------*/
24 
25 #include "vktComputeCooperativeMatrixTests.hpp"
26 
27 #include "vkBufferWithMemory.hpp"
28 #include "vkImageWithMemory.hpp"
29 #include "vkQueryUtil.hpp"
30 #include "vkBuilderUtil.hpp"
31 #include "vkCmdUtil.hpp"
32 #include "vkTypeUtil.hpp"
33 #include "vkObjUtil.hpp"
34 
35 #include "vktTestGroupUtil.hpp"
36 #include "vktTestCase.hpp"
37 
38 #include "deDefs.h"
39 #include "deFloat16.h"
40 #include "deMath.h"
41 #include "deRandom.h"
42 #include "deSharedPtr.hpp"
43 #include "deString.h"
44 
45 #include "tcuTestCase.hpp"
46 #include "tcuTestLog.hpp"
47 
48 #include <string>
49 #include <sstream>
50 #include <set>
51 #include <algorithm>
52 
53 namespace vkt
54 {
55 namespace compute
56 {
57 namespace
58 {
59 using namespace vk;
60 using namespace std;
61 
62 typedef enum
63 {
64 	TT_LENGTH = 0,
65 	TT_CONSTANT,
66 	TT_CONVERT,
67 	TT_COMPOSITE,
68 	TT_COMPOSITE_RVALUE,
69 	TT_ADD,
70 	TT_SUB,
71 	TT_DIV,
72 	TT_NEGATE,
73 	TT_MATRIXTIMESSCALAR,
74 	TT_FUNC,
75 	TT_MATRIXMULADD,
76 	TT_COMPOSITE_ARRAY,
77 	TT_MATRIXMULADD_ARRAY,
78 } TestType;
79 
80 typedef enum
81 {
82 	SC_BUFFER = 0,
83 	SC_WORKGROUP,
84 	SC_WORKGROUP_VARIABLE_POINTERS,
85 	SC_BUFFER_VARIABLE_POINTERS,
86 	SC_PHYSICAL_STORAGE_BUFFER,
87 } StorageClass;
88 
89 const VkFlags allShaderStages = VK_SHADER_STAGE_COMPUTE_BIT;
90 
91 struct CaseDef
92 {
93 	TestType testType;
94 	deUint32 subgroupsPerWorkgroupX;
95 	deUint32 subgroupsPerWorkgroupY;
96 	deUint32 workgroupsX;
97 	deUint32 workgroupsY;
98 	VkComponentTypeNV inputType;
99 	VkComponentTypeNV outputType;
100 	bool colMajor;
101 	StorageClass storageClass;
102 };
103 
104 class CooperativeMatrixTestInstance : public TestInstance
105 {
106 public:
107 						CooperativeMatrixTestInstance	(Context& context, const CaseDef& data);
108 						~CooperativeMatrixTestInstance	(void);
109 	tcu::TestStatus		iterate				(void);
110 private:
111 	CaseDef			m_data;
112 };
113 
CooperativeMatrixTestInstance(Context & context,const CaseDef & data)114 CooperativeMatrixTestInstance::CooperativeMatrixTestInstance (Context& context, const CaseDef& data)
115 	: vkt::TestInstance		(context)
116 	, m_data				(data)
117 {
118 }
119 
~CooperativeMatrixTestInstance(void)120 CooperativeMatrixTestInstance::~CooperativeMatrixTestInstance (void)
121 {
122 }
123 
124 class CooperativeMatrixTestCase : public TestCase
125 {
126 	public:
127 								CooperativeMatrixTestCase		(tcu::TestContext& context, const char* name, const char* desc, const CaseDef data);
128 								~CooperativeMatrixTestCase	(void);
129 	virtual	void				initPrograms		(SourceCollections& programCollection) const;
130 	virtual TestInstance*		createInstance		(Context& context) const;
131 	virtual void				checkSupport		(Context& context) const;
132 
133 private:
134 	CaseDef					m_data;
135 };
136 
CooperativeMatrixTestCase(tcu::TestContext & context,const char * name,const char * desc,const CaseDef data)137 CooperativeMatrixTestCase::CooperativeMatrixTestCase (tcu::TestContext& context, const char* name, const char* desc, const CaseDef data)
138 	: vkt::TestCase	(context, name, desc)
139 	, m_data		(data)
140 {
141 }
142 
~CooperativeMatrixTestCase(void)143 CooperativeMatrixTestCase::~CooperativeMatrixTestCase	(void)
144 {
145 }
146 
checkSupport(Context & context) const147 void CooperativeMatrixTestCase::checkSupport(Context& context) const
148 {
149 	if (!context.contextSupports(vk::ApiVersion(1, 1, 0)))
150 	{
151 		TCU_THROW(NotSupportedError, "Vulkan 1.1 not supported");
152 	}
153 
154 	if (!context.getCooperativeMatrixFeatures().cooperativeMatrix)
155 	{
156 		TCU_THROW(NotSupportedError, "cooperativeMatrix not supported");
157 	}
158 
159 	if (!context.getVulkanMemoryModelFeatures().vulkanMemoryModel)
160 	{
161 		TCU_THROW(NotSupportedError, "vulkanMemoryModel not supported");
162 	}
163 
164 	if ((m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS || m_data.storageClass == SC_BUFFER_VARIABLE_POINTERS) &&
165 		!context.getVariablePointersFeatures().variablePointers)
166 	{
167 		TCU_THROW(NotSupportedError, "variable pointers not supported");
168 	}
169 
170 	if (m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER && !context.isBufferDeviceAddressSupported())
171 	{
172 		TCU_THROW(NotSupportedError, "buffer device address not supported");
173 	}
174 
175 	if (!context.getShaderFloat16Int8Features().shaderFloat16 &&
176 		(m_data.inputType == VK_COMPONENT_TYPE_FLOAT16_NV || m_data.outputType == VK_COMPONENT_TYPE_FLOAT16_NV))
177 	{
178 		TCU_THROW(NotSupportedError, "shaderFloat16 not supported");
179 	}
180 
181 	deUint32 propertyCount = 0;
182 	VkCooperativeMatrixPropertiesNV *pProperties;
183 	context.getInstanceInterface().getPhysicalDeviceCooperativeMatrixPropertiesNV(context.getPhysicalDevice(), &propertyCount, DE_NULL);
184 	if (propertyCount == 0)
185 		TCU_THROW(NotSupportedError, "cooperative matrices not supported");
186 
187 	bool supported[2] = { false, false };
188 	pProperties = new VkCooperativeMatrixPropertiesNV[propertyCount];
189 
190 	for (deUint32 i = 0; i < propertyCount; ++i)
191 	{
192 		VkCooperativeMatrixPropertiesNV *p = &pProperties[i];
193 		p->sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_PROPERTIES_NV;
194 		p->pNext = DE_NULL;
195 	}
196 
197 	context.getInstanceInterface().getPhysicalDeviceCooperativeMatrixPropertiesNV(context.getPhysicalDevice(), &propertyCount, pProperties);
198 
199 	for (deUint32 i = 0; i < propertyCount; ++i)
200 	{
201 		VkCooperativeMatrixPropertiesNV *p = &pProperties[i];
202 		if (m_data.testType == TT_MATRIXMULADD ||
203 			m_data.testType == TT_MATRIXMULADD_ARRAY)
204 		{
205 			if (p->AType == m_data.inputType &&
206 				p->BType == m_data.inputType &&
207 				p->CType == m_data.outputType &&
208 				p->DType == m_data.outputType &&
209 				p->scope == VK_SCOPE_SUBGROUP_NV)
210 			{
211 				supported[0] = supported[1] = true;
212 			}
213 		}
214 		else
215 		{
216 			VkComponentTypeNV types[2] = { m_data.inputType, m_data.outputType };
217 
218 			for (deUint32 j = 0; j < 2; ++j)
219 			{
220 				if (p->scope == VK_SCOPE_SUBGROUP_NV && (p->AType == types[j] || p->BType == types[j] || p->CType == types[j] || p->DType == types[j]))
221 				{
222 					supported[j] = true;
223 				}
224 			}
225 		}
226 	}
227 
228 	delete [] pProperties;
229 
230 	if (!supported[0] || !supported[1])
231 		TCU_THROW(NotSupportedError, "cooperative matrix combination not supported");
232 }
233 
234 struct {
235 	const char *typeName;
236 	const char *coopmatTypeName;
237 	deUint32 bits;
238 } componentTypeInfo[] =
239 {
240 	{ "float16_t",	"fcoopmatNV",	16 },
241 	{ "float32_t",	"fcoopmatNV",	32 },
242 	{ "float64_t",	"fcoopmatNV",	64 },
243 	{ "int8_t",		"icoopmatNV",	8 },
244 	{ "int16_t",	"icoopmatNV",	16 },
245 	{ "int32_t",	"icoopmatNV",	32 },
246 	{ "int64_t",	"icoopmatNV",	64 },
247 	{ "uint8_t",	"ucoopmatNV",	8 },
248 	{ "uint16_t",	"ucoopmatNV",	16 },
249 	{ "uint32_t",	"ucoopmatNV",	32 },
250 	{ "uint64_t",	"ucoopmatNV",	64 },
251 };
252 
isFloatType(VkComponentTypeNV t)253 static bool isFloatType(VkComponentTypeNV t)
254 {
255 	switch (t)
256 	{
257 	default:
258 		return false;
259 	case VK_COMPONENT_TYPE_FLOAT16_NV:
260 	case VK_COMPONENT_TYPE_FLOAT32_NV:
261 	case VK_COMPONENT_TYPE_FLOAT64_NV:
262 		return true;
263 	}
264 }
265 
isSIntType(VkComponentTypeNV t)266 static bool isSIntType(VkComponentTypeNV t)
267 {
268 	switch (t)
269 	{
270 	default:
271 		return false;
272 	case VK_COMPONENT_TYPE_SINT8_NV:
273 	case VK_COMPONENT_TYPE_SINT16_NV:
274 	case VK_COMPONENT_TYPE_SINT32_NV:
275 	case VK_COMPONENT_TYPE_SINT64_NV:
276 		return true;
277 	}
278 }
279 
initPrograms(SourceCollections & programCollection) const280 void CooperativeMatrixTestCase::initPrograms (SourceCollections& programCollection) const
281 {
282 	std::stringstream css;
283 	css << "#version 450 core\n";
284 	css << "#pragma use_vulkan_memory_model\n";
285 	css <<
286 		"#extension GL_KHR_shader_subgroup_basic : enable\n"
287 		"#extension GL_KHR_memory_scope_semantics : enable\n"
288 		"#extension GL_NV_cooperative_matrix : enable\n"
289 		"#extension GL_NV_integer_cooperative_matrix : enable\n"
290 		"#extension GL_EXT_shader_explicit_arithmetic_types_float16 : enable\n"
291 		"#extension GL_EXT_shader_explicit_arithmetic_types_float32 : enable\n"
292 		"#extension GL_EXT_shader_explicit_arithmetic_types_int8 : enable\n"
293 		"#extension GL_EXT_shader_explicit_arithmetic_types_int32 : enable\n"
294 		"#extension GL_EXT_buffer_reference : enable\n"
295 		"// strides overriden by spec constants\n"
296 		"layout(constant_id = 2) const int AStride = 1;\n"
297 		"layout(constant_id = 3) const int BStride = 1;\n"
298 		"layout(constant_id = 4) const int CStride = 1;\n"
299 		"layout(constant_id = 5) const int OStride = 1;\n"
300 		"layout(constant_id = 6) const int M = 1;\n"
301 		"layout(constant_id = 7) const int N = 1;\n"
302 		"layout(constant_id = 8) const int K = 1;\n"
303 		"layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z = 1) in;\n";
304 
305 	if (m_data.storageClass == SC_BUFFER_VARIABLE_POINTERS || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
306 		css << "#pragma use_variable_pointers\n";
307 
308 	struct
309 	{
310 		string rows, cols;
311 	} dims[4];
312 
313 	if (m_data.testType == TT_MATRIXMULADD ||
314 		m_data.testType == TT_MATRIXMULADD_ARRAY)
315 	{
316 		dims[0].rows = "M";
317 		dims[0].cols = "K";
318 		dims[1].rows = "K";
319 		dims[1].cols = "N";
320 		dims[2].rows = "M";
321 		dims[2].cols = "N";
322 		dims[3].rows = "M";
323 		dims[3].cols = "N";
324 	}
325 	else
326 	{
327 		dims[0].rows = "M";
328 		dims[0].cols = "N";
329 		dims[1].rows = "M";
330 		dims[1].cols = "N";
331 		dims[2].rows = "M";
332 		dims[2].cols = "N";
333 		dims[3].rows = "M";
334 		dims[3].cols = "N";
335 	}
336 
337 	const char *typeStrA = componentTypeInfo[m_data.inputType].typeName;
338 	const char *typeStrB = componentTypeInfo[m_data.inputType].typeName;
339 	const char *typeStrC = componentTypeInfo[m_data.outputType].typeName;
340 	const char *typeStrO = componentTypeInfo[m_data.outputType].typeName;
341 
342 	css << "const int workgroupsX = " << m_data.workgroupsX << ";\n";
343 	css << "const uvec2 subgroupsPerWG = uvec2(" << m_data.subgroupsPerWorkgroupX << ", " << m_data.subgroupsPerWorkgroupY << ");\n";
344 
345 	if (m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER)
346 	{
347 		css << "layout(buffer_reference) buffer InputA { " << typeStrA << " x[]; };\n";
348 		css << "layout(buffer_reference) buffer InputB { " << typeStrB << " x[]; };\n";
349 		css << "layout(buffer_reference) buffer InputC { " << typeStrC << " x[]; };\n";
350 		css << "layout(buffer_reference) buffer Output { " << typeStrO << " x[]; };\n";
351 		css << "layout(set=0, binding=4) buffer Params { InputA inputA; InputB inputB; InputC inputC; Output outputO; } params;\n";
352 	}
353 	else
354 	{
355 		css << "layout(set=0, binding=0) coherent buffer InputA { " << typeStrA << " x[]; } inputA;\n";
356 		css << "layout(set=0, binding=1) coherent buffer InputB { " << typeStrB << " x[]; } inputB;\n";
357 		css << "layout(set=0, binding=2) coherent buffer InputC { " << typeStrC << " x[]; } inputC;\n";
358 		css << "layout(set=0, binding=3) coherent buffer Output { " << typeStrO << " x[]; } outputO;\n";
359 	}
360 
361 	if (m_data.storageClass == SC_WORKGROUP || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
362 	{
363 		css << "shared " << typeStrA << " sharedA[" << dims[0].rows << " * " << dims[0].cols << " * subgroupsPerWG.x * subgroupsPerWG.y];\n";
364 		css << "shared " << typeStrB << " sharedB[" << dims[1].rows << " * " << dims[1].cols << " * subgroupsPerWG.x * subgroupsPerWG.y];\n";
365 		css << "shared " << typeStrC << " sharedC[" << dims[2].rows << " * " << dims[2].cols << " * subgroupsPerWG.x * subgroupsPerWG.y];\n";
366 		css << "shared " << typeStrO << " sharedO[" << dims[3].rows << " * " << dims[3].cols << " * subgroupsPerWG.x * subgroupsPerWG.y];\n";
367 	}
368 
369 	std::stringstream matAType, matBType, matCType, outputMatType;
370 
371 	matAType  << componentTypeInfo[m_data.inputType].coopmatTypeName << "<" << componentTypeInfo[m_data.inputType].bits << ", gl_ScopeSubgroup, " << dims[0].rows << ", " << dims[0].cols << ">";
372 	matBType  << componentTypeInfo[m_data.inputType].coopmatTypeName << "<" << componentTypeInfo[m_data.inputType].bits << ", gl_ScopeSubgroup, " << dims[1].rows << ", " << dims[1].cols << ">";
373 	matCType  << componentTypeInfo[m_data.outputType].coopmatTypeName << "<" << componentTypeInfo[m_data.outputType].bits << ", gl_ScopeSubgroup, " << dims[2].rows << ", " << dims[2].cols << ">";
374 	outputMatType << componentTypeInfo[m_data.outputType].coopmatTypeName << "<" << componentTypeInfo[m_data.outputType].bits << ", gl_ScopeSubgroup, " << dims[3].rows << ", " << dims[3].cols << ">";
375 
376 	css << matAType.str() << " matA;\n";
377 	css << matBType.str() << " matB;\n";
378 	css << matCType.str() << " matC;\n";
379 	css << outputMatType.str() << " matO;\n";
380 
381 	if (m_data.testType == TT_CONSTANT)
382 		css << "const " << outputMatType.str() << " matConst = " << outputMatType.str() << "(1.0);\n";
383 
384 	if (m_data.testType == TT_FUNC)
385 		css << matAType.str() << " f(" << matAType.str() << " m) { return -m; }\n";
386 
387 	css <<
388 		"void main()\n"
389 		"{\n"
390 		// matrixID is the x,y index of the matrix owned by this subgroup.
391 		"   uvec2 subgroupXY = uvec2(gl_SubgroupID % subgroupsPerWG.x, gl_SubgroupID / subgroupsPerWG.x);\n"
392 		"   uvec2 matrixID = uvec2(gl_WorkGroupID.xy) * subgroupsPerWG + subgroupXY;\n";
393 
394 	if (m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER)
395 	{
396 		css << "   InputA inputA = params.inputA;\n";
397 		css << "   InputB inputB = params.inputB;\n";
398 		css << "   InputC inputC = params.inputC;\n";
399 		css << "   Output outputO = params.outputO;\n";
400 	}
401 
402 	string strides[4];
403 	for (deUint32 i = 0; i < 4; ++i)
404 	{
405 		strides[i] = (m_data.colMajor ? dims[i].rows : dims[i].cols) + string(" * ") + de::toString(m_data.subgroupsPerWorkgroupX * m_data.workgroupsX);
406 	}
407 
408 	// element<i> is the starting element in buffer memory.
409 	// elementS<i> is the starting element in shared memory.
410 	css << "   uint element0 = " << strides[0] << " * " << (m_data.colMajor ? dims[0].cols : dims[0].rows) << " * matrixID.y + " << (m_data.colMajor ? dims[0].rows : dims[0].cols) << " * matrixID.x;\n"
411 		   "   uint element1 = " << strides[1] << " * " << (m_data.colMajor ? dims[1].cols : dims[1].rows) << " * matrixID.y + " << (m_data.colMajor ? dims[1].rows : dims[1].cols) << " * matrixID.x;\n"
412 		   "   uint element2 = " << strides[2] << " * " << (m_data.colMajor ? dims[2].cols : dims[2].rows) << " * matrixID.y + " << (m_data.colMajor ? dims[2].rows : dims[2].cols) << " * matrixID.x;\n"
413 		   "   uint element3 = " << strides[3] << " * " << (m_data.colMajor ? dims[3].cols : dims[3].rows) << " * matrixID.y + " << (m_data.colMajor ? dims[3].rows : dims[3].cols) << " * matrixID.x;\n"
414 		   "   uint elementS0, elementS1, elementS2, elementS3;\n";
415 
416 	// For shared memory tests, copy the matrix from buffer memory into
417 	// workgroup memory. For simplicity, do it all on a single thread.
418 	if (m_data.storageClass == SC_WORKGROUP || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
419 	{
420 		const char *name[] =
421 		{
422 			"sharedA",
423 			"sharedB",
424 			"sharedC",
425 		};
426 		const char *inputName[] =
427 		{
428 			"inputA",
429 			"inputB",
430 			"inputC",
431 		};
432 		for (deUint32 m = 0; m < 4; ++m)
433 		{
434 			string sharedStride = strides[m] + " / workgroupsX";
435 			css << "       elementS" << m << " = " << sharedStride << " * " << (m_data.colMajor ? dims[m].cols : dims[m].rows) << " * subgroupXY.y + " << (m_data.colMajor ? dims[m].rows : dims[m].cols) << " * subgroupXY.x;\n";
436 		}
437 		css << "   if (subgroupElect()) {\n";
438 		// copy all three input buffers.
439 		for (deUint32 m = 0; m < 3; ++m)
440 		{
441 			string sharedStride = strides[m] + " / workgroupsX";
442 			css <<  "       for (int i = 0; i < " << dims[m].rows << "; ++i) {\n"
443 					"       for (int j = 0; j < " << dims[m].cols << "; ++j) {\n"
444 					"           int localElementInput = " << strides[m] << " * " << (m_data.colMajor ? "j" : "i") << " + " << (m_data.colMajor ? "i" : "j") << ";\n"
445 					"           int localElementShared = " << sharedStride << " * " << (m_data.colMajor ? "j" : "i") << " + " << (m_data.colMajor ? "i" : "j") << ";\n"
446 					"           " << name[m] << "[elementS" << m << " + localElementShared] = " << inputName[m] << ".x[element" << m << " + localElementInput];\n"
447 					"       }\n"
448 					"       }\n";
449 			strides[m] = sharedStride;
450 		}
451 		css << "   }\n";
452 		css << "   controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease);\n";
453 	}
454 
455 	const char *colMajor = (m_data.colMajor ? "true" : "false");
456 
457 	if (m_data.storageClass == SC_WORKGROUP || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
458 	{
459 		css <<  "   coopMatLoadNV(matA, sharedA, elementS0, " << strides[0] << ", " << colMajor << ");\n"
460 				"   coopMatLoadNV(matB, sharedB, elementS1, " << strides[1] << ", " << colMajor << ");\n"
461 				"   coopMatLoadNV(matC, sharedC, elementS2, " << strides[2] << ", " << colMajor << ");\n";
462 	}
463 	else
464 	{
465 		css << "   coopMatLoadNV(matA, inputA.x, element0, " << strides[0] << ", " << colMajor << ");\n"
466 			   "   coopMatLoadNV(matB, inputB.x, element1, " << strides[1] << ", " << colMajor << ");\n"
467 			   "   coopMatLoadNV(matC, inputC.x, element2, " << strides[2] << ", " << colMajor << ");\n";
468 	}
469 
470 	if (m_data.testType == TT_COMPOSITE_ARRAY ||
471 		m_data.testType == TT_MATRIXMULADD_ARRAY)
472 	{
473 		css << "   " << matAType.str() << " matAArr[2];\n    matAArr[1] = matA; matAArr[0] = " << matAType.str() << "(0.0);\n"
474 			   "   " << matBType.str() << " matBArr[2];\n    matBArr[1] = matB; matBArr[0] = " << matBType.str() << "(0.0);\n"
475 			   "   " << matCType.str() << " matCArr[2];\n    matCArr[1] = matC; matCArr[0] = " << matCType.str() << "(0.0);\n"
476 			   "   " << outputMatType.str() << " matOArr[2];\n";
477 	}
478 
479 	switch (m_data.testType)
480 	{
481 	default:
482 		DE_ASSERT(0);
483 		// fall through
484 	case TT_LENGTH:
485 		css << "   matO = " << outputMatType.str() << "(matO.length());\n";
486 		break;
487 	case TT_CONSTANT:
488 		css << "   matO = matConst;\n";
489 		break;
490 	case TT_CONVERT:
491 		css << "   matO = " << outputMatType.str() << "(matA);\n";
492 		break;
493 	case TT_COMPOSITE:
494 	case TT_COMPOSITE_RVALUE:
495 		css << "   for (int i = 0; i < matA.length(); ++i) {\n"
496 			   "       matO[i] = matA[i] + matB[i];\n"
497 			   "   }\n";
498 		if (m_data.testType == TT_COMPOSITE_RVALUE)
499 		{
500 			css << "   " << matAType.str() << " t = matA;\n"
501 				   "   matO[0] = (t += matB)[0];\n"
502 				   "   if (matA.length() > 0) {\n"
503 				   "       t = matA;\n"
504 				   "       matO[1] = (t += matB)[1];\n"
505 				   "   }\n";
506 		}
507 		break;
508 	case TT_COMPOSITE_ARRAY:
509 		css << "   for (int i = 0; i < matA.length(); ++i) {\n"
510 			   "       matOArr[1][i] = matAArr[1][i] + matBArr[1][i];\n"
511 			   "   }\n";
512 		break;
513 	case TT_ADD:
514 		css << "   matO = matA + matB;\n";
515 		break;
516 	case TT_SUB:
517 		css << "   matO = matA - matB;\n";
518 		break;
519 	case TT_DIV:
520 		css << "   matO = matA / matB;\n";
521 		break;
522 	case TT_NEGATE:
523 		css << "   matO = -matA;\n";
524 		break;
525 	case TT_FUNC:
526 		css << "   matO = f(matA);\n";
527 		break;
528 	case TT_MATRIXTIMESSCALAR:
529 		css << "   matO = (" << typeStrA << "(2.0)*matA)*" << typeStrA << "(3.0);\n";
530 		break;
531 	case TT_MATRIXMULADD:
532 		css << "   matO = coopMatMulAddNV(matA, matB, matC);\n";
533 		break;
534 	case TT_MATRIXMULADD_ARRAY:
535 		css << "   matOArr[1] = coopMatMulAddNV(matAArr[1], matBArr[1], matCArr[1]);\n";
536 		break;
537 	}
538 
539 	if (m_data.testType == TT_COMPOSITE_ARRAY ||
540 		m_data.testType == TT_MATRIXMULADD_ARRAY)
541 	{
542 		css << "   matOArr[0] = " << outputMatType.str() << "(0.0);\n";
543 		css << "   matO = matOArr[1];\n";
544 	}
545 
546 	if (m_data.storageClass == SC_WORKGROUP || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
547 	{
548 		string sharedStride = strides[3] + " / workgroupsX";
549 		css << "   coopMatStoreNV(matO, sharedO, elementS3, " << sharedStride << ", " << colMajor << ");\n";
550 		css << "   controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease);\n";
551 		css << "   if (subgroupElect()) {\n";
552 		css << "       for (int i = 0; i < " << dims[3].rows << "; ++i) {\n"
553 			   "       for (int j = 0; j < " << dims[3].cols << "; ++j) {\n"
554 			   "           int localElementInput = " << strides[3] << " * " << (m_data.colMajor ? "j" : "i") << " + " << (m_data.colMajor ? "i" : "j") << ";\n"
555 			   "           int localElementShared = " << sharedStride << " * " << (m_data.colMajor ? "j" : "i") << " + " << (m_data.colMajor ? "i" : "j") << ";\n"
556 			   "           outputO.x[element3 + localElementInput] = sharedO[elementS3 + localElementShared];\n"
557 			   "       }\n"
558 			   "       }\n";
559 		css << "   }\n";
560 	}
561 	else
562 	{
563 		css << "   coopMatStoreNV(matO, outputO.x, element3, " << strides[3] << ", " << colMajor << ");\n";
564 	}
565 
566 	css <<
567 		"}\n";
568 
569 	const vk::ShaderBuildOptions	buildOptions	(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
570 
571 	programCollection.glslSources.add("test") << glu::ComputeSource(css.str()) << buildOptions;
572 }
573 
createInstance(Context & context) const574 TestInstance* CooperativeMatrixTestCase::createInstance (Context& context) const
575 {
576 	return new CooperativeMatrixTestInstance(context, m_data);
577 }
578 
setDataFloat(void * base,VkComponentTypeNV dt,deUint32 i,float value)579 static void setDataFloat(void *base, VkComponentTypeNV dt, deUint32 i, float value)
580 {
581 	if (dt == VK_COMPONENT_TYPE_FLOAT32_NV)
582 	{
583 		((float *)base)[i] = value;
584 	}
585 	else
586 	{
587 		DE_ASSERT(dt == VK_COMPONENT_TYPE_FLOAT16_NV);
588 		((deFloat16 *)base)[i] = deFloat32To16(value);
589 	}
590 }
591 
getDataFloat(void * base,VkComponentTypeNV dt,deUint32 i)592 static float getDataFloat(void *base, VkComponentTypeNV dt, deUint32 i)
593 {
594 	if (dt == VK_COMPONENT_TYPE_FLOAT32_NV)
595 	{
596 		return ((float *)base)[i];
597 	}
598 	else
599 	{
600 		DE_ASSERT(dt == VK_COMPONENT_TYPE_FLOAT16_NV);
601 		return deFloat16To32(((deFloat16 *)base)[i]);
602 	}
603 }
604 
setDataInt(void * base,VkComponentTypeNV dt,deUint32 i,deUint32 value)605 static void setDataInt(void *base, VkComponentTypeNV dt, deUint32 i, deUint32 value)
606 {
607 	DE_ASSERT(componentTypeInfo[dt].bits <= 32);
608 	switch (dt) {
609 	default: DE_ASSERT(0); // fallthrough
610 	case VK_COMPONENT_TYPE_UINT8_NV:	((deUint8  *)base)[i] = (deUint8)value; break;
611 	case VK_COMPONENT_TYPE_UINT16_NV:	((deUint16 *)base)[i] = (deUint16)value; break;
612 	case VK_COMPONENT_TYPE_UINT32_NV:	((deUint32 *)base)[i] = (deUint32)value; break;
613 	case VK_COMPONENT_TYPE_SINT8_NV:	((deInt8  *)base)[i] = (deInt8)value; break;
614 	case VK_COMPONENT_TYPE_SINT16_NV:	((deInt16 *)base)[i] = (deInt16)value; break;
615 	case VK_COMPONENT_TYPE_SINT32_NV:	((deInt32 *)base)[i] = (deInt32)value; break;
616 	}
617 }
618 
getDataInt(void * base,VkComponentTypeNV dt,deUint32 i)619 static deUint32 getDataInt(void *base, VkComponentTypeNV dt, deUint32 i)
620 {
621 	DE_ASSERT(componentTypeInfo[dt].bits <= 32);
622 	switch (dt) {
623 	default: DE_ASSERT(0); // fallthrough
624 	case VK_COMPONENT_TYPE_UINT8_NV:	return ((deUint8  *)base)[i];
625 	case VK_COMPONENT_TYPE_UINT16_NV:	return ((deUint16 *)base)[i];
626 	case VK_COMPONENT_TYPE_UINT32_NV:	return ((deUint32 *)base)[i];
627 	case VK_COMPONENT_TYPE_SINT8_NV:	return ((deInt8  *)base)[i];
628 	case VK_COMPONENT_TYPE_SINT16_NV:	return ((deInt16 *)base)[i];
629 	case VK_COMPONENT_TYPE_SINT32_NV:	return ((deInt32 *)base)[i];
630 	}
631 }
632 
iterate(void)633 tcu::TestStatus CooperativeMatrixTestInstance::iterate (void)
634 {
635 	const DeviceInterface&	vk						= m_context.getDeviceInterface();
636 	const VkDevice			device					= m_context.getDevice();
637 	Allocator&				allocator				= m_context.getDefaultAllocator();
638 	MemoryRequirement		memoryDeviceAddress		= m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER &&
639 													  m_context.isDeviceFunctionalitySupported("VK_KHR_buffer_device_address") ? MemoryRequirement::DeviceAddress : MemoryRequirement::Any;
640 	qpTestResult			finalres				= QP_TEST_RESULT_PASS;
641 	tcu::TestLog&			log						= m_context.getTestContext().getLog();
642 
643 	deRandom rnd;
644 	deRandom_init(&rnd, 1234);
645 
646 	vk::VkPhysicalDeviceSubgroupProperties subgroupProperties;
647 	deMemset(&subgroupProperties, 0, sizeof(subgroupProperties));
648 	subgroupProperties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES;
649 
650 	vk::VkPhysicalDeviceProperties2 properties2;
651 	deMemset(&properties2, 0, sizeof(properties2));
652 	properties2.sType = vk::VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2;
653 	properties2.pNext = &subgroupProperties;
654 
655 	m_context.getInstanceInterface().getPhysicalDeviceProperties2(m_context.getPhysicalDevice(), &properties2);
656 
657 	deUint32 propertyCount = 0;
658 	VkCooperativeMatrixPropertiesNV *pProperties;
659 	m_context.getInstanceInterface().getPhysicalDeviceCooperativeMatrixPropertiesNV(m_context.getPhysicalDevice(), &propertyCount, DE_NULL);
660 	// Shouldn't have made it through checkSupport without any properties
661 	DE_ASSERT(propertyCount != 0);
662 
663 	pProperties = new VkCooperativeMatrixPropertiesNV[propertyCount];
664 
665 	for (deUint32 i = 0; i < propertyCount; ++i)
666 	{
667 		VkCooperativeMatrixPropertiesNV *p = &pProperties[i];
668 		p->sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_PROPERTIES_NV;
669 		p->pNext = DE_NULL;
670 	}
671 
672 	m_context.getInstanceInterface().getPhysicalDeviceCooperativeMatrixPropertiesNV(m_context.getPhysicalDevice(), &propertyCount, pProperties);
673 
674 	struct TestTuple
675 	{
676 		TestTuple() {}
677 		TestTuple(deUint32 m, deUint32 n, deUint32 k) : M(m), N(n), K(k) {}
678 
679 		bool operator<(const TestTuple &other) const
680 		{
681 			return M < other.M ||
682 				   (M == other.M && N < other.N) ||
683 				   (M == other.M && N == other.N && K < other.K);
684 		}
685 
686 		deUint32 M, N, K;
687 	};
688 
689 	vector<TestTuple> testSizes;
690 
691 	if (m_data.testType == TT_MATRIXMULADD ||
692 		m_data.testType == TT_MATRIXMULADD_ARRAY)
693 	{
694 		for (deUint32 i = 0; i < propertyCount; ++i)
695 		{
696 			VkCooperativeMatrixPropertiesNV *p = &pProperties[i];
697 
698 			if (p->AType == m_data.inputType &&
699 				p->BType == m_data.inputType &&
700 				p->CType == m_data.outputType &&
701 				p->DType == m_data.outputType &&
702 				p->scope == VK_SCOPE_SUBGROUP_NV)
703 			{
704 				testSizes.push_back(TestTuple(p->MSize, p->NSize, p->KSize));
705 			}
706 		}
707 	}
708 	else
709 	{
710 		set<TestTuple> typeSizes[2];
711 		VkComponentTypeNV types[2] = { m_data.inputType, m_data.outputType };
712 
713 		for (deUint32 i = 0; i < propertyCount; ++i)
714 		{
715 			VkCooperativeMatrixPropertiesNV *p = &pProperties[i];
716 
717 			if (p->scope != VK_SCOPE_SUBGROUP_NV)
718 				continue;
719 
720 			for (deUint32 j = 0; j < 2; ++j)
721 			{
722 				// For these tests, m_data.M/N are always the matrix size. Check if they match
723 				// any input or output in the list.
724 				if (p->AType == types[j])
725 					typeSizes[j].insert(TestTuple(p->MSize, p->KSize, 0));
726 				if (p->BType == types[j])
727 					typeSizes[j].insert(TestTuple(p->KSize, p->NSize, 0));
728 				if (p->CType == types[j] ||
729 					p->DType == types[j])
730 					typeSizes[j].insert(TestTuple(p->MSize, p->NSize, 0));
731 			}
732 		}
733 		// Test those sizes that are supported for both the input and output type.
734 		std::set_intersection(typeSizes[0].begin(), typeSizes[0].end(),
735 							  typeSizes[1].begin(), typeSizes[1].end(),
736 							  std::back_inserter(testSizes));
737 	}
738 
739 	delete [] pProperties;
740 
741 	for (unsigned int s = 0; s < testSizes.size(); ++s)
742 	{
743 		// When testing a multiply, MxNxK is the type of matrix multiply.
744 		// Otherwise, MxN is the size of the input/output matrices
745 		deUint32 M, N, K;
746 		M = testSizes[s].M;
747 		N = testSizes[s].N;
748 		K = testSizes[s].K;
749 
750 		log << tcu::TestLog::Message << "Testing M = " << M << ", N = " << N << ", K = " << K << tcu::TestLog::EndMessage;
751 
752 		struct
753 		{
754 			deUint32 rows, cols;
755 		} dims[4];
756 
757 		if (m_data.testType == TT_MATRIXMULADD ||
758 			m_data.testType == TT_MATRIXMULADD_ARRAY)
759 		{
760 			dims[0].rows = M;
761 			dims[0].cols = K;
762 			dims[1].rows = K;
763 			dims[1].cols = N;
764 			dims[2].rows = M;
765 			dims[2].cols = N;
766 			dims[3].rows = M;
767 			dims[3].cols = N;
768 		}
769 		else
770 		{
771 			dims[0].rows = M;
772 			dims[0].cols = N;
773 			dims[1].rows = M;
774 			dims[1].cols = N;
775 			dims[2].rows = M;
776 			dims[2].cols = N;
777 			dims[3].rows = M;
778 			dims[3].cols = N;
779 		}
780 
781 		VkComponentTypeNV dataTypes[4];
782 		size_t elementSize[4];
783 		VkDeviceSize bufferSizes[5];
784 		de::MovePtr<BufferWithMemory> buffers[5];
785 		vk::VkDescriptorBufferInfo bufferDescriptors[5];
786 		deUint32 strides[4]; // in elements
787 		deUint32 totalElements[4];
788 
789 		for (deUint32 i = 0; i < 5; ++i)
790 		{
791 			if (i < 4)
792 			{
793 				// A/B use input type, C/D use output type
794 				dataTypes[i] = (i < 2) ? m_data.inputType : m_data.outputType;
795 				elementSize[i] = componentTypeInfo[dataTypes[i]].bits / 8;
796 
797 				strides[i] = (m_data.colMajor ? dims[i].rows : dims[i].cols) * m_data.subgroupsPerWorkgroupX * m_data.workgroupsX;
798 				totalElements[i] = strides[i] * (m_data.colMajor ? dims[i].cols : dims[i].rows) * m_data.subgroupsPerWorkgroupY * m_data.workgroupsY;
799 
800 				bufferSizes[i] = totalElements[i] * elementSize[i];
801 			}
802 			else
803 			{
804 				bufferSizes[4] = sizeof(VkDeviceAddress)*4;
805 			}
806 
807 			try
808 			{
809 				buffers[i] = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
810 					vk, device, allocator, makeBufferCreateInfo(bufferSizes[i], VK_BUFFER_USAGE_STORAGE_BUFFER_BIT|VK_BUFFER_USAGE_TRANSFER_DST_BIT|VK_BUFFER_USAGE_TRANSFER_SRC_BIT|VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT_EXT),
811 					MemoryRequirement::HostVisible | MemoryRequirement::Cached | MemoryRequirement::Coherent | memoryDeviceAddress));
812 			}
813 			catch (const tcu::NotSupportedError&)
814 			{
815 				buffers[i] = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
816 					vk, device, allocator, makeBufferCreateInfo(bufferSizes[i], VK_BUFFER_USAGE_STORAGE_BUFFER_BIT|VK_BUFFER_USAGE_TRANSFER_DST_BIT|VK_BUFFER_USAGE_TRANSFER_SRC_BIT|VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT_EXT),
817 					MemoryRequirement::HostVisible | memoryDeviceAddress));
818 			}
819 
820 			bufferDescriptors[i] = makeDescriptorBufferInfo(**buffers[i], 0, bufferSizes[i]);
821 		}
822 
823 		void *ptrs[5];
824 		for (deUint32 i = 0; i < 5; ++i)
825 		{
826 			ptrs[i] = buffers[i]->getAllocation().getHostPtr();
827 		}
828 
829 		vk::DescriptorSetLayoutBuilder layoutBuilder;
830 
831 		layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
832 		layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
833 		layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
834 		layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
835 		layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
836 
837 		vk::Unique<vk::VkDescriptorSetLayout>	descriptorSetLayout(layoutBuilder.build(vk, device));
838 
839 		vk::Unique<vk::VkDescriptorPool>		descriptorPool(vk::DescriptorPoolBuilder()
840 			.addType(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 5u)
841 			.build(vk, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u));
842 		vk::Unique<vk::VkDescriptorSet>			descriptorSet		(makeDescriptorSet(vk, device, *descriptorPool, *descriptorSetLayout));
843 
844 		vk::DescriptorSetUpdateBuilder setUpdateBuilder;
845 		if (m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER)
846 		{
847 			const bool useKHR = m_context.isDeviceFunctionalitySupported("VK_KHR_buffer_device_address");
848 
849 			VkBufferDeviceAddressInfo info =
850 			{
851 				VK_STRUCTURE_TYPE_BUFFER_DEVICE_ADDRESS_INFO,		// VkStructureType	 sType;
852 				DE_NULL,											// const void*		 pNext;
853 				0,													// VkBuffer			buffer
854 			};
855 			VkDeviceAddress *addrsInMemory = (VkDeviceAddress *)ptrs[4];
856 			for (deUint32 i = 0; i < 4; ++i)
857 			{
858 				info.buffer = **buffers[i];
859 				VkDeviceAddress addr;
860 				if (useKHR)
861 					addr = vk.getBufferDeviceAddress(device, &info);
862 				else
863 					addr = vk.getBufferDeviceAddressEXT(device, &info);
864 				addrsInMemory[i] = addr;
865 			}
866 			setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(4),
867 				VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[4]);
868 		}
869 		else
870 		{
871 			setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(0),
872 				VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[0]);
873 			setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(1),
874 				VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[1]);
875 			setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(2),
876 				VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[2]);
877 			setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(3),
878 				VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[3]);
879 		}
880 
881 		setUpdateBuilder.update(vk, device);
882 
883 		const VkPipelineLayoutCreateInfo pipelineLayoutCreateInfo =
884 		{
885 			VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO,				// sType
886 			DE_NULL,													// pNext
887 			(VkPipelineLayoutCreateFlags)0,
888 			1,															// setLayoutCount
889 			&descriptorSetLayout.get(),									// pSetLayouts
890 			0u,															// pushConstantRangeCount
891 			DE_NULL,													// pPushConstantRanges
892 		};
893 
894 		Move<VkPipelineLayout> pipelineLayout = createPipelineLayout(vk, device, &pipelineLayoutCreateInfo, NULL);
895 
896 		Move<VkPipeline> pipeline;
897 
898 		VkPipelineBindPoint bindPoint = VK_PIPELINE_BIND_POINT_COMPUTE;
899 
900 		const deUint32 specData[9] =
901 		{
902 			subgroupProperties.subgroupSize * m_data.subgroupsPerWorkgroupX,
903 			m_data.subgroupsPerWorkgroupY,
904 			strides[0],
905 			strides[1],
906 			strides[2],
907 			strides[3],
908 			M,
909 			N,
910 			K,
911 		};
912 
913 		const vk::VkSpecializationMapEntry entries[9] =
914 		{
915 			{0, (deUint32)(sizeof(deUint32) * 0), sizeof(deUint32)},
916 			{1, (deUint32)(sizeof(deUint32) * 1), sizeof(deUint32)},
917 			{2, (deUint32)(sizeof(deUint32) * 2), sizeof(deUint32)},
918 			{3, (deUint32)(sizeof(deUint32) * 3), sizeof(deUint32)},
919 			{4, (deUint32)(sizeof(deUint32) * 4), sizeof(deUint32)},
920 			{5, (deUint32)(sizeof(deUint32) * 5), sizeof(deUint32)},
921 			{6, (deUint32)(sizeof(deUint32) * 6), sizeof(deUint32)},
922 			{7, (deUint32)(sizeof(deUint32) * 7), sizeof(deUint32)},
923 			{8, (deUint32)(sizeof(deUint32) * 8), sizeof(deUint32)},
924 		};
925 
926 		const vk::VkSpecializationInfo specInfo =
927 		{
928 			9,						// mapEntryCount
929 			entries,				// pMapEntries
930 			sizeof(specData),		// dataSize
931 			specData				// pData
932 		};
933 
934 		for (deUint32 i = 0; i < 4; ++i)
935 			for (deUint32 j = 0; j < totalElements[i]; ++j)
936 			{
937 				if (isFloatType(dataTypes[i]))
938 				{
939 					if (m_data.testType != TT_MATRIXMULADD &&
940 						m_data.testType != TT_MATRIXMULADD_ARRAY)
941 						setDataFloat(ptrs[i], dataTypes[i], j, ((float)(deRandom_getUint32(&rnd) & 0xff) - 64.0f)/2.0f);
942 					else
943 						setDataFloat(ptrs[i], dataTypes[i], j, ((float)(deRandom_getUint32(&rnd) & 0xf) - 4.0f)/2.0f);
944 				}
945 				else
946 					setDataInt(ptrs[i], dataTypes[i], j, (deRandom_getUint32(&rnd) & 0xff) - 128);
947 			}
948 
949 		flushAlloc(vk, device, buffers[0]->getAllocation());
950 		flushAlloc(vk, device, buffers[1]->getAllocation());
951 		flushAlloc(vk, device, buffers[2]->getAllocation());
952 		flushAlloc(vk, device, buffers[3]->getAllocation());
953 
954 		const Unique<VkShaderModule>	shader						(createShaderModule(vk, device, m_context.getBinaryCollection().get("test"), 0));
955 
956 		const VkPipelineShaderStageCreateInfo	shaderCreateInfo =
957 		{
958 			VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
959 			DE_NULL,
960 			(VkPipelineShaderStageCreateFlags)0,
961 			VK_SHADER_STAGE_COMPUTE_BIT,								// stage
962 			*shader,													// shader
963 			"main",
964 			&specInfo,													// pSpecializationInfo
965 		};
966 
967 		const VkComputePipelineCreateInfo		pipelineCreateInfo =
968 		{
969 			VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
970 			DE_NULL,
971 			0u,															// flags
972 			shaderCreateInfo,											// cs
973 			*pipelineLayout,											// layout
974 			(vk::VkPipeline)0,											// basePipelineHandle
975 			0u,															// basePipelineIndex
976 		};
977 		pipeline = createComputePipeline(vk, device, DE_NULL, &pipelineCreateInfo, NULL);
978 
979 		const VkQueue					queue					= m_context.getUniversalQueue();
980 		Move<VkCommandPool>				cmdPool					= createCommandPool(vk, device, 0, m_context.getUniversalQueueFamilyIndex());
981 		Move<VkCommandBuffer>			cmdBuffer				= allocateCommandBuffer(vk, device, *cmdPool, VK_COMMAND_BUFFER_LEVEL_PRIMARY);
982 
983 		beginCommandBuffer(vk, *cmdBuffer, 0u);
984 
985 		vk.cmdBindDescriptorSets(*cmdBuffer, bindPoint, *pipelineLayout, 0u, 1, &*descriptorSet, 0u, DE_NULL);
986 		vk.cmdBindPipeline(*cmdBuffer, bindPoint, *pipeline);
987 
988 		vk.cmdDispatch(*cmdBuffer, m_data.workgroupsX, m_data.workgroupsY, 1);
989 
990 		endCommandBuffer(vk, *cmdBuffer);
991 
992 		submitCommandsAndWait(vk, device, queue, cmdBuffer.get());
993 
994 		invalidateAlloc(vk, device, buffers[3]->getAllocation());
995 
996 		qpTestResult res = QP_TEST_RESULT_PASS;
997 
998 		if (isFloatType(dataTypes[0]))
999 		{
1000 			if (m_data.testType != TT_MATRIXMULADD &&
1001 				m_data.testType != TT_MATRIXMULADD_ARRAY)
1002 			{
1003 				for (deUint32 i = 0; i < totalElements[3]; ++i)
1004 				{
1005 					float inputA = getDataFloat(ptrs[0], dataTypes[0], i);
1006 					float inputB = getDataFloat(ptrs[1], dataTypes[1], i);
1007 					float output = getDataFloat(ptrs[3], dataTypes[3], i);
1008 					switch (m_data.testType)
1009 					{
1010 					case TT_LENGTH:
1011 						if (output < 1.0f || output > (float)(N*M))
1012 							res = QP_TEST_RESULT_FAIL;
1013 						// We expect the matrix to be spread evenly across invocations, it is
1014 						// surprising (but not necessarily illegal) if not
1015 						if (output != (float)(N*M/subgroupProperties.subgroupSize) &&
1016 							res == QP_TEST_RESULT_PASS)
1017 							res = QP_TEST_RESULT_QUALITY_WARNING;
1018 						break;
1019 					case TT_CONSTANT:
1020 						if (output != 1.0f)
1021 							res = QP_TEST_RESULT_FAIL;
1022 						break;
1023 					case TT_CONVERT:
1024 						if (output != inputA)
1025 							res = QP_TEST_RESULT_FAIL;
1026 						break;
1027 					case TT_COMPOSITE:
1028 					case TT_COMPOSITE_RVALUE:
1029 					case TT_COMPOSITE_ARRAY:
1030 					case TT_ADD:
1031 						if (output != inputA + inputB)
1032 							res = QP_TEST_RESULT_FAIL;
1033 						break;
1034 					case TT_SUB:
1035 						if (output != inputA - inputB)
1036 							res = QP_TEST_RESULT_FAIL;
1037 						break;
1038 					case TT_DIV:
1039 						{
1040 							float ulp = (m_data.inputType == VK_COMPONENT_TYPE_FLOAT16_NV) ? 1.0f/1024.0f : 1.0f/(8.0f*1024.0f*1024.0f);
1041 							// division allows 2.5ulp, but we'll use 3.
1042 							ulp *= 3;
1043 							if (inputB != 0 && fabs(output - inputA / inputB) > ulp * fabs(inputA / inputB))
1044 								res = QP_TEST_RESULT_FAIL;
1045 						}
1046 						break;
1047 					case TT_NEGATE:
1048 					case TT_FUNC:
1049 						if (output != -inputA)
1050 							res = QP_TEST_RESULT_FAIL;
1051 						break;
1052 					case TT_MATRIXTIMESSCALAR:
1053 						if (output != 6.0*inputA)
1054 							res = QP_TEST_RESULT_FAIL;
1055 						break;
1056 					default:
1057 						break;
1058 					}
1059 				}
1060 			}
1061 			else
1062 			{
1063 				deUint32 ik, kj, ij;
1064 				for (deUint32 mX = 0; mX < m_data.subgroupsPerWorkgroupX*m_data.workgroupsX; ++mX)
1065 				{
1066 					for (deUint32 mY = 0; mY < m_data.subgroupsPerWorkgroupY*m_data.workgroupsY; ++mY)
1067 					{
1068 						for (deUint32 i = 0; i < M; ++i)
1069 						{
1070 							for (deUint32 j = 0; j < N; ++j)
1071 							{
1072 								float ref = 0;
1073 								for (deUint32 k = 0; k < K; ++k)
1074 								{
1075 									if (m_data.colMajor)
1076 										ik = mX * M + i + strides[0] * (mY * K + k);
1077 									else
1078 										ik = mX * K + k + strides[0] * (mY * M + i);
1079 
1080 									float Aik = getDataFloat(ptrs[0], dataTypes[0], ik);
1081 
1082 									if (m_data.colMajor)
1083 										kj = mX * K + k + strides[1] * (mY * N + j);
1084 									else
1085 										kj = mX * N + j + strides[1] * (mY * K + k);
1086 
1087 									float Bkj = getDataFloat(ptrs[1], dataTypes[1], kj);
1088 
1089 									ref += Aik*Bkj;
1090 								}
1091 
1092 								if (m_data.colMajor)
1093 									ij = mX * M + i + strides[2] * (mY * N + j);
1094 								else
1095 									ij = mX * N + j + strides[2] * (mY * M + i);
1096 
1097 								float Cij = getDataFloat(ptrs[2], dataTypes[2], ij);
1098 
1099 								ref += Cij;
1100 
1101 								float Dij = getDataFloat(ptrs[3], dataTypes[3], ij);
1102 
1103 								if (ref != Dij)
1104 								{
1105 									res = QP_TEST_RESULT_FAIL;
1106 								}
1107 							}
1108 						}
1109 					}
1110 				}
1111 			}
1112 		} else {
1113 			if (m_data.testType != TT_MATRIXMULADD &&
1114 				m_data.testType != TT_MATRIXMULADD_ARRAY)
1115 			{
1116 				for (deUint32 i = 0; i < totalElements[3]; ++i)
1117 				{
1118 					deUint32 inputA = getDataInt(ptrs[0], dataTypes[0], i);
1119 					deUint32 inputB = getDataInt(ptrs[1], dataTypes[1], i);
1120 					deUint32 output = getDataInt(ptrs[3], dataTypes[3], i);
1121 					int resultSize = componentTypeInfo[dataTypes[3]].bits;
1122 					deUint32 mask = resultSize == 32 ? ~0 : ((1 << resultSize) - 1);
1123 					switch (m_data.testType)
1124 					{
1125 					case TT_LENGTH:
1126 						if (output < 1 || output > N*M)
1127 							res = QP_TEST_RESULT_FAIL;
1128 						// We expect the matrix to be spread evenly across invocations, it is
1129 						// surprising (but not necessarily illegal) if not
1130 						if (output != N*M/subgroupProperties.subgroupSize &&
1131 							res == QP_TEST_RESULT_PASS)
1132 							res = QP_TEST_RESULT_QUALITY_WARNING;
1133 						break;
1134 					case TT_CONSTANT:
1135 						if (output != 1)
1136 							res = QP_TEST_RESULT_FAIL;
1137 						break;
1138 					case TT_CONVERT:
1139 						if (output != inputA)
1140 							res = QP_TEST_RESULT_FAIL;
1141 						break;
1142 					case TT_COMPOSITE:
1143 					case TT_COMPOSITE_RVALUE:
1144 					case TT_COMPOSITE_ARRAY:
1145 					case TT_ADD:
1146 						if ((output & mask) != ((inputA + inputB) & mask)) {
1147 							res = QP_TEST_RESULT_FAIL;
1148 						}
1149 						break;
1150 					case TT_SUB:
1151 						if ((output & mask) != ((inputA - inputB) & mask))
1152 							res = QP_TEST_RESULT_FAIL;
1153 						break;
1154 					case TT_DIV:
1155 						{
1156 							if (isSIntType(dataTypes[3]))
1157 							{
1158 								if (inputB != 0 && ((deInt32)output & mask) != (((deInt32)inputA / (deInt32)inputB) & mask))
1159 									res = QP_TEST_RESULT_FAIL;
1160 							} else
1161 							{
1162 								if (inputB != 0 && output != inputA / inputB)
1163 									res = QP_TEST_RESULT_FAIL;
1164 							}
1165 						}
1166 						break;
1167 					case TT_NEGATE:
1168 					case TT_FUNC:
1169 						if ((output & mask) != ((-(deInt32)inputA) & mask))
1170 							res = QP_TEST_RESULT_FAIL;
1171 						break;
1172 					case TT_MATRIXTIMESSCALAR:
1173 						if ((output & mask) != ((6*inputA) & mask)) {
1174 							res = QP_TEST_RESULT_FAIL;
1175 						}
1176 						break;
1177 					default:
1178 						break;
1179 					}
1180 				}
1181 			}
1182 			else
1183 			{
1184 				deUint32 ik, kj, ij;
1185 				for (deUint32 mX = 0; mX < m_data.subgroupsPerWorkgroupX*m_data.workgroupsX; ++mX)
1186 				{
1187 					for (deUint32 mY = 0; mY < m_data.subgroupsPerWorkgroupY*m_data.workgroupsY; ++mY)
1188 					{
1189 						for (deUint32 i = 0; i < M; ++i)
1190 						{
1191 							for (deUint32 j = 0; j < N; ++j)
1192 							{
1193 								deUint32 ref = 0;
1194 								for (deUint32 k = 0; k < K; ++k)
1195 								{
1196 									if (m_data.colMajor)
1197 										ik = mX * M + i + strides[0] * (mY * K + k);
1198 									else
1199 										ik = mX * K + k + strides[0] * (mY * M + i);
1200 
1201 									deUint32 Aik = getDataInt(ptrs[0], dataTypes[0], ik);
1202 
1203 									if (m_data.colMajor)
1204 										kj = mX * K + k + strides[1] * (mY * N + j);
1205 									else
1206 										kj = mX * N + j + strides[1] * (mY * K + k);
1207 
1208 									deUint32 Bkj = getDataInt(ptrs[1], dataTypes[1], kj);
1209 
1210 									ref += Aik*Bkj;
1211 								}
1212 
1213 								if (m_data.colMajor)
1214 									ij = mX * M + i + strides[2] * (mY * N + j);
1215 								else
1216 									ij = mX * N + j + strides[2] * (mY * M + i);
1217 
1218 								deUint32 Cij = getDataInt(ptrs[2], dataTypes[2], ij);
1219 
1220 								ref += Cij;
1221 
1222 								deUint32 Dij = getDataInt(ptrs[3], dataTypes[3], ij);
1223 
1224 								if (ref != Dij)
1225 								{
1226 									res = QP_TEST_RESULT_FAIL;
1227 								}
1228 							}
1229 						}
1230 					}
1231 				}
1232 			}
1233 		}
1234 		if (res != QP_TEST_RESULT_PASS)
1235 		{
1236 			log << tcu::TestLog::Message << "failed with M = " << M << ", N = " << N << ", K = " << K << tcu::TestLog::EndMessage;
1237 			finalres = res;
1238 		}
1239 	}
1240 
1241 	return tcu::TestStatus(finalres, qpGetTestResultName(finalres));
1242 }
1243 
1244 }	// anonymous
1245 
createCooperativeMatrixTests(tcu::TestContext & testCtx)1246 tcu::TestCaseGroup*	createCooperativeMatrixTests (tcu::TestContext& testCtx)
1247 {
1248 	de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(
1249 			testCtx, "cooperative_matrix", "GL_NV_cooperative_matrix tests"));
1250 
1251 	typedef struct
1252 	{
1253 		deUint32				value;
1254 		const char*				name;
1255 		const char*				description;
1256 	} TestGroupCase;
1257 
1258 	typedef struct
1259 	{
1260 		deUint32				value[2];
1261 		const char*				name;
1262 		const char*				description;
1263 	} TestGroupCase2;
1264 
1265 	TestGroupCase ttCases[] =
1266 	{
1267 		{ TT_LENGTH,				"length",					"OpCooperativeMatrixLengthNV"	},
1268 		{ TT_CONSTANT,				"constant",					"OpConstantComposite"			},
1269 		{ TT_CONVERT,				"convert",					"OpFConvert/OpSConvert/OpUConvert"	},
1270 		{ TT_COMPOSITE,				"composite",				"OpCompositeConstruct"			},
1271 		{ TT_COMPOSITE_RVALUE,		"composite_rvalue",			"OpCompositeExtract"			},
1272 		{ TT_ADD,					"add",						"OpFAdd/OpIAdd"					},
1273 		{ TT_SUB,					"sub",						"OpFSub/OpISub"					},
1274 		{ TT_DIV,					"div",						"OpFDiv/OpSDiv/OpUDiv"			},
1275 		{ TT_NEGATE,				"negate",					"OpFNegate/OpSNegate"			},
1276 		{ TT_MATRIXTIMESSCALAR,		"matrixtimesscalar",		"OpMatrixTimesScalar"			},
1277 		{ TT_FUNC,					"func",						"OpFunctionParameter"			},
1278 		{ TT_MATRIXMULADD,			"matrixmuladd",				"OpCooperativeMatrixMulAddNV"	},
1279 		{ TT_COMPOSITE_ARRAY,		"composite_array",			"OpCompositeConstruct w/array"			},
1280 		{ TT_MATRIXMULADD_ARRAY,	"matrixmuladd_array",		"OpCooperativeMatrixMulAddNV w/array"	},
1281 	};
1282 
1283 	TestGroupCase2 dtCases[] =
1284 	{
1285 		{ { VK_COMPONENT_TYPE_FLOAT32_NV,	VK_COMPONENT_TYPE_FLOAT32_NV },	"float32_float32",	"A/B are fp32 C/D are fp32"		},
1286 		{ { VK_COMPONENT_TYPE_FLOAT32_NV,	VK_COMPONENT_TYPE_FLOAT16_NV },	"float32_float16",	"A/B are fp32 C/D are fp16"		},
1287 		{ { VK_COMPONENT_TYPE_FLOAT16_NV,	VK_COMPONENT_TYPE_FLOAT32_NV },	"float16_float32",	"A/B are fp16 C/D are fp32"		},
1288 		{ { VK_COMPONENT_TYPE_FLOAT16_NV,	VK_COMPONENT_TYPE_FLOAT16_NV },	"float16_float16",	"A/B are fp16 C/D are fp16"		},
1289 		{ { VK_COMPONENT_TYPE_UINT8_NV,		VK_COMPONENT_TYPE_UINT8_NV },	"uint8_uint8",		"A/B are u8 C/D are u8"			},
1290 		{ { VK_COMPONENT_TYPE_UINT8_NV,		VK_COMPONENT_TYPE_UINT32_NV },	"uint8_uint32",		"A/B are u8 C/D are u32"		},
1291 		{ { VK_COMPONENT_TYPE_SINT8_NV,		VK_COMPONENT_TYPE_SINT8_NV },	"sint8_sint8",		"A/B are s8 C/D are s8"			},
1292 		{ { VK_COMPONENT_TYPE_SINT8_NV,		VK_COMPONENT_TYPE_SINT32_NV },	"sint8_sint32",		"A/B are s8 C/D are s32"		},
1293 		{ { VK_COMPONENT_TYPE_UINT32_NV,	VK_COMPONENT_TYPE_UINT32_NV },	"uint32_uint32",	"A/B are u32 C/D are u32"		},
1294 		{ { VK_COMPONENT_TYPE_UINT32_NV,	VK_COMPONENT_TYPE_UINT8_NV },	"uint32_uint8",		"A/B are u32 C/D are u8"		},
1295 		{ { VK_COMPONENT_TYPE_SINT32_NV,	VK_COMPONENT_TYPE_SINT32_NV },	"sint32_sint32",	"A/B are s32 C/D are s32"		},
1296 		{ { VK_COMPONENT_TYPE_SINT32_NV,	VK_COMPONENT_TYPE_SINT8_NV },	"sint32_sint8",		"A/B are s32 C/D are s8"		},
1297 	};
1298 
1299 	TestGroupCase colCases[] =
1300 	{
1301 		{ 0,		"rowmajor",	"row major"		},
1302 		{ 1,		"colmajor",	"col major"		},
1303 	};
1304 
1305 	TestGroupCase scCases[] =
1306 	{
1307 		{ SC_BUFFER,						"buffer",			"SSBO"				},
1308 		{ SC_WORKGROUP,						"workgroup",		"shared memory"		},
1309 		{ SC_BUFFER_VARIABLE_POINTERS,		"buffer_varptr",	"SSBO w/variable pointers"		},
1310 		{ SC_WORKGROUP_VARIABLE_POINTERS,	"workgroup_varptr",	"shared memory w/variable pointers"		},
1311 		{ SC_PHYSICAL_STORAGE_BUFFER,		"physical_buffer",	"physical_storage_buffer"				},
1312 	};
1313 
1314 	for (int ttNdx = 0; ttNdx < DE_LENGTH_OF_ARRAY(ttCases); ttNdx++)
1315 	{
1316 		de::MovePtr<tcu::TestCaseGroup> ttGroup(new tcu::TestCaseGroup(testCtx, ttCases[ttNdx].name, ttCases[ttNdx].description));
1317 		for (int dtNdx = 0; dtNdx < DE_LENGTH_OF_ARRAY(dtCases); dtNdx++)
1318 		{
1319 			de::MovePtr<tcu::TestCaseGroup> dtGroup(new tcu::TestCaseGroup(testCtx, dtCases[dtNdx].name, dtCases[dtNdx].description));
1320 			for (int scNdx = 0; scNdx < DE_LENGTH_OF_ARRAY(scCases); scNdx++)
1321 			{
1322 				de::MovePtr<tcu::TestCaseGroup> scGroup(new tcu::TestCaseGroup(testCtx, scCases[scNdx].name, scCases[scNdx].description));
1323 				for (int colNdx = 0; colNdx < DE_LENGTH_OF_ARRAY(colCases); colNdx++)
1324 				{
1325 					TestType testType = (TestType)ttCases[ttNdx].value;
1326 					VkComponentTypeNV inputType = (VkComponentTypeNV)dtCases[dtNdx].value[0];
1327 					VkComponentTypeNV outputType = (VkComponentTypeNV)dtCases[dtNdx].value[1];
1328 
1329 					bool isMatrixMul = testType == TT_MATRIXMULADD || testType == TT_MATRIXMULADD_ARRAY;
1330 
1331 					if (!isMatrixMul && testType != TT_CONVERT && inputType != outputType)
1332 						continue;
1333 
1334 					if (testType == TT_CONVERT && inputType == outputType)
1335 						continue;
1336 
1337 					if (isMatrixMul && componentTypeInfo[inputType].bits > componentTypeInfo[outputType].bits)
1338 						continue;
1339 
1340 					CaseDef c =
1341 					{
1342 						testType,							// TestType testtype;
1343 						2u,									// deUint32 subgroupsPerWorkgroupX;
1344 						2u,									// deUint32 subgroupsPerWorkgroupY;
1345 						4u,									// deUint32 workgroupsX;
1346 						4u,									// deUint32 workgroupsY;
1347 						(VkComponentTypeNV)inputType,		// VkComponentTypeNV inputType;
1348 						(VkComponentTypeNV)outputType,		// VkComponentTypeNV outputType;
1349 						!!colCases[colNdx].value,			// bool colMajor;
1350 						(StorageClass)scCases[scNdx].value,	// StorageClass storageClass;
1351 					};
1352 
1353 					scGroup->addChild(new CooperativeMatrixTestCase(testCtx, colCases[colNdx].name, colCases[colNdx].description, c));
1354 				}
1355 				dtGroup->addChild(scGroup.release());
1356 			}
1357 			ttGroup->addChild(dtGroup.release());
1358 		}
1359 		group->addChild(ttGroup.release());
1360 	}
1361 	return group.release();
1362 }
1363 
1364 }	// compute
1365 }	// vkt
1366