1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // Tests for unique type declaration rules validator.
16
17 #include <string>
18
19 #include "gmock/gmock.h"
20 #include "test/unit_spirv.h"
21 #include "test/val/val_fixtures.h"
22
23 namespace spvtools {
24 namespace val {
25 namespace {
26
27 using ::testing::HasSubstr;
28 using ::testing::Not;
29
30 using ValidateArithmetics = spvtest::ValidateBase<bool>;
31
GenerateCode(const std::string & main_body)32 std::string GenerateCode(const std::string& main_body) {
33 const std::string prefix =
34 R"(
35 OpCapability Shader
36 OpCapability Int64
37 OpCapability Float64
38 OpCapability Matrix
39 %ext_inst = OpExtInstImport "GLSL.std.450"
40 OpMemoryModel Logical GLSL450
41 OpEntryPoint Fragment %main "main"
42 OpExecutionMode %main OriginUpperLeft
43 %void = OpTypeVoid
44 %func = OpTypeFunction %void
45 %bool = OpTypeBool
46 %f32 = OpTypeFloat 32
47 %u32 = OpTypeInt 32 0
48 %s32 = OpTypeInt 32 1
49 %f64 = OpTypeFloat 64
50 %u64 = OpTypeInt 64 0
51 %s64 = OpTypeInt 64 1
52 %boolvec2 = OpTypeVector %bool 2
53 %s32vec2 = OpTypeVector %s32 2
54 %u32vec2 = OpTypeVector %u32 2
55 %u64vec2 = OpTypeVector %u64 2
56 %f32vec2 = OpTypeVector %f32 2
57 %f64vec2 = OpTypeVector %f64 2
58 %boolvec3 = OpTypeVector %bool 3
59 %u32vec3 = OpTypeVector %u32 3
60 %u64vec3 = OpTypeVector %u64 3
61 %s32vec3 = OpTypeVector %s32 3
62 %f32vec3 = OpTypeVector %f32 3
63 %f64vec3 = OpTypeVector %f64 3
64 %boolvec4 = OpTypeVector %bool 4
65 %u32vec4 = OpTypeVector %u32 4
66 %u64vec4 = OpTypeVector %u64 4
67 %s32vec4 = OpTypeVector %s32 4
68 %f32vec4 = OpTypeVector %f32 4
69 %f64vec4 = OpTypeVector %f64 4
70
71 %f32mat22 = OpTypeMatrix %f32vec2 2
72 %f32mat23 = OpTypeMatrix %f32vec2 3
73 %f32mat32 = OpTypeMatrix %f32vec3 2
74 %f32mat33 = OpTypeMatrix %f32vec3 3
75 %f64mat22 = OpTypeMatrix %f64vec2 2
76
77 %struct_f32_f32 = OpTypeStruct %f32 %f32
78 %struct_u32_u32 = OpTypeStruct %u32 %u32
79 %struct_u32_u32_u32 = OpTypeStruct %u32 %u32 %u32
80 %struct_s32_s32 = OpTypeStruct %s32 %s32
81 %struct_s32_u32 = OpTypeStruct %s32 %u32
82 %struct_u32vec2_u32vec2 = OpTypeStruct %u32vec2 %u32vec2
83 %struct_s32vec2_s32vec2 = OpTypeStruct %s32vec2 %s32vec2
84
85 %f32_0 = OpConstant %f32 0
86 %f32_1 = OpConstant %f32 1
87 %f32_2 = OpConstant %f32 2
88 %f32_3 = OpConstant %f32 3
89 %f32_4 = OpConstant %f32 4
90 %f32_pi = OpConstant %f32 3.14159
91
92 %s32_0 = OpConstant %s32 0
93 %s32_1 = OpConstant %s32 1
94 %s32_2 = OpConstant %s32 2
95 %s32_3 = OpConstant %s32 3
96 %s32_4 = OpConstant %s32 4
97 %s32_m1 = OpConstant %s32 -1
98
99 %u32_0 = OpConstant %u32 0
100 %u32_1 = OpConstant %u32 1
101 %u32_2 = OpConstant %u32 2
102 %u32_3 = OpConstant %u32 3
103 %u32_4 = OpConstant %u32 4
104
105 %f64_0 = OpConstant %f64 0
106 %f64_1 = OpConstant %f64 1
107 %f64_2 = OpConstant %f64 2
108 %f64_3 = OpConstant %f64 3
109 %f64_4 = OpConstant %f64 4
110
111 %s64_0 = OpConstant %s64 0
112 %s64_1 = OpConstant %s64 1
113 %s64_2 = OpConstant %s64 2
114 %s64_3 = OpConstant %s64 3
115 %s64_4 = OpConstant %s64 4
116 %s64_m1 = OpConstant %s64 -1
117
118 %u64_0 = OpConstant %u64 0
119 %u64_1 = OpConstant %u64 1
120 %u64_2 = OpConstant %u64 2
121 %u64_3 = OpConstant %u64 3
122 %u64_4 = OpConstant %u64 4
123
124 %u32vec2_01 = OpConstantComposite %u32vec2 %u32_0 %u32_1
125 %u32vec2_12 = OpConstantComposite %u32vec2 %u32_1 %u32_2
126 %u32vec3_012 = OpConstantComposite %u32vec3 %u32_0 %u32_1 %u32_2
127 %u32vec3_123 = OpConstantComposite %u32vec3 %u32_1 %u32_2 %u32_3
128 %u32vec4_0123 = OpConstantComposite %u32vec4 %u32_0 %u32_1 %u32_2 %u32_3
129 %u32vec4_1234 = OpConstantComposite %u32vec4 %u32_1 %u32_2 %u32_3 %u32_4
130
131 %s32vec2_01 = OpConstantComposite %s32vec2 %s32_0 %s32_1
132 %s32vec2_12 = OpConstantComposite %s32vec2 %s32_1 %s32_2
133 %s32vec3_012 = OpConstantComposite %s32vec3 %s32_0 %s32_1 %s32_2
134 %s32vec3_123 = OpConstantComposite %s32vec3 %s32_1 %s32_2 %s32_3
135 %s32vec4_0123 = OpConstantComposite %s32vec4 %s32_0 %s32_1 %s32_2 %s32_3
136 %s32vec4_1234 = OpConstantComposite %s32vec4 %s32_1 %s32_2 %s32_3 %s32_4
137
138 %f32vec2_01 = OpConstantComposite %f32vec2 %f32_0 %f32_1
139 %f32vec2_12 = OpConstantComposite %f32vec2 %f32_1 %f32_2
140 %f32vec3_012 = OpConstantComposite %f32vec3 %f32_0 %f32_1 %f32_2
141 %f32vec3_123 = OpConstantComposite %f32vec3 %f32_1 %f32_2 %f32_3
142 %f32vec4_0123 = OpConstantComposite %f32vec4 %f32_0 %f32_1 %f32_2 %f32_3
143 %f32vec4_1234 = OpConstantComposite %f32vec4 %f32_1 %f32_2 %f32_3 %f32_4
144
145 %f64vec2_01 = OpConstantComposite %f64vec2 %f64_0 %f64_1
146 %f64vec2_12 = OpConstantComposite %f64vec2 %f64_1 %f64_2
147 %f64vec3_012 = OpConstantComposite %f64vec3 %f64_0 %f64_1 %f64_2
148 %f64vec3_123 = OpConstantComposite %f64vec3 %f64_1 %f64_2 %f64_3
149 %f64vec4_0123 = OpConstantComposite %f64vec4 %f64_0 %f64_1 %f64_2 %f64_3
150 %f64vec4_1234 = OpConstantComposite %f64vec4 %f64_1 %f64_2 %f64_3 %f64_4
151
152 %f32mat22_1212 = OpConstantComposite %f32mat22 %f32vec2_12 %f32vec2_12
153 %f32mat23_121212 = OpConstantComposite %f32mat23 %f32vec2_12 %f32vec2_12 %f32vec2_12
154 %f32mat32_123123 = OpConstantComposite %f32mat32 %f32vec3_123 %f32vec3_123
155 %f32mat33_123123123 = OpConstantComposite %f32mat33 %f32vec3_123 %f32vec3_123 %f32vec3_123
156
157 %f64mat22_1212 = OpConstantComposite %f64mat22 %f64vec2_12 %f64vec2_12
158
159 %main = OpFunction %void None %func
160 %main_entry = OpLabel)";
161
162 const std::string suffix =
163 R"(
164 OpReturn
165 OpFunctionEnd)";
166
167 return prefix + main_body + suffix;
168 }
169
TEST_F(ValidateArithmetics,F32Success)170 TEST_F(ValidateArithmetics, F32Success) {
171 const std::string body = R"(
172 %val1 = OpFMul %f32 %f32_0 %f32_1
173 %val2 = OpFSub %f32 %f32_2 %f32_0
174 %val3 = OpFAdd %f32 %val1 %val2
175 %val4 = OpFNegate %f32 %val3
176 %val5 = OpFDiv %f32 %val4 %val1
177 %val6 = OpFRem %f32 %val4 %f32_2
178 %val7 = OpFMod %f32 %val4 %f32_2
179 )";
180
181 CompileSuccessfully(GenerateCode(body).c_str());
182 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
183 }
184
TEST_F(ValidateArithmetics,F64Success)185 TEST_F(ValidateArithmetics, F64Success) {
186 const std::string body = R"(
187 %val1 = OpFMul %f64 %f64_0 %f64_1
188 %val2 = OpFSub %f64 %f64_2 %f64_0
189 %val3 = OpFAdd %f64 %val1 %val2
190 %val4 = OpFNegate %f64 %val3
191 %val5 = OpFDiv %f64 %val4 %val1
192 %val6 = OpFRem %f64 %val4 %f64_2
193 %val7 = OpFMod %f64 %val4 %f64_2
194 )";
195
196 CompileSuccessfully(GenerateCode(body).c_str());
197 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
198 }
199
TEST_F(ValidateArithmetics,Int32Success)200 TEST_F(ValidateArithmetics, Int32Success) {
201 const std::string body = R"(
202 %val1 = OpIMul %u32 %s32_0 %u32_1
203 %val2 = OpIMul %s32 %s32_2 %u32_1
204 %val3 = OpIAdd %u32 %val1 %val2
205 %val4 = OpIAdd %s32 %val1 %val2
206 %val5 = OpISub %u32 %val3 %val4
207 %val6 = OpISub %s32 %val4 %val3
208 %val7 = OpSDiv %s32 %val4 %val3
209 %val8 = OpSNegate %s32 %val7
210 %val9 = OpSRem %s32 %val4 %val3
211 %val10 = OpSMod %s32 %val4 %val3
212 )";
213
214 CompileSuccessfully(GenerateCode(body).c_str());
215 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
216 }
217
TEST_F(ValidateArithmetics,Int64Success)218 TEST_F(ValidateArithmetics, Int64Success) {
219 const std::string body = R"(
220 %val1 = OpIMul %u64 %s64_0 %u64_1
221 %val2 = OpIMul %s64 %s64_2 %u64_1
222 %val3 = OpIAdd %u64 %val1 %val2
223 %val4 = OpIAdd %s64 %val1 %val2
224 %val5 = OpISub %u64 %val3 %val4
225 %val6 = OpISub %s64 %val4 %val3
226 %val7 = OpSDiv %s64 %val4 %val3
227 %val8 = OpSNegate %s64 %val7
228 %val9 = OpSRem %s64 %val4 %val3
229 %val10 = OpSMod %s64 %val4 %val3
230 )";
231
232 CompileSuccessfully(GenerateCode(body).c_str());
233 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
234 }
235
TEST_F(ValidateArithmetics,F32Vec2Success)236 TEST_F(ValidateArithmetics, F32Vec2Success) {
237 const std::string body = R"(
238 %val1 = OpFMul %f32vec2 %f32vec2_01 %f32vec2_12
239 %val2 = OpFSub %f32vec2 %f32vec2_12 %f32vec2_01
240 %val3 = OpFAdd %f32vec2 %val1 %val2
241 %val4 = OpFNegate %f32vec2 %val3
242 %val5 = OpFDiv %f32vec2 %val4 %val1
243 %val6 = OpFRem %f32vec2 %val4 %f32vec2_12
244 %val7 = OpFMod %f32vec2 %val4 %f32vec2_12
245 )";
246
247 CompileSuccessfully(GenerateCode(body).c_str());
248 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
249 }
250
TEST_F(ValidateArithmetics,F64Vec2Success)251 TEST_F(ValidateArithmetics, F64Vec2Success) {
252 const std::string body = R"(
253 %val1 = OpFMul %f64vec2 %f64vec2_01 %f64vec2_12
254 %val2 = OpFSub %f64vec2 %f64vec2_12 %f64vec2_01
255 %val3 = OpFAdd %f64vec2 %val1 %val2
256 %val4 = OpFNegate %f64vec2 %val3
257 %val5 = OpFDiv %f64vec2 %val4 %val1
258 %val6 = OpFRem %f64vec2 %val4 %f64vec2_12
259 %val7 = OpFMod %f64vec2 %val4 %f64vec2_12
260 )";
261
262 CompileSuccessfully(GenerateCode(body).c_str());
263 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
264 }
265
TEST_F(ValidateArithmetics,U32Vec2Success)266 TEST_F(ValidateArithmetics, U32Vec2Success) {
267 const std::string body = R"(
268 %val1 = OpIMul %u32vec2 %u32vec2_01 %u32vec2_12
269 %val2 = OpISub %u32vec2 %u32vec2_12 %u32vec2_01
270 %val3 = OpIAdd %u32vec2 %val1 %val2
271 %val4 = OpSNegate %u32vec2 %val3
272 %val5 = OpSDiv %u32vec2 %val4 %val1
273 %val6 = OpSRem %u32vec2 %val4 %u32vec2_12
274 %val7 = OpSMod %u32vec2 %val4 %u32vec2_12
275 )";
276
277 CompileSuccessfully(GenerateCode(body).c_str());
278 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
279 }
280
TEST_F(ValidateArithmetics,FNegateTypeIdU32)281 TEST_F(ValidateArithmetics, FNegateTypeIdU32) {
282 const std::string body = R"(
283 %val = OpFNegate %u32 %u32_0
284 )";
285
286 CompileSuccessfully(GenerateCode(body).c_str());
287 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
288 EXPECT_THAT(
289 getDiagnosticString(),
290 HasSubstr(
291 "Expected floating scalar or vector type as Result Type: FNegate"));
292 }
293
TEST_F(ValidateArithmetics,FNegateTypeIdVec2U32)294 TEST_F(ValidateArithmetics, FNegateTypeIdVec2U32) {
295 const std::string body = R"(
296 %val = OpFNegate %u32vec2 %u32vec2_01
297 )";
298
299 CompileSuccessfully(GenerateCode(body).c_str());
300 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
301 EXPECT_THAT(
302 getDiagnosticString(),
303 HasSubstr(
304 "Expected floating scalar or vector type as Result Type: FNegate"));
305 }
306
TEST_F(ValidateArithmetics,FNegateWrongOperand)307 TEST_F(ValidateArithmetics, FNegateWrongOperand) {
308 const std::string body = R"(
309 %val = OpFNegate %f32 %u32_0
310 )";
311
312 CompileSuccessfully(GenerateCode(body).c_str());
313 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
314 EXPECT_THAT(getDiagnosticString(),
315 HasSubstr("Expected arithmetic operands to be of Result Type: "
316 "FNegate operand index 2"));
317 }
318
TEST_F(ValidateArithmetics,FMulTypeIdU32)319 TEST_F(ValidateArithmetics, FMulTypeIdU32) {
320 const std::string body = R"(
321 %val = OpFMul %u32 %u32_0 %u32_1
322 )";
323
324 CompileSuccessfully(GenerateCode(body).c_str());
325 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
326 EXPECT_THAT(
327 getDiagnosticString(),
328 HasSubstr(
329 "Expected floating scalar or vector type as Result Type: FMul"));
330 }
331
TEST_F(ValidateArithmetics,FMulTypeIdVec2U32)332 TEST_F(ValidateArithmetics, FMulTypeIdVec2U32) {
333 const std::string body = R"(
334 %val = OpFMul %u32vec2 %u32vec2_01 %u32vec2_12
335 )";
336
337 CompileSuccessfully(GenerateCode(body).c_str());
338 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
339 EXPECT_THAT(
340 getDiagnosticString(),
341 HasSubstr(
342 "Expected floating scalar or vector type as Result Type: FMul"));
343 }
344
TEST_F(ValidateArithmetics,FMulWrongOperand1)345 TEST_F(ValidateArithmetics, FMulWrongOperand1) {
346 const std::string body = R"(
347 %val = OpFMul %f32 %u32_0 %f32_1
348 )";
349
350 CompileSuccessfully(GenerateCode(body).c_str());
351 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
352 EXPECT_THAT(getDiagnosticString(),
353 HasSubstr("Expected arithmetic operands to be of Result Type: "
354 "FMul operand index 2"));
355 }
356
TEST_F(ValidateArithmetics,FMulWrongOperand2)357 TEST_F(ValidateArithmetics, FMulWrongOperand2) {
358 const std::string body = R"(
359 %val = OpFMul %f32 %f32_0 %u32_1
360 )";
361
362 CompileSuccessfully(GenerateCode(body).c_str());
363 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
364 EXPECT_THAT(getDiagnosticString(),
365 HasSubstr("Expected arithmetic operands to be of Result Type: "
366 "FMul operand index 3"));
367 }
368
TEST_F(ValidateArithmetics,FMulWrongVectorOperand1)369 TEST_F(ValidateArithmetics, FMulWrongVectorOperand1) {
370 const std::string body = R"(
371 %val = OpFMul %f64vec3 %f32vec3_123 %f64vec3_012
372 )";
373
374 CompileSuccessfully(GenerateCode(body).c_str());
375 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
376 EXPECT_THAT(getDiagnosticString(),
377 HasSubstr("Expected arithmetic operands to be of Result Type: "
378 "FMul operand index 2"));
379 }
380
TEST_F(ValidateArithmetics,FMulWrongVectorOperand2)381 TEST_F(ValidateArithmetics, FMulWrongVectorOperand2) {
382 const std::string body = R"(
383 %val = OpFMul %f32vec3 %f32vec3_123 %f64vec3_012
384 )";
385
386 CompileSuccessfully(GenerateCode(body).c_str());
387 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
388 EXPECT_THAT(getDiagnosticString(),
389 HasSubstr("Expected arithmetic operands to be of Result Type: "
390 "FMul operand index 3"));
391 }
392
TEST_F(ValidateArithmetics,IMulFloatTypeId)393 TEST_F(ValidateArithmetics, IMulFloatTypeId) {
394 const std::string body = R"(
395 %val = OpIMul %f32 %u32_0 %s32_1
396 )";
397
398 CompileSuccessfully(GenerateCode(body).c_str());
399 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
400 EXPECT_THAT(
401 getDiagnosticString(),
402 HasSubstr("Expected int scalar or vector type as Result Type: IMul"));
403 }
404
TEST_F(ValidateArithmetics,IMulFloatOperand1)405 TEST_F(ValidateArithmetics, IMulFloatOperand1) {
406 const std::string body = R"(
407 %val = OpIMul %u32 %f32_0 %s32_1
408 )";
409
410 CompileSuccessfully(GenerateCode(body).c_str());
411 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
412 EXPECT_THAT(getDiagnosticString(),
413 HasSubstr("Expected int scalar or vector type as operand: "
414 "IMul operand index 2"));
415 }
416
TEST_F(ValidateArithmetics,IMulFloatOperand2)417 TEST_F(ValidateArithmetics, IMulFloatOperand2) {
418 const std::string body = R"(
419 %val = OpIMul %u32 %s32_0 %f32_1
420 )";
421
422 CompileSuccessfully(GenerateCode(body).c_str());
423 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
424 EXPECT_THAT(getDiagnosticString(),
425 HasSubstr("Expected int scalar or vector type as operand: "
426 "IMul operand index 3"));
427 }
428
TEST_F(ValidateArithmetics,IMulWrongBitWidthOperand1)429 TEST_F(ValidateArithmetics, IMulWrongBitWidthOperand1) {
430 const std::string body = R"(
431 %val = OpIMul %u64 %u32_0 %s64_1
432 )";
433
434 CompileSuccessfully(GenerateCode(body).c_str());
435 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
436 EXPECT_THAT(
437 getDiagnosticString(),
438 HasSubstr("Expected arithmetic operands to have the same bit width "
439 "as Result Type: IMul operand index 2"));
440 }
441
TEST_F(ValidateArithmetics,IMulWrongBitWidthOperand2)442 TEST_F(ValidateArithmetics, IMulWrongBitWidthOperand2) {
443 const std::string body = R"(
444 %val = OpIMul %u32 %u32_0 %s64_1
445 )";
446
447 CompileSuccessfully(GenerateCode(body).c_str());
448 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
449 EXPECT_THAT(
450 getDiagnosticString(),
451 HasSubstr("Expected arithmetic operands to have the same bit width "
452 "as Result Type: IMul operand index 3"));
453 }
454
TEST_F(ValidateArithmetics,IMulWrongBitWidthVector)455 TEST_F(ValidateArithmetics, IMulWrongBitWidthVector) {
456 const std::string body = R"(
457 %val = OpIMul %u64vec3 %u32vec3_012 %u32vec3_123
458 )";
459
460 CompileSuccessfully(GenerateCode(body).c_str());
461 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
462 EXPECT_THAT(
463 getDiagnosticString(),
464 HasSubstr("Expected arithmetic operands to have the same bit width "
465 "as Result Type: IMul operand index 2"));
466 }
467
TEST_F(ValidateArithmetics,IMulVectorScalarOperand1)468 TEST_F(ValidateArithmetics, IMulVectorScalarOperand1) {
469 const std::string body = R"(
470 %val = OpIMul %u32vec2 %u32_0 %u32vec2_01
471 )";
472
473 CompileSuccessfully(GenerateCode(body).c_str());
474 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
475 EXPECT_THAT(
476 getDiagnosticString(),
477 HasSubstr("Expected arithmetic operands to have the same dimension "
478 "as Result Type: IMul operand index 2"));
479 }
480
TEST_F(ValidateArithmetics,IMulVectorScalarOperand2)481 TEST_F(ValidateArithmetics, IMulVectorScalarOperand2) {
482 const std::string body = R"(
483 %val = OpIMul %u32vec2 %u32vec2_01 %u32_0
484 )";
485
486 CompileSuccessfully(GenerateCode(body).c_str());
487 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
488 EXPECT_THAT(
489 getDiagnosticString(),
490 HasSubstr("Expected arithmetic operands to have the same dimension "
491 "as Result Type: IMul operand index 3"));
492 }
493
TEST_F(ValidateArithmetics,IMulScalarVectorOperand1)494 TEST_F(ValidateArithmetics, IMulScalarVectorOperand1) {
495 const std::string body = R"(
496 %val = OpIMul %s32 %u32vec2_01 %u32_0
497 )";
498
499 CompileSuccessfully(GenerateCode(body).c_str());
500 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
501 EXPECT_THAT(
502 getDiagnosticString(),
503 HasSubstr("Expected arithmetic operands to have the same dimension "
504 "as Result Type: IMul operand index 2"));
505 }
506
TEST_F(ValidateArithmetics,IMulScalarVectorOperand2)507 TEST_F(ValidateArithmetics, IMulScalarVectorOperand2) {
508 const std::string body = R"(
509 %val = OpIMul %u32 %u32_0 %s32vec2_01
510 )";
511
512 CompileSuccessfully(GenerateCode(body).c_str());
513 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
514 EXPECT_THAT(
515 getDiagnosticString(),
516 HasSubstr("Expected arithmetic operands to have the same dimension "
517 "as Result Type: IMul operand index 3"));
518 }
519
TEST_F(ValidateArithmetics,SNegateFloat)520 TEST_F(ValidateArithmetics, SNegateFloat) {
521 const std::string body = R"(
522 %val = OpSNegate %s32 %f32_1
523 )";
524
525 CompileSuccessfully(GenerateCode(body).c_str());
526 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
527 EXPECT_THAT(getDiagnosticString(),
528 HasSubstr("Expected int scalar or vector type as operand: "
529 "SNegate operand index 2"));
530 }
531
TEST_F(ValidateArithmetics,UDivFloatType)532 TEST_F(ValidateArithmetics, UDivFloatType) {
533 const std::string body = R"(
534 %val = OpUDiv %f32 %u32_2 %u32_1
535 )";
536
537 CompileSuccessfully(GenerateCode(body).c_str());
538 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
539 EXPECT_THAT(
540 getDiagnosticString(),
541 HasSubstr(
542 "Expected unsigned int scalar or vector type as Result Type: UDiv"));
543 }
544
TEST_F(ValidateArithmetics,UDivSignedIntType)545 TEST_F(ValidateArithmetics, UDivSignedIntType) {
546 const std::string body = R"(
547 %val = OpUDiv %s32 %u32_2 %u32_1
548 )";
549
550 CompileSuccessfully(GenerateCode(body).c_str());
551 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
552 EXPECT_THAT(
553 getDiagnosticString(),
554 HasSubstr(
555 "Expected unsigned int scalar or vector type as Result Type: UDiv"));
556 }
557
TEST_F(ValidateArithmetics,UDivWrongOperand1)558 TEST_F(ValidateArithmetics, UDivWrongOperand1) {
559 const std::string body = R"(
560 %val = OpUDiv %u64 %f64_2 %u64_1
561 )";
562
563 CompileSuccessfully(GenerateCode(body).c_str());
564 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
565 EXPECT_THAT(getDiagnosticString(),
566 HasSubstr("Expected arithmetic operands to be of Result Type: "
567 "UDiv operand index 2"));
568 }
569
TEST_F(ValidateArithmetics,UDivWrongOperand2)570 TEST_F(ValidateArithmetics, UDivWrongOperand2) {
571 const std::string body = R"(
572 %val = OpUDiv %u64 %u64_2 %u32_1
573 )";
574
575 CompileSuccessfully(GenerateCode(body).c_str());
576 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
577 EXPECT_THAT(getDiagnosticString(),
578 HasSubstr("Expected arithmetic operands to be of Result Type: "
579 "UDiv operand index 3"));
580 }
581
TEST_F(ValidateArithmetics,DotSuccess)582 TEST_F(ValidateArithmetics, DotSuccess) {
583 const std::string body = R"(
584 %val = OpDot %f32 %f32vec2_01 %f32vec2_12
585 )";
586
587 CompileSuccessfully(GenerateCode(body).c_str());
588 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
589 }
590
TEST_F(ValidateArithmetics,DotWrongTypeId)591 TEST_F(ValidateArithmetics, DotWrongTypeId) {
592 const std::string body = R"(
593 %val = OpDot %u32 %u32vec2_01 %u32vec2_12
594 )";
595
596 CompileSuccessfully(GenerateCode(body).c_str());
597 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
598 EXPECT_THAT(getDiagnosticString(),
599 HasSubstr("Expected float scalar type as Result Type: Dot"));
600 }
601
TEST_F(ValidateArithmetics,DotNotVectorTypeOperand1)602 TEST_F(ValidateArithmetics, DotNotVectorTypeOperand1) {
603 const std::string body = R"(
604 %val = OpDot %f32 %f32 %f32vec2_12
605 )";
606
607 CompileSuccessfully(GenerateCode(body).c_str());
608 ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
609 EXPECT_THAT(getDiagnosticString(), HasSubstr("Operand 6[%float] cannot be a "
610 "type"));
611 }
612
TEST_F(ValidateArithmetics,DotNotVectorTypeOperand2)613 TEST_F(ValidateArithmetics, DotNotVectorTypeOperand2) {
614 const std::string body = R"(
615 %val = OpDot %f32 %f32vec3_012 %f32_1
616 )";
617
618 CompileSuccessfully(GenerateCode(body).c_str());
619 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
620 EXPECT_THAT(
621 getDiagnosticString(),
622 HasSubstr("Expected float vector as operand: Dot operand index 3"));
623 }
624
TEST_F(ValidateArithmetics,DotWrongComponentOperand1)625 TEST_F(ValidateArithmetics, DotWrongComponentOperand1) {
626 const std::string body = R"(
627 %val = OpDot %f64 %f32vec2_01 %f64vec2_12
628 )";
629
630 CompileSuccessfully(GenerateCode(body).c_str());
631 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
632 EXPECT_THAT(getDiagnosticString(),
633 HasSubstr("Expected component type to be equal to Result Type: "
634 "Dot operand index 2"));
635 }
636
TEST_F(ValidateArithmetics,DotWrongComponentOperand2)637 TEST_F(ValidateArithmetics, DotWrongComponentOperand2) {
638 const std::string body = R"(
639 %val = OpDot %f32 %f32vec2_01 %f64vec2_12
640 )";
641
642 CompileSuccessfully(GenerateCode(body).c_str());
643 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
644 EXPECT_THAT(getDiagnosticString(),
645 HasSubstr("Expected component type to be equal to Result Type: "
646 "Dot operand index 3"));
647 }
648
TEST_F(ValidateArithmetics,DotDifferentVectorSize)649 TEST_F(ValidateArithmetics, DotDifferentVectorSize) {
650 const std::string body = R"(
651 %val = OpDot %f32 %f32vec2_01 %f32vec3_123
652 )";
653
654 CompileSuccessfully(GenerateCode(body).c_str());
655 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
656 EXPECT_THAT(
657 getDiagnosticString(),
658 HasSubstr(
659 "Expected operands to have the same number of componenets: Dot"));
660 }
661
TEST_F(ValidateArithmetics,VectorTimesScalarSuccess)662 TEST_F(ValidateArithmetics, VectorTimesScalarSuccess) {
663 const std::string body = R"(
664 %val = OpVectorTimesScalar %f32vec2 %f32vec2_01 %f32_2
665 )";
666
667 CompileSuccessfully(GenerateCode(body).c_str());
668 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
669 }
670
TEST_F(ValidateArithmetics,VectorTimesScalarWrongTypeId)671 TEST_F(ValidateArithmetics, VectorTimesScalarWrongTypeId) {
672 const std::string body = R"(
673 %val = OpVectorTimesScalar %u32vec2 %f32vec2_01 %f32_2
674 )";
675
676 CompileSuccessfully(GenerateCode(body).c_str());
677 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
678 EXPECT_THAT(getDiagnosticString(),
679 HasSubstr("Expected float vector type as Result Type: "
680 "VectorTimesScalar"));
681 }
682
TEST_F(ValidateArithmetics,VectorTimesScalarWrongVector)683 TEST_F(ValidateArithmetics, VectorTimesScalarWrongVector) {
684 const std::string body = R"(
685 %val = OpVectorTimesScalar %f32vec2 %f32vec3_012 %f32_2
686 )";
687
688 CompileSuccessfully(GenerateCode(body).c_str());
689 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
690 EXPECT_THAT(
691 getDiagnosticString(),
692 HasSubstr("Expected vector operand type to be equal to Result Type: "
693 "VectorTimesScalar"));
694 }
695
TEST_F(ValidateArithmetics,VectorTimesScalarWrongScalar)696 TEST_F(ValidateArithmetics, VectorTimesScalarWrongScalar) {
697 const std::string body = R"(
698 %val = OpVectorTimesScalar %f32vec2 %f32vec2_01 %f64_2
699 )";
700
701 CompileSuccessfully(GenerateCode(body).c_str());
702 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
703 EXPECT_THAT(
704 getDiagnosticString(),
705 HasSubstr("Expected scalar operand type to be equal to the component "
706 "type of the vector operand: VectorTimesScalar"));
707 }
708
TEST_F(ValidateArithmetics,MatrixTimesScalarSuccess)709 TEST_F(ValidateArithmetics, MatrixTimesScalarSuccess) {
710 const std::string body = R"(
711 %val = OpMatrixTimesScalar %f32mat22 %f32mat22_1212 %f32_2
712 )";
713
714 CompileSuccessfully(GenerateCode(body).c_str());
715 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
716 }
717
TEST_F(ValidateArithmetics,MatrixTimesScalarWrongTypeId)718 TEST_F(ValidateArithmetics, MatrixTimesScalarWrongTypeId) {
719 const std::string body = R"(
720 %val = OpMatrixTimesScalar %f32vec2 %f32mat22_1212 %f32_2
721 )";
722
723 CompileSuccessfully(GenerateCode(body).c_str());
724 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
725 EXPECT_THAT(getDiagnosticString(),
726 HasSubstr("Expected float matrix type as Result Type: "
727 "MatrixTimesScalar"));
728 }
729
TEST_F(ValidateArithmetics,MatrixTimesScalarWrongMatrix)730 TEST_F(ValidateArithmetics, MatrixTimesScalarWrongMatrix) {
731 const std::string body = R"(
732 %val = OpMatrixTimesScalar %f32mat22 %f32vec2_01 %f32_2
733 )";
734
735 CompileSuccessfully(GenerateCode(body).c_str());
736 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
737 EXPECT_THAT(
738 getDiagnosticString(),
739 HasSubstr("Expected matrix operand type to be equal to Result Type: "
740 "MatrixTimesScalar"));
741 }
742
TEST_F(ValidateArithmetics,MatrixTimesScalarWrongScalar)743 TEST_F(ValidateArithmetics, MatrixTimesScalarWrongScalar) {
744 const std::string body = R"(
745 %val = OpMatrixTimesScalar %f32mat22 %f32mat22_1212 %f64_2
746 )";
747
748 CompileSuccessfully(GenerateCode(body).c_str());
749 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
750 EXPECT_THAT(
751 getDiagnosticString(),
752 HasSubstr("Expected scalar operand type to be equal to the component "
753 "type of the matrix operand: MatrixTimesScalar"));
754 }
755
TEST_F(ValidateArithmetics,VectorTimesMatrix2x22Success)756 TEST_F(ValidateArithmetics, VectorTimesMatrix2x22Success) {
757 const std::string body = R"(
758 %val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f32mat22_1212
759 )";
760
761 CompileSuccessfully(GenerateCode(body).c_str());
762 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
763 }
764
TEST_F(ValidateArithmetics,VectorTimesMatrix3x32Success)765 TEST_F(ValidateArithmetics, VectorTimesMatrix3x32Success) {
766 const std::string body = R"(
767 %val = OpVectorTimesMatrix %f32vec2 %f32vec3_123 %f32mat32_123123
768 )";
769
770 CompileSuccessfully(GenerateCode(body).c_str());
771 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
772 }
773
TEST_F(ValidateArithmetics,VectorTimesMatrixWrongTypeId)774 TEST_F(ValidateArithmetics, VectorTimesMatrixWrongTypeId) {
775 const std::string body = R"(
776 %val = OpVectorTimesMatrix %f32mat22 %f32vec2_12 %f32mat22_1212
777 )";
778
779 CompileSuccessfully(GenerateCode(body).c_str());
780 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
781 EXPECT_THAT(getDiagnosticString(),
782 HasSubstr("Expected float vector type as Result Type: "
783 "VectorTimesMatrix"));
784 }
785
TEST_F(ValidateArithmetics,VectorTimesMatrixNotFloatVector)786 TEST_F(ValidateArithmetics, VectorTimesMatrixNotFloatVector) {
787 const std::string body = R"(
788 %val = OpVectorTimesMatrix %f32vec2 %u32vec2_12 %f32mat22_1212
789 )";
790
791 CompileSuccessfully(GenerateCode(body).c_str());
792 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
793 EXPECT_THAT(getDiagnosticString(),
794 HasSubstr("Expected float vector type as left operand: "
795 "VectorTimesMatrix"));
796 }
797
TEST_F(ValidateArithmetics,VectorTimesMatrixWrongVectorComponent)798 TEST_F(ValidateArithmetics, VectorTimesMatrixWrongVectorComponent) {
799 const std::string body = R"(
800 %val = OpVectorTimesMatrix %f32vec2 %f64vec2_12 %f32mat22_1212
801 )";
802
803 CompileSuccessfully(GenerateCode(body).c_str());
804 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
805 EXPECT_THAT(
806 getDiagnosticString(),
807 HasSubstr(
808 "Expected component types of Result Type and vector to be equal: "
809 "VectorTimesMatrix"));
810 }
811
TEST_F(ValidateArithmetics,VectorTimesMatrixWrongMatrix)812 TEST_F(ValidateArithmetics, VectorTimesMatrixWrongMatrix) {
813 const std::string body = R"(
814 %val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f32vec2_12
815 )";
816
817 CompileSuccessfully(GenerateCode(body).c_str());
818 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
819 EXPECT_THAT(getDiagnosticString(),
820 HasSubstr("Expected float matrix type as right operand: "
821 "VectorTimesMatrix"));
822 }
823
TEST_F(ValidateArithmetics,VectorTimesMatrixWrongMatrixComponent)824 TEST_F(ValidateArithmetics, VectorTimesMatrixWrongMatrixComponent) {
825 const std::string body = R"(
826 %val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f64mat22_1212
827 )";
828
829 CompileSuccessfully(GenerateCode(body).c_str());
830 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
831 EXPECT_THAT(
832 getDiagnosticString(),
833 HasSubstr(
834 "Expected component types of Result Type and matrix to be equal: "
835 "VectorTimesMatrix"));
836 }
837
TEST_F(ValidateArithmetics,VectorTimesMatrix2eq2x23Fail)838 TEST_F(ValidateArithmetics, VectorTimesMatrix2eq2x23Fail) {
839 const std::string body = R"(
840 %val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f32mat23_121212
841 )";
842
843 CompileSuccessfully(GenerateCode(body).c_str());
844 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
845 EXPECT_THAT(
846 getDiagnosticString(),
847 HasSubstr(
848 "Expected number of columns of the matrix to be equal to Result Type "
849 "vector size: VectorTimesMatrix"));
850 }
851
TEST_F(ValidateArithmetics,VectorTimesMatrix2x32Fail)852 TEST_F(ValidateArithmetics, VectorTimesMatrix2x32Fail) {
853 const std::string body = R"(
854 %val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f32mat32_123123
855 )";
856
857 CompileSuccessfully(GenerateCode(body).c_str());
858 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
859 EXPECT_THAT(
860 getDiagnosticString(),
861 HasSubstr(
862 "Expected number of rows of the matrix to be equal to the vector "
863 "operand size: VectorTimesMatrix"));
864 }
865
TEST_F(ValidateArithmetics,MatrixTimesVector22x2Success)866 TEST_F(ValidateArithmetics, MatrixTimesVector22x2Success) {
867 const std::string body = R"(
868 %val = OpMatrixTimesVector %f32vec2 %f32mat22_1212 %f32vec2_12
869 )";
870
871 CompileSuccessfully(GenerateCode(body).c_str());
872 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
873 }
874
TEST_F(ValidateArithmetics,MatrixTimesVector23x3Success)875 TEST_F(ValidateArithmetics, MatrixTimesVector23x3Success) {
876 const std::string body = R"(
877 %val = OpMatrixTimesVector %f32vec2 %f32mat23_121212 %f32vec3_123
878 )";
879
880 CompileSuccessfully(GenerateCode(body).c_str());
881 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
882 }
883
TEST_F(ValidateArithmetics,MatrixTimesVectorWrongTypeId)884 TEST_F(ValidateArithmetics, MatrixTimesVectorWrongTypeId) {
885 const std::string body = R"(
886 %val = OpMatrixTimesVector %f32mat22 %f32mat22_1212 %f32vec2_12
887 )";
888
889 CompileSuccessfully(GenerateCode(body).c_str());
890 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
891 EXPECT_THAT(getDiagnosticString(),
892 HasSubstr("Expected float vector type as Result Type: "
893 "MatrixTimesVector"));
894 }
895
TEST_F(ValidateArithmetics,MatrixTimesVectorWrongMatrix)896 TEST_F(ValidateArithmetics, MatrixTimesVectorWrongMatrix) {
897 const std::string body = R"(
898 %val = OpMatrixTimesVector %f32vec3 %f32vec3_123 %f32vec3_123
899 )";
900
901 CompileSuccessfully(GenerateCode(body).c_str());
902 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
903 EXPECT_THAT(getDiagnosticString(),
904 HasSubstr("Expected float matrix type as left operand: "
905 "MatrixTimesVector"));
906 }
907
TEST_F(ValidateArithmetics,MatrixTimesVectorWrongMatrixCol)908 TEST_F(ValidateArithmetics, MatrixTimesVectorWrongMatrixCol) {
909 const std::string body = R"(
910 %val = OpMatrixTimesVector %f32vec3 %f32mat23_121212 %f32vec3_123
911 )";
912
913 CompileSuccessfully(GenerateCode(body).c_str());
914 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
915 EXPECT_THAT(
916 getDiagnosticString(),
917 HasSubstr(
918 "Expected column type of the matrix to be equal to Result Type: "
919 "MatrixTimesVector"));
920 }
921
TEST_F(ValidateArithmetics,MatrixTimesVectorWrongVector)922 TEST_F(ValidateArithmetics, MatrixTimesVectorWrongVector) {
923 const std::string body = R"(
924 %val = OpMatrixTimesVector %f32vec2 %f32mat22_1212 %u32vec2_12
925 )";
926
927 CompileSuccessfully(GenerateCode(body).c_str());
928 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
929 EXPECT_THAT(getDiagnosticString(),
930 HasSubstr("Expected float vector type as right operand: "
931 "MatrixTimesVector"));
932 }
933
TEST_F(ValidateArithmetics,MatrixTimesVectorDifferentComponents)934 TEST_F(ValidateArithmetics, MatrixTimesVectorDifferentComponents) {
935 const std::string body = R"(
936 %val = OpMatrixTimesVector %f32vec2 %f32mat22_1212 %f64vec2_12
937 )";
938
939 CompileSuccessfully(GenerateCode(body).c_str());
940 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
941 EXPECT_THAT(getDiagnosticString(),
942 HasSubstr("Expected component types of the operands to be equal: "
943 "MatrixTimesVector"));
944 }
945
TEST_F(ValidateArithmetics,MatrixTimesVector22x3Fail)946 TEST_F(ValidateArithmetics, MatrixTimesVector22x3Fail) {
947 const std::string body = R"(
948 %val = OpMatrixTimesVector %f32vec2 %f32mat22_1212 %f32vec3_123
949 )";
950
951 CompileSuccessfully(GenerateCode(body).c_str());
952 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
953 EXPECT_THAT(
954 getDiagnosticString(),
955 HasSubstr(
956 "Expected number of columns of the matrix to be equal to the vector "
957 "size: MatrixTimesVector"));
958 }
959
TEST_F(ValidateArithmetics,MatrixTimesMatrix22x22Success)960 TEST_F(ValidateArithmetics, MatrixTimesMatrix22x22Success) {
961 const std::string body = R"(
962 %val = OpMatrixTimesMatrix %f32mat22 %f32mat22_1212 %f32mat22_1212
963 )";
964
965 CompileSuccessfully(GenerateCode(body).c_str());
966 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
967 }
968
TEST_F(ValidateArithmetics,MatrixTimesMatrix23x32Success)969 TEST_F(ValidateArithmetics, MatrixTimesMatrix23x32Success) {
970 const std::string body = R"(
971 %val = OpMatrixTimesMatrix %f32mat22 %f32mat23_121212 %f32mat32_123123
972 )";
973
974 CompileSuccessfully(GenerateCode(body).c_str());
975 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
976 }
977
TEST_F(ValidateArithmetics,MatrixTimesMatrix33x33Success)978 TEST_F(ValidateArithmetics, MatrixTimesMatrix33x33Success) {
979 const std::string body = R"(
980 %val = OpMatrixTimesMatrix %f32mat33 %f32mat33_123123123 %f32mat33_123123123
981 )";
982
983 CompileSuccessfully(GenerateCode(body).c_str());
984 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
985 }
986
TEST_F(ValidateArithmetics,MatrixTimesMatrixWrongTypeId)987 TEST_F(ValidateArithmetics, MatrixTimesMatrixWrongTypeId) {
988 const std::string body = R"(
989 %val = OpMatrixTimesMatrix %f32vec2 %f32mat22_1212 %f32mat22_1212
990 )";
991
992 CompileSuccessfully(GenerateCode(body).c_str());
993 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
994 EXPECT_THAT(
995 getDiagnosticString(),
996 HasSubstr(
997 "Expected float matrix type as Result Type: MatrixTimesMatrix"));
998 }
999
TEST_F(ValidateArithmetics,MatrixTimesMatrixWrongLeftOperand)1000 TEST_F(ValidateArithmetics, MatrixTimesMatrixWrongLeftOperand) {
1001 const std::string body = R"(
1002 %val = OpMatrixTimesMatrix %f32mat22 %f32vec2_12 %f32mat22_1212
1003 )";
1004
1005 CompileSuccessfully(GenerateCode(body).c_str());
1006 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1007 EXPECT_THAT(
1008 getDiagnosticString(),
1009 HasSubstr(
1010 "Expected float matrix type as left operand: MatrixTimesMatrix"));
1011 }
1012
TEST_F(ValidateArithmetics,MatrixTimesMatrixWrongRightOperand)1013 TEST_F(ValidateArithmetics, MatrixTimesMatrixWrongRightOperand) {
1014 const std::string body = R"(
1015 %val = OpMatrixTimesMatrix %f32mat22 %f32mat22_1212 %f32vec2_12
1016 )";
1017
1018 CompileSuccessfully(GenerateCode(body).c_str());
1019 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1020 EXPECT_THAT(
1021 getDiagnosticString(),
1022 HasSubstr(
1023 "Expected float matrix type as right operand: MatrixTimesMatrix"));
1024 }
1025
TEST_F(ValidateArithmetics,MatrixTimesMatrix32x23Fail)1026 TEST_F(ValidateArithmetics, MatrixTimesMatrix32x23Fail) {
1027 const std::string body = R"(
1028 %val = OpMatrixTimesMatrix %f32mat22 %f32mat32_123123 %f32mat23_121212
1029 )";
1030
1031 CompileSuccessfully(GenerateCode(body).c_str());
1032 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1033 EXPECT_THAT(
1034 getDiagnosticString(),
1035 HasSubstr(
1036 "Expected column types of Result Type and left matrix to be equal: "
1037 "MatrixTimesMatrix"));
1038 }
1039
TEST_F(ValidateArithmetics,MatrixTimesMatrixDifferentComponents)1040 TEST_F(ValidateArithmetics, MatrixTimesMatrixDifferentComponents) {
1041 const std::string body = R"(
1042 %val = OpMatrixTimesMatrix %f32mat22 %f32mat22_1212 %f64mat22_1212
1043 )";
1044
1045 CompileSuccessfully(GenerateCode(body).c_str());
1046 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1047 EXPECT_THAT(getDiagnosticString(),
1048 HasSubstr("Expected component types of Result Type and right "
1049 "matrix to be equal: "
1050 "MatrixTimesMatrix"));
1051 }
1052
TEST_F(ValidateArithmetics,MatrixTimesMatrix23x23Fail)1053 TEST_F(ValidateArithmetics, MatrixTimesMatrix23x23Fail) {
1054 const std::string body = R"(
1055 %val = OpMatrixTimesMatrix %f32mat22 %f32mat23_121212 %f32mat23_121212
1056 )";
1057
1058 CompileSuccessfully(GenerateCode(body).c_str());
1059 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1060 EXPECT_THAT(getDiagnosticString(),
1061 HasSubstr("Expected number of columns of Result Type and right "
1062 "matrix to be equal: "
1063 "MatrixTimesMatrix"));
1064 }
1065
TEST_F(ValidateArithmetics,MatrixTimesMatrix23x22Fail)1066 TEST_F(ValidateArithmetics, MatrixTimesMatrix23x22Fail) {
1067 const std::string body = R"(
1068 %val = OpMatrixTimesMatrix %f32mat22 %f32mat23_121212 %f32mat22_1212
1069 )";
1070
1071 CompileSuccessfully(GenerateCode(body).c_str());
1072 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1073 EXPECT_THAT(getDiagnosticString(),
1074 HasSubstr("Expected number of columns of left matrix and number "
1075 "of rows of right "
1076 "matrix to be equal: MatrixTimesMatrix"));
1077 }
1078
TEST_F(ValidateArithmetics,OuterProduct2x2Success)1079 TEST_F(ValidateArithmetics, OuterProduct2x2Success) {
1080 const std::string body = R"(
1081 %val = OpOuterProduct %f32mat22 %f32vec2_12 %f32vec2_01
1082 )";
1083
1084 CompileSuccessfully(GenerateCode(body).c_str());
1085 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
1086 }
1087
TEST_F(ValidateArithmetics,OuterProduct3x2Success)1088 TEST_F(ValidateArithmetics, OuterProduct3x2Success) {
1089 const std::string body = R"(
1090 %val = OpOuterProduct %f32mat32 %f32vec3_123 %f32vec2_01
1091 )";
1092
1093 CompileSuccessfully(GenerateCode(body).c_str());
1094 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
1095 }
1096
TEST_F(ValidateArithmetics,OuterProduct2x3Success)1097 TEST_F(ValidateArithmetics, OuterProduct2x3Success) {
1098 const std::string body = R"(
1099 %val = OpOuterProduct %f32mat23 %f32vec2_01 %f32vec3_123
1100 )";
1101
1102 CompileSuccessfully(GenerateCode(body).c_str());
1103 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
1104 }
1105
TEST_F(ValidateArithmetics,OuterProductWrongTypeId)1106 TEST_F(ValidateArithmetics, OuterProductWrongTypeId) {
1107 const std::string body = R"(
1108 %val = OpOuterProduct %f32vec2 %f32vec2_01 %f32vec3_123
1109 )";
1110
1111 CompileSuccessfully(GenerateCode(body).c_str());
1112 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1113 EXPECT_THAT(getDiagnosticString(),
1114 HasSubstr("Expected float matrix type as Result Type: "
1115 "OuterProduct"));
1116 }
1117
TEST_F(ValidateArithmetics,OuterProductWrongLeftOperand)1118 TEST_F(ValidateArithmetics, OuterProductWrongLeftOperand) {
1119 const std::string body = R"(
1120 %val = OpOuterProduct %f32mat22 %f32vec3_123 %f32vec2_01
1121 )";
1122
1123 CompileSuccessfully(GenerateCode(body).c_str());
1124 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1125 EXPECT_THAT(
1126 getDiagnosticString(),
1127 HasSubstr("Expected column type of Result Type to be equal to the type "
1128 "of the left operand: OuterProduct"));
1129 }
1130
TEST_F(ValidateArithmetics,OuterProductRightOperandNotFloatVector)1131 TEST_F(ValidateArithmetics, OuterProductRightOperandNotFloatVector) {
1132 const std::string body = R"(
1133 %val = OpOuterProduct %f32mat22 %f32vec2_12 %u32vec2_01
1134 )";
1135
1136 CompileSuccessfully(GenerateCode(body).c_str());
1137 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1138 EXPECT_THAT(
1139 getDiagnosticString(),
1140 HasSubstr("Expected float vector type as right operand: OuterProduct"));
1141 }
1142
TEST_F(ValidateArithmetics,OuterProductRightOperandWrongComponent)1143 TEST_F(ValidateArithmetics, OuterProductRightOperandWrongComponent) {
1144 const std::string body = R"(
1145 %val = OpOuterProduct %f32mat22 %f32vec2_12 %f64vec2_01
1146 )";
1147
1148 CompileSuccessfully(GenerateCode(body).c_str());
1149 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1150 EXPECT_THAT(getDiagnosticString(),
1151 HasSubstr("Expected component types of the operands to be equal: "
1152 "OuterProduct"));
1153 }
1154
TEST_F(ValidateArithmetics,OuterProductRightOperandWrongDimension)1155 TEST_F(ValidateArithmetics, OuterProductRightOperandWrongDimension) {
1156 const std::string body = R"(
1157 %val = OpOuterProduct %f32mat22 %f32vec2_12 %f32vec3_123
1158 )";
1159
1160 CompileSuccessfully(GenerateCode(body).c_str());
1161 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1162 EXPECT_THAT(
1163 getDiagnosticString(),
1164 HasSubstr("Expected number of columns of the matrix to be equal to the "
1165 "vector size of the right operand: OuterProduct"));
1166 }
1167
GenerateCoopMatCode(const std::string & extra_types,const std::string & main_body)1168 std::string GenerateCoopMatCode(const std::string& extra_types,
1169 const std::string& main_body) {
1170 const std::string prefix =
1171 R"(
1172 OpCapability Shader
1173 OpCapability Float16
1174 OpCapability CooperativeMatrixNV
1175 OpExtension "SPV_NV_cooperative_matrix"
1176 OpMemoryModel Logical GLSL450
1177 OpEntryPoint GLCompute %main "main"
1178 %void = OpTypeVoid
1179 %func = OpTypeFunction %void
1180 %bool = OpTypeBool
1181 %f16 = OpTypeFloat 16
1182 %f32 = OpTypeFloat 32
1183 %u32 = OpTypeInt 32 0
1184 %s32 = OpTypeInt 32 1
1185
1186 %u32_8 = OpConstant %u32 8
1187 %u32_16 = OpConstant %u32 16
1188 %u32_4 = OpConstant %u32 4
1189 %subgroup = OpConstant %u32 3
1190
1191 %f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
1192 %u32mat = OpTypeCooperativeMatrixNV %u32 %subgroup %u32_8 %u32_8
1193 %s32mat = OpTypeCooperativeMatrixNV %s32 %subgroup %u32_8 %u32_8
1194
1195 %f16_1 = OpConstant %f16 1
1196 %f32_1 = OpConstant %f32 1
1197 %u32_1 = OpConstant %u32 1
1198 %s32_1 = OpConstant %s32 1
1199
1200 %f16mat_1 = OpConstantComposite %f16mat %f16_1
1201 %u32mat_1 = OpConstantComposite %u32mat %u32_1
1202 %s32mat_1 = OpConstantComposite %s32mat %s32_1
1203
1204 %u32_c1 = OpSpecConstant %u32 1
1205 %u32_c2 = OpSpecConstant %u32 2
1206
1207 %f16matc = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_c1 %u32_c2
1208 %f16matc_1 = OpConstantComposite %f16matc %f16_1
1209
1210 %mat16x4 = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_16 %u32_4
1211 %mat4x16 = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_4 %u32_16
1212 %mat16x16 = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_16 %u32_16
1213 %f16mat_16x4_1 = OpConstantComposite %mat16x4 %f16_1
1214 %f16mat_4x16_1 = OpConstantComposite %mat4x16 %f16_1
1215 %f16mat_16x16_1 = OpConstantComposite %mat16x16 %f16_1)";
1216
1217 const std::string func_begin =
1218 R"(
1219 %main = OpFunction %void None %func
1220 %main_entry = OpLabel)";
1221
1222 const std::string suffix =
1223 R"(
1224 OpReturn
1225 OpFunctionEnd)";
1226
1227 return prefix + extra_types + func_begin + main_body + suffix;
1228 }
1229
TEST_F(ValidateArithmetics,CoopMatSuccess)1230 TEST_F(ValidateArithmetics, CoopMatSuccess) {
1231 const std::string body = R"(
1232 %val1 = OpFAdd %f16mat %f16mat_1 %f16mat_1
1233 %val2 = OpFSub %f16mat %f16mat_1 %f16mat_1
1234 %val3 = OpFDiv %f16mat %f16mat_1 %f16mat_1
1235 %val4 = OpFNegate %f16mat %f16mat_1
1236 %val5 = OpIAdd %u32mat %u32mat_1 %u32mat_1
1237 %val6 = OpISub %u32mat %u32mat_1 %u32mat_1
1238 %val7 = OpUDiv %u32mat %u32mat_1 %u32mat_1
1239 %val8 = OpIAdd %s32mat %s32mat_1 %s32mat_1
1240 %val9 = OpISub %s32mat %s32mat_1 %s32mat_1
1241 %val10 = OpSDiv %s32mat %s32mat_1 %s32mat_1
1242 %val11 = OpSNegate %s32mat %s32mat_1
1243 %val12 = OpMatrixTimesScalar %f16mat %f16mat_1 %f16_1
1244 %val13 = OpMatrixTimesScalar %u32mat %u32mat_1 %u32_1
1245 %val14 = OpMatrixTimesScalar %s32mat %s32mat_1 %s32_1
1246 %val15 = OpCooperativeMatrixMulAddNV %mat16x16 %f16mat_16x4_1 %f16mat_4x16_1 %f16mat_16x16_1
1247 %val16 = OpCooperativeMatrixMulAddNV %f16matc %f16matc_1 %f16matc_1 %f16matc_1
1248 )";
1249
1250 CompileSuccessfully(GenerateCoopMatCode("", body).c_str());
1251 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
1252 }
1253
TEST_F(ValidateArithmetics,CoopMatFMulFail)1254 TEST_F(ValidateArithmetics, CoopMatFMulFail) {
1255 const std::string body = R"(
1256 %val1 = OpFMul %f16mat %f16mat_1 %f16mat_1
1257 )";
1258
1259 CompileSuccessfully(GenerateCoopMatCode("", body).c_str());
1260 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1261 EXPECT_THAT(
1262 getDiagnosticString(),
1263 HasSubstr(
1264 "Expected floating scalar or vector type as Result Type: FMul"));
1265 }
1266
TEST_F(ValidateArithmetics,CoopMatMatrixTimesScalarMismatchFail)1267 TEST_F(ValidateArithmetics, CoopMatMatrixTimesScalarMismatchFail) {
1268 const std::string body = R"(
1269 %val1 = OpMatrixTimesScalar %f16mat %f16mat_1 %f32_1
1270 )";
1271
1272 CompileSuccessfully(GenerateCoopMatCode("", body).c_str());
1273 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1274 EXPECT_THAT(
1275 getDiagnosticString(),
1276 HasSubstr("Expected scalar operand type to be equal to the component "
1277 "type of the matrix operand: MatrixTimesScalar"));
1278 }
1279
TEST_F(ValidateArithmetics,CoopMatScopeFail)1280 TEST_F(ValidateArithmetics, CoopMatScopeFail) {
1281 const std::string types = R"(
1282 %workgroup = OpConstant %u32 2
1283
1284 %mat16x16_wg = OpTypeCooperativeMatrixNV %f16 %workgroup %u32_16 %u32_16
1285 %f16matwg_16x16_1 = OpConstantComposite %mat16x16_wg %f16_1
1286 )";
1287
1288 const std::string body = R"(
1289 %val1 = OpCooperativeMatrixMulAddNV %mat16x16 %f16mat_16x4_1 %f16mat_4x16_1 %f16matwg_16x16_1
1290 )";
1291
1292 CompileSuccessfully(GenerateCoopMatCode(types, body).c_str());
1293 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1294 EXPECT_THAT(
1295 getDiagnosticString(),
1296 HasSubstr(
1297 "Cooperative matrix scopes must match: CooperativeMatrixMulAddNV"));
1298 }
1299
TEST_F(ValidateArithmetics,CoopMatDimFail)1300 TEST_F(ValidateArithmetics, CoopMatDimFail) {
1301 const std::string body = R"(
1302 %val1 = OpCooperativeMatrixMulAddNV %mat16x16 %f16mat_4x16_1 %f16mat_16x4_1 %f16mat_16x16_1
1303 )";
1304
1305 CompileSuccessfully(GenerateCoopMatCode("", body).c_str());
1306 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1307 EXPECT_THAT(
1308 getDiagnosticString(),
1309 HasSubstr("Cooperative matrix 'M' mismatch: CooperativeMatrixMulAddNV"));
1310 }
1311
TEST_F(ValidateArithmetics,IAddCarrySuccess)1312 TEST_F(ValidateArithmetics, IAddCarrySuccess) {
1313 const std::string body = R"(
1314 %val1 = OpIAddCarry %struct_u32_u32 %u32_0 %u32_1
1315 %val2 = OpIAddCarry %struct_u32vec2_u32vec2 %u32vec2_01 %u32vec2_12
1316 )";
1317
1318 CompileSuccessfully(GenerateCode(body).c_str());
1319 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
1320 }
1321
TEST_F(ValidateArithmetics,IAddCarryResultTypeNotStruct)1322 TEST_F(ValidateArithmetics, IAddCarryResultTypeNotStruct) {
1323 const std::string body = R"(
1324 %val = OpIAddCarry %u32 %u32_0 %u32_1
1325 )";
1326
1327 CompileSuccessfully(GenerateCode(body).c_str());
1328 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1329 EXPECT_THAT(getDiagnosticString(),
1330 HasSubstr("Expected a struct as Result Type: IAddCarry"));
1331 }
1332
TEST_F(ValidateArithmetics,IAddCarryResultTypeNotTwoMembers)1333 TEST_F(ValidateArithmetics, IAddCarryResultTypeNotTwoMembers) {
1334 const std::string body = R"(
1335 %val = OpIAddCarry %struct_u32_u32_u32 %u32_0 %u32_1
1336 )";
1337
1338 CompileSuccessfully(GenerateCode(body).c_str());
1339 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1340 EXPECT_THAT(
1341 getDiagnosticString(),
1342 HasSubstr("Expected Result Type struct to have two members: IAddCarry"));
1343 }
1344
TEST_F(ValidateArithmetics,IAddCarryResultTypeMemberNotUnsignedInt)1345 TEST_F(ValidateArithmetics, IAddCarryResultTypeMemberNotUnsignedInt) {
1346 const std::string body = R"(
1347 %val = OpIAddCarry %struct_s32_s32 %s32_0 %s32_1
1348 )";
1349
1350 CompileSuccessfully(GenerateCode(body).c_str());
1351 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1352 EXPECT_THAT(getDiagnosticString(),
1353 HasSubstr("Expected Result Type struct member types to be "
1354 "unsigned integer scalar "
1355 "or vector: IAddCarry"));
1356 }
1357
TEST_F(ValidateArithmetics,IAddCarryWrongLeftOperand)1358 TEST_F(ValidateArithmetics, IAddCarryWrongLeftOperand) {
1359 const std::string body = R"(
1360 %val = OpIAddCarry %struct_u32_u32 %s32_0 %u32_1
1361 )";
1362
1363 CompileSuccessfully(GenerateCode(body).c_str());
1364 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1365 EXPECT_THAT(getDiagnosticString(),
1366 HasSubstr("Expected both operands to be of Result Type member "
1367 "type: IAddCarry"));
1368 }
1369
TEST_F(ValidateArithmetics,IAddCarryWrongRightOperand)1370 TEST_F(ValidateArithmetics, IAddCarryWrongRightOperand) {
1371 const std::string body = R"(
1372 %val = OpIAddCarry %struct_u32_u32 %u32_0 %s32_1
1373 )";
1374
1375 CompileSuccessfully(GenerateCode(body).c_str());
1376 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1377 EXPECT_THAT(getDiagnosticString(),
1378 HasSubstr("Expected both operands to be of Result Type member "
1379 "type: IAddCarry"));
1380 }
1381
TEST_F(ValidateArithmetics,OpSMulExtendedSuccess)1382 TEST_F(ValidateArithmetics, OpSMulExtendedSuccess) {
1383 const std::string body = R"(
1384 %val1 = OpSMulExtended %struct_u32_u32 %u32_0 %u32_1
1385 %val2 = OpSMulExtended %struct_s32_s32 %s32_0 %s32_1
1386 %val3 = OpSMulExtended %struct_u32vec2_u32vec2 %u32vec2_01 %u32vec2_12
1387 %val4 = OpSMulExtended %struct_s32vec2_s32vec2 %s32vec2_01 %s32vec2_12
1388 )";
1389
1390 CompileSuccessfully(GenerateCode(body).c_str());
1391 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
1392 }
1393
TEST_F(ValidateArithmetics,SMulExtendedResultTypeMemberNotInt)1394 TEST_F(ValidateArithmetics, SMulExtendedResultTypeMemberNotInt) {
1395 const std::string body = R"(
1396 %val = OpSMulExtended %struct_f32_f32 %f32_0 %f32_1
1397 )";
1398
1399 CompileSuccessfully(GenerateCode(body).c_str());
1400 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1401 EXPECT_THAT(
1402 getDiagnosticString(),
1403 HasSubstr("Expected Result Type struct member types to be integer scalar "
1404 "or vector: SMulExtended"));
1405 }
1406
TEST_F(ValidateArithmetics,SMulExtendedResultTypeMembersNotIdentical)1407 TEST_F(ValidateArithmetics, SMulExtendedResultTypeMembersNotIdentical) {
1408 const std::string body = R"(
1409 %val = OpSMulExtended %struct_s32_u32 %s32_0 %s32_1
1410 )";
1411
1412 CompileSuccessfully(GenerateCode(body).c_str());
1413 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1414 EXPECT_THAT(
1415 getDiagnosticString(),
1416 HasSubstr("Expected Result Type struct member types to be identical: "
1417 "SMulExtended"));
1418 }
1419
1420 } // namespace
1421 } // namespace val
1422 } // namespace spvtools
1423