1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // Tests for unique type declaration rules validator.
16
17 #include <string>
18
19 #include "gmock/gmock.h"
20 #include "test/unit_spirv.h"
21 #include "test/val/val_fixtures.h"
22
23 namespace spvtools {
24 namespace val {
25 namespace {
26
27 using ::testing::HasSubstr;
28 using ::testing::Not;
29
30 using ValidateArithmetics = spvtest::ValidateBase<bool>;
31
GenerateCode(const std::string & main_body)32 std::string GenerateCode(const std::string& main_body) {
33 const std::string prefix =
34 R"(
35 OpCapability Shader
36 OpCapability Int64
37 OpCapability Float64
38 OpCapability Matrix
39 %ext_inst = OpExtInstImport "GLSL.std.450"
40 OpMemoryModel Logical GLSL450
41 OpEntryPoint Fragment %main "main"
42 OpExecutionMode %main OriginUpperLeft
43 %void = OpTypeVoid
44 %func = OpTypeFunction %void
45 %bool = OpTypeBool
46 %f32 = OpTypeFloat 32
47 %u32 = OpTypeInt 32 0
48 %s32 = OpTypeInt 32 1
49 %f64 = OpTypeFloat 64
50 %u64 = OpTypeInt 64 0
51 %s64 = OpTypeInt 64 1
52 %boolvec2 = OpTypeVector %bool 2
53 %s32vec2 = OpTypeVector %s32 2
54 %u32vec2 = OpTypeVector %u32 2
55 %u64vec2 = OpTypeVector %u64 2
56 %f32vec2 = OpTypeVector %f32 2
57 %f64vec2 = OpTypeVector %f64 2
58 %boolvec3 = OpTypeVector %bool 3
59 %u32vec3 = OpTypeVector %u32 3
60 %u64vec3 = OpTypeVector %u64 3
61 %s32vec3 = OpTypeVector %s32 3
62 %f32vec3 = OpTypeVector %f32 3
63 %f64vec3 = OpTypeVector %f64 3
64 %boolvec4 = OpTypeVector %bool 4
65 %u32vec4 = OpTypeVector %u32 4
66 %u64vec4 = OpTypeVector %u64 4
67 %s32vec4 = OpTypeVector %s32 4
68 %f32vec4 = OpTypeVector %f32 4
69 %f64vec4 = OpTypeVector %f64 4
70
71 %f32mat22 = OpTypeMatrix %f32vec2 2
72 %f32mat23 = OpTypeMatrix %f32vec2 3
73 %f32mat32 = OpTypeMatrix %f32vec3 2
74 %f32mat33 = OpTypeMatrix %f32vec3 3
75 %f64mat22 = OpTypeMatrix %f64vec2 2
76
77 %struct_f32_f32 = OpTypeStruct %f32 %f32
78 %struct_u32_u32 = OpTypeStruct %u32 %u32
79 %struct_u32_u32_u32 = OpTypeStruct %u32 %u32 %u32
80 %struct_s32_s32 = OpTypeStruct %s32 %s32
81 %struct_s32_u32 = OpTypeStruct %s32 %u32
82 %struct_u32vec2_u32vec2 = OpTypeStruct %u32vec2 %u32vec2
83 %struct_s32vec2_s32vec2 = OpTypeStruct %s32vec2 %s32vec2
84
85 %f32_0 = OpConstant %f32 0
86 %f32_1 = OpConstant %f32 1
87 %f32_2 = OpConstant %f32 2
88 %f32_3 = OpConstant %f32 3
89 %f32_4 = OpConstant %f32 4
90 %f32_pi = OpConstant %f32 3.14159
91
92 %s32_0 = OpConstant %s32 0
93 %s32_1 = OpConstant %s32 1
94 %s32_2 = OpConstant %s32 2
95 %s32_3 = OpConstant %s32 3
96 %s32_4 = OpConstant %s32 4
97 %s32_m1 = OpConstant %s32 -1
98
99 %u32_0 = OpConstant %u32 0
100 %u32_1 = OpConstant %u32 1
101 %u32_2 = OpConstant %u32 2
102 %u32_3 = OpConstant %u32 3
103 %u32_4 = OpConstant %u32 4
104
105 %f64_0 = OpConstant %f64 0
106 %f64_1 = OpConstant %f64 1
107 %f64_2 = OpConstant %f64 2
108 %f64_3 = OpConstant %f64 3
109 %f64_4 = OpConstant %f64 4
110
111 %s64_0 = OpConstant %s64 0
112 %s64_1 = OpConstant %s64 1
113 %s64_2 = OpConstant %s64 2
114 %s64_3 = OpConstant %s64 3
115 %s64_4 = OpConstant %s64 4
116 %s64_m1 = OpConstant %s64 -1
117
118 %u64_0 = OpConstant %u64 0
119 %u64_1 = OpConstant %u64 1
120 %u64_2 = OpConstant %u64 2
121 %u64_3 = OpConstant %u64 3
122 %u64_4 = OpConstant %u64 4
123
124 %u32vec2_01 = OpConstantComposite %u32vec2 %u32_0 %u32_1
125 %u32vec2_12 = OpConstantComposite %u32vec2 %u32_1 %u32_2
126 %u32vec3_012 = OpConstantComposite %u32vec3 %u32_0 %u32_1 %u32_2
127 %u32vec3_123 = OpConstantComposite %u32vec3 %u32_1 %u32_2 %u32_3
128 %u32vec4_0123 = OpConstantComposite %u32vec4 %u32_0 %u32_1 %u32_2 %u32_3
129 %u32vec4_1234 = OpConstantComposite %u32vec4 %u32_1 %u32_2 %u32_3 %u32_4
130
131 %s32vec2_01 = OpConstantComposite %s32vec2 %s32_0 %s32_1
132 %s32vec2_12 = OpConstantComposite %s32vec2 %s32_1 %s32_2
133 %s32vec3_012 = OpConstantComposite %s32vec3 %s32_0 %s32_1 %s32_2
134 %s32vec3_123 = OpConstantComposite %s32vec3 %s32_1 %s32_2 %s32_3
135 %s32vec4_0123 = OpConstantComposite %s32vec4 %s32_0 %s32_1 %s32_2 %s32_3
136 %s32vec4_1234 = OpConstantComposite %s32vec4 %s32_1 %s32_2 %s32_3 %s32_4
137
138 %f32vec2_01 = OpConstantComposite %f32vec2 %f32_0 %f32_1
139 %f32vec2_12 = OpConstantComposite %f32vec2 %f32_1 %f32_2
140 %f32vec3_012 = OpConstantComposite %f32vec3 %f32_0 %f32_1 %f32_2
141 %f32vec3_123 = OpConstantComposite %f32vec3 %f32_1 %f32_2 %f32_3
142 %f32vec4_0123 = OpConstantComposite %f32vec4 %f32_0 %f32_1 %f32_2 %f32_3
143 %f32vec4_1234 = OpConstantComposite %f32vec4 %f32_1 %f32_2 %f32_3 %f32_4
144
145 %f64vec2_01 = OpConstantComposite %f64vec2 %f64_0 %f64_1
146 %f64vec2_12 = OpConstantComposite %f64vec2 %f64_1 %f64_2
147 %f64vec3_012 = OpConstantComposite %f64vec3 %f64_0 %f64_1 %f64_2
148 %f64vec3_123 = OpConstantComposite %f64vec3 %f64_1 %f64_2 %f64_3
149 %f64vec4_0123 = OpConstantComposite %f64vec4 %f64_0 %f64_1 %f64_2 %f64_3
150 %f64vec4_1234 = OpConstantComposite %f64vec4 %f64_1 %f64_2 %f64_3 %f64_4
151
152 %f32mat22_1212 = OpConstantComposite %f32mat22 %f32vec2_12 %f32vec2_12
153 %f32mat23_121212 = OpConstantComposite %f32mat23 %f32vec2_12 %f32vec2_12 %f32vec2_12
154 %f32mat32_123123 = OpConstantComposite %f32mat32 %f32vec3_123 %f32vec3_123
155 %f32mat33_123123123 = OpConstantComposite %f32mat33 %f32vec3_123 %f32vec3_123 %f32vec3_123
156
157 %f64mat22_1212 = OpConstantComposite %f64mat22 %f64vec2_12 %f64vec2_12
158
159 %main = OpFunction %void None %func
160 %main_entry = OpLabel)";
161
162 const std::string suffix =
163 R"(
164 OpReturn
165 OpFunctionEnd)";
166
167 return prefix + main_body + suffix;
168 }
169
TEST_F(ValidateArithmetics,F32Success)170 TEST_F(ValidateArithmetics, F32Success) {
171 const std::string body = R"(
172 %val1 = OpFMul %f32 %f32_0 %f32_1
173 %val2 = OpFSub %f32 %f32_2 %f32_0
174 %val3 = OpFAdd %f32 %val1 %val2
175 %val4 = OpFNegate %f32 %val3
176 %val5 = OpFDiv %f32 %val4 %val1
177 %val6 = OpFRem %f32 %val4 %f32_2
178 %val7 = OpFMod %f32 %val4 %f32_2
179 )";
180
181 CompileSuccessfully(GenerateCode(body).c_str());
182 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
183 }
184
TEST_F(ValidateArithmetics,F64Success)185 TEST_F(ValidateArithmetics, F64Success) {
186 const std::string body = R"(
187 %val1 = OpFMul %f64 %f64_0 %f64_1
188 %val2 = OpFSub %f64 %f64_2 %f64_0
189 %val3 = OpFAdd %f64 %val1 %val2
190 %val4 = OpFNegate %f64 %val3
191 %val5 = OpFDiv %f64 %val4 %val1
192 %val6 = OpFRem %f64 %val4 %f64_2
193 %val7 = OpFMod %f64 %val4 %f64_2
194 )";
195
196 CompileSuccessfully(GenerateCode(body).c_str());
197 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
198 }
199
TEST_F(ValidateArithmetics,Int32Success)200 TEST_F(ValidateArithmetics, Int32Success) {
201 const std::string body = R"(
202 %val1 = OpIMul %u32 %s32_0 %u32_1
203 %val2 = OpIMul %s32 %s32_2 %u32_1
204 %val3 = OpIAdd %u32 %val1 %val2
205 %val4 = OpIAdd %s32 %val1 %val2
206 %val5 = OpISub %u32 %val3 %val4
207 %val6 = OpISub %s32 %val4 %val3
208 %val7 = OpSDiv %s32 %val4 %val3
209 %val8 = OpSNegate %s32 %val7
210 %val9 = OpSRem %s32 %val4 %val3
211 %val10 = OpSMod %s32 %val4 %val3
212 )";
213
214 CompileSuccessfully(GenerateCode(body).c_str());
215 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
216 }
217
TEST_F(ValidateArithmetics,Int64Success)218 TEST_F(ValidateArithmetics, Int64Success) {
219 const std::string body = R"(
220 %val1 = OpIMul %u64 %s64_0 %u64_1
221 %val2 = OpIMul %s64 %s64_2 %u64_1
222 %val3 = OpIAdd %u64 %val1 %val2
223 %val4 = OpIAdd %s64 %val1 %val2
224 %val5 = OpISub %u64 %val3 %val4
225 %val6 = OpISub %s64 %val4 %val3
226 %val7 = OpSDiv %s64 %val4 %val3
227 %val8 = OpSNegate %s64 %val7
228 %val9 = OpSRem %s64 %val4 %val3
229 %val10 = OpSMod %s64 %val4 %val3
230 )";
231
232 CompileSuccessfully(GenerateCode(body).c_str());
233 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
234 }
235
TEST_F(ValidateArithmetics,F32Vec2Success)236 TEST_F(ValidateArithmetics, F32Vec2Success) {
237 const std::string body = R"(
238 %val1 = OpFMul %f32vec2 %f32vec2_01 %f32vec2_12
239 %val2 = OpFSub %f32vec2 %f32vec2_12 %f32vec2_01
240 %val3 = OpFAdd %f32vec2 %val1 %val2
241 %val4 = OpFNegate %f32vec2 %val3
242 %val5 = OpFDiv %f32vec2 %val4 %val1
243 %val6 = OpFRem %f32vec2 %val4 %f32vec2_12
244 %val7 = OpFMod %f32vec2 %val4 %f32vec2_12
245 )";
246
247 CompileSuccessfully(GenerateCode(body).c_str());
248 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
249 }
250
TEST_F(ValidateArithmetics,F64Vec2Success)251 TEST_F(ValidateArithmetics, F64Vec2Success) {
252 const std::string body = R"(
253 %val1 = OpFMul %f64vec2 %f64vec2_01 %f64vec2_12
254 %val2 = OpFSub %f64vec2 %f64vec2_12 %f64vec2_01
255 %val3 = OpFAdd %f64vec2 %val1 %val2
256 %val4 = OpFNegate %f64vec2 %val3
257 %val5 = OpFDiv %f64vec2 %val4 %val1
258 %val6 = OpFRem %f64vec2 %val4 %f64vec2_12
259 %val7 = OpFMod %f64vec2 %val4 %f64vec2_12
260 )";
261
262 CompileSuccessfully(GenerateCode(body).c_str());
263 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
264 }
265
TEST_F(ValidateArithmetics,U32Vec2Success)266 TEST_F(ValidateArithmetics, U32Vec2Success) {
267 const std::string body = R"(
268 %val1 = OpIMul %u32vec2 %u32vec2_01 %u32vec2_12
269 %val2 = OpISub %u32vec2 %u32vec2_12 %u32vec2_01
270 %val3 = OpIAdd %u32vec2 %val1 %val2
271 %val4 = OpSNegate %u32vec2 %val3
272 %val5 = OpSDiv %u32vec2 %val4 %val1
273 %val6 = OpSRem %u32vec2 %val4 %u32vec2_12
274 %val7 = OpSMod %u32vec2 %val4 %u32vec2_12
275 )";
276
277 CompileSuccessfully(GenerateCode(body).c_str());
278 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
279 }
280
TEST_F(ValidateArithmetics,FNegateTypeIdU32)281 TEST_F(ValidateArithmetics, FNegateTypeIdU32) {
282 const std::string body = R"(
283 %val = OpFNegate %u32 %u32_0
284 )";
285
286 CompileSuccessfully(GenerateCode(body).c_str());
287 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
288 EXPECT_THAT(
289 getDiagnosticString(),
290 HasSubstr(
291 "Expected floating scalar or vector type as Result Type: FNegate"));
292 }
293
TEST_F(ValidateArithmetics,FNegateTypeIdVec2U32)294 TEST_F(ValidateArithmetics, FNegateTypeIdVec2U32) {
295 const std::string body = R"(
296 %val = OpFNegate %u32vec2 %u32vec2_01
297 )";
298
299 CompileSuccessfully(GenerateCode(body).c_str());
300 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
301 EXPECT_THAT(
302 getDiagnosticString(),
303 HasSubstr(
304 "Expected floating scalar or vector type as Result Type: FNegate"));
305 }
306
TEST_F(ValidateArithmetics,FNegateWrongOperand)307 TEST_F(ValidateArithmetics, FNegateWrongOperand) {
308 const std::string body = R"(
309 %val = OpFNegate %f32 %u32_0
310 )";
311
312 CompileSuccessfully(GenerateCode(body).c_str());
313 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
314 EXPECT_THAT(getDiagnosticString(),
315 HasSubstr("Expected arithmetic operands to be of Result Type: "
316 "FNegate operand index 2"));
317 }
318
TEST_F(ValidateArithmetics,FMulTypeIdU32)319 TEST_F(ValidateArithmetics, FMulTypeIdU32) {
320 const std::string body = R"(
321 %val = OpFMul %u32 %u32_0 %u32_1
322 )";
323
324 CompileSuccessfully(GenerateCode(body).c_str());
325 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
326 EXPECT_THAT(
327 getDiagnosticString(),
328 HasSubstr(
329 "Expected floating scalar or vector type as Result Type: FMul"));
330 }
331
TEST_F(ValidateArithmetics,FMulTypeIdVec2U32)332 TEST_F(ValidateArithmetics, FMulTypeIdVec2U32) {
333 const std::string body = R"(
334 %val = OpFMul %u32vec2 %u32vec2_01 %u32vec2_12
335 )";
336
337 CompileSuccessfully(GenerateCode(body).c_str());
338 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
339 EXPECT_THAT(
340 getDiagnosticString(),
341 HasSubstr(
342 "Expected floating scalar or vector type as Result Type: FMul"));
343 }
344
TEST_F(ValidateArithmetics,FMulWrongOperand1)345 TEST_F(ValidateArithmetics, FMulWrongOperand1) {
346 const std::string body = R"(
347 %val = OpFMul %f32 %u32_0 %f32_1
348 )";
349
350 CompileSuccessfully(GenerateCode(body).c_str());
351 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
352 EXPECT_THAT(getDiagnosticString(),
353 HasSubstr("Expected arithmetic operands to be of Result Type: "
354 "FMul operand index 2"));
355 }
356
TEST_F(ValidateArithmetics,FMulWrongOperand2)357 TEST_F(ValidateArithmetics, FMulWrongOperand2) {
358 const std::string body = R"(
359 %val = OpFMul %f32 %f32_0 %u32_1
360 )";
361
362 CompileSuccessfully(GenerateCode(body).c_str());
363 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
364 EXPECT_THAT(getDiagnosticString(),
365 HasSubstr("Expected arithmetic operands to be of Result Type: "
366 "FMul operand index 3"));
367 }
368
TEST_F(ValidateArithmetics,FMulWrongVectorOperand1)369 TEST_F(ValidateArithmetics, FMulWrongVectorOperand1) {
370 const std::string body = R"(
371 %val = OpFMul %f64vec3 %f32vec3_123 %f64vec3_012
372 )";
373
374 CompileSuccessfully(GenerateCode(body).c_str());
375 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
376 EXPECT_THAT(getDiagnosticString(),
377 HasSubstr("Expected arithmetic operands to be of Result Type: "
378 "FMul operand index 2"));
379 }
380
TEST_F(ValidateArithmetics,FMulWrongVectorOperand2)381 TEST_F(ValidateArithmetics, FMulWrongVectorOperand2) {
382 const std::string body = R"(
383 %val = OpFMul %f32vec3 %f32vec3_123 %f64vec3_012
384 )";
385
386 CompileSuccessfully(GenerateCode(body).c_str());
387 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
388 EXPECT_THAT(getDiagnosticString(),
389 HasSubstr("Expected arithmetic operands to be of Result Type: "
390 "FMul operand index 3"));
391 }
392
TEST_F(ValidateArithmetics,IMulFloatTypeId)393 TEST_F(ValidateArithmetics, IMulFloatTypeId) {
394 const std::string body = R"(
395 %val = OpIMul %f32 %u32_0 %s32_1
396 )";
397
398 CompileSuccessfully(GenerateCode(body).c_str());
399 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
400 EXPECT_THAT(
401 getDiagnosticString(),
402 HasSubstr("Expected int scalar or vector type as Result Type: IMul"));
403 }
404
TEST_F(ValidateArithmetics,IMulFloatOperand1)405 TEST_F(ValidateArithmetics, IMulFloatOperand1) {
406 const std::string body = R"(
407 %val = OpIMul %u32 %f32_0 %s32_1
408 )";
409
410 CompileSuccessfully(GenerateCode(body).c_str());
411 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
412 EXPECT_THAT(getDiagnosticString(),
413 HasSubstr("Expected int scalar or vector type as operand: "
414 "IMul operand index 2"));
415 }
416
TEST_F(ValidateArithmetics,IMulFloatOperand2)417 TEST_F(ValidateArithmetics, IMulFloatOperand2) {
418 const std::string body = R"(
419 %val = OpIMul %u32 %s32_0 %f32_1
420 )";
421
422 CompileSuccessfully(GenerateCode(body).c_str());
423 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
424 EXPECT_THAT(getDiagnosticString(),
425 HasSubstr("Expected int scalar or vector type as operand: "
426 "IMul operand index 3"));
427 }
428
TEST_F(ValidateArithmetics,IMulWrongBitWidthOperand1)429 TEST_F(ValidateArithmetics, IMulWrongBitWidthOperand1) {
430 const std::string body = R"(
431 %val = OpIMul %u64 %u32_0 %s64_1
432 )";
433
434 CompileSuccessfully(GenerateCode(body).c_str());
435 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
436 EXPECT_THAT(
437 getDiagnosticString(),
438 HasSubstr("Expected arithmetic operands to have the same bit width "
439 "as Result Type: IMul operand index 2"));
440 }
441
TEST_F(ValidateArithmetics,IMulWrongBitWidthOperand2)442 TEST_F(ValidateArithmetics, IMulWrongBitWidthOperand2) {
443 const std::string body = R"(
444 %val = OpIMul %u32 %u32_0 %s64_1
445 )";
446
447 CompileSuccessfully(GenerateCode(body).c_str());
448 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
449 EXPECT_THAT(
450 getDiagnosticString(),
451 HasSubstr("Expected arithmetic operands to have the same bit width "
452 "as Result Type: IMul operand index 3"));
453 }
454
TEST_F(ValidateArithmetics,IMulWrongBitWidthVector)455 TEST_F(ValidateArithmetics, IMulWrongBitWidthVector) {
456 const std::string body = R"(
457 %val = OpIMul %u64vec3 %u32vec3_012 %u32vec3_123
458 )";
459
460 CompileSuccessfully(GenerateCode(body).c_str());
461 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
462 EXPECT_THAT(
463 getDiagnosticString(),
464 HasSubstr("Expected arithmetic operands to have the same bit width "
465 "as Result Type: IMul operand index 2"));
466 }
467
TEST_F(ValidateArithmetics,IMulVectorScalarOperand1)468 TEST_F(ValidateArithmetics, IMulVectorScalarOperand1) {
469 const std::string body = R"(
470 %val = OpIMul %u32vec2 %u32_0 %u32vec2_01
471 )";
472
473 CompileSuccessfully(GenerateCode(body).c_str());
474 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
475 EXPECT_THAT(
476 getDiagnosticString(),
477 HasSubstr("Expected arithmetic operands to have the same dimension "
478 "as Result Type: IMul operand index 2"));
479 }
480
TEST_F(ValidateArithmetics,IMulVectorScalarOperand2)481 TEST_F(ValidateArithmetics, IMulVectorScalarOperand2) {
482 const std::string body = R"(
483 %val = OpIMul %u32vec2 %u32vec2_01 %u32_0
484 )";
485
486 CompileSuccessfully(GenerateCode(body).c_str());
487 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
488 EXPECT_THAT(
489 getDiagnosticString(),
490 HasSubstr("Expected arithmetic operands to have the same dimension "
491 "as Result Type: IMul operand index 3"));
492 }
493
TEST_F(ValidateArithmetics,IMulScalarVectorOperand1)494 TEST_F(ValidateArithmetics, IMulScalarVectorOperand1) {
495 const std::string body = R"(
496 %val = OpIMul %s32 %u32vec2_01 %u32_0
497 )";
498
499 CompileSuccessfully(GenerateCode(body).c_str());
500 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
501 EXPECT_THAT(
502 getDiagnosticString(),
503 HasSubstr("Expected arithmetic operands to have the same dimension "
504 "as Result Type: IMul operand index 2"));
505 }
506
TEST_F(ValidateArithmetics,IMulScalarVectorOperand2)507 TEST_F(ValidateArithmetics, IMulScalarVectorOperand2) {
508 const std::string body = R"(
509 %val = OpIMul %u32 %u32_0 %s32vec2_01
510 )";
511
512 CompileSuccessfully(GenerateCode(body).c_str());
513 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
514 EXPECT_THAT(
515 getDiagnosticString(),
516 HasSubstr("Expected arithmetic operands to have the same dimension "
517 "as Result Type: IMul operand index 3"));
518 }
519
TEST_F(ValidateArithmetics,SNegateFloat)520 TEST_F(ValidateArithmetics, SNegateFloat) {
521 const std::string body = R"(
522 %val = OpSNegate %s32 %f32_1
523 )";
524
525 CompileSuccessfully(GenerateCode(body).c_str());
526 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
527 EXPECT_THAT(getDiagnosticString(),
528 HasSubstr("Expected int scalar or vector type as operand: "
529 "SNegate operand index 2"));
530 }
531
TEST_F(ValidateArithmetics,UDivFloatType)532 TEST_F(ValidateArithmetics, UDivFloatType) {
533 const std::string body = R"(
534 %val = OpUDiv %f32 %u32_2 %u32_1
535 )";
536
537 CompileSuccessfully(GenerateCode(body).c_str());
538 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
539 EXPECT_THAT(
540 getDiagnosticString(),
541 HasSubstr(
542 "Expected unsigned int scalar or vector type as Result Type: UDiv"));
543 }
544
TEST_F(ValidateArithmetics,UDivSignedIntType)545 TEST_F(ValidateArithmetics, UDivSignedIntType) {
546 const std::string body = R"(
547 %val = OpUDiv %s32 %u32_2 %u32_1
548 )";
549
550 CompileSuccessfully(GenerateCode(body).c_str());
551 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
552 EXPECT_THAT(
553 getDiagnosticString(),
554 HasSubstr(
555 "Expected unsigned int scalar or vector type as Result Type: UDiv"));
556 }
557
TEST_F(ValidateArithmetics,UDivWrongOperand1)558 TEST_F(ValidateArithmetics, UDivWrongOperand1) {
559 const std::string body = R"(
560 %val = OpUDiv %u64 %f64_2 %u64_1
561 )";
562
563 CompileSuccessfully(GenerateCode(body).c_str());
564 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
565 EXPECT_THAT(getDiagnosticString(),
566 HasSubstr("Expected arithmetic operands to be of Result Type: "
567 "UDiv operand index 2"));
568 }
569
TEST_F(ValidateArithmetics,UDivWrongOperand2)570 TEST_F(ValidateArithmetics, UDivWrongOperand2) {
571 const std::string body = R"(
572 %val = OpUDiv %u64 %u64_2 %u32_1
573 )";
574
575 CompileSuccessfully(GenerateCode(body).c_str());
576 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
577 EXPECT_THAT(getDiagnosticString(),
578 HasSubstr("Expected arithmetic operands to be of Result Type: "
579 "UDiv operand index 3"));
580 }
581
TEST_F(ValidateArithmetics,DotSuccess)582 TEST_F(ValidateArithmetics, DotSuccess) {
583 const std::string body = R"(
584 %val = OpDot %f32 %f32vec2_01 %f32vec2_12
585 )";
586
587 CompileSuccessfully(GenerateCode(body).c_str());
588 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
589 }
590
TEST_F(ValidateArithmetics,DotWrongTypeId)591 TEST_F(ValidateArithmetics, DotWrongTypeId) {
592 const std::string body = R"(
593 %val = OpDot %u32 %u32vec2_01 %u32vec2_12
594 )";
595
596 CompileSuccessfully(GenerateCode(body).c_str());
597 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
598 EXPECT_THAT(getDiagnosticString(),
599 HasSubstr("Expected float scalar type as Result Type: Dot"));
600 }
601
TEST_F(ValidateArithmetics,DotNotVectorTypeOperand1)602 TEST_F(ValidateArithmetics, DotNotVectorTypeOperand1) {
603 const std::string body = R"(
604 %val = OpDot %f32 %f32 %f32vec2_12
605 )";
606
607 CompileSuccessfully(GenerateCode(body).c_str());
608 ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
609 EXPECT_THAT(getDiagnosticString(),
610 HasSubstr("Operand '6[%float]' cannot be a "
611 "type"));
612 }
613
TEST_F(ValidateArithmetics,DotNotVectorTypeOperand2)614 TEST_F(ValidateArithmetics, DotNotVectorTypeOperand2) {
615 const std::string body = R"(
616 %val = OpDot %f32 %f32vec3_012 %f32_1
617 )";
618
619 CompileSuccessfully(GenerateCode(body).c_str());
620 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
621 EXPECT_THAT(
622 getDiagnosticString(),
623 HasSubstr("Expected float vector as operand: Dot operand index 3"));
624 }
625
TEST_F(ValidateArithmetics,DotWrongComponentOperand1)626 TEST_F(ValidateArithmetics, DotWrongComponentOperand1) {
627 const std::string body = R"(
628 %val = OpDot %f64 %f32vec2_01 %f64vec2_12
629 )";
630
631 CompileSuccessfully(GenerateCode(body).c_str());
632 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
633 EXPECT_THAT(getDiagnosticString(),
634 HasSubstr("Expected component type to be equal to Result Type: "
635 "Dot operand index 2"));
636 }
637
TEST_F(ValidateArithmetics,DotWrongComponentOperand2)638 TEST_F(ValidateArithmetics, DotWrongComponentOperand2) {
639 const std::string body = R"(
640 %val = OpDot %f32 %f32vec2_01 %f64vec2_12
641 )";
642
643 CompileSuccessfully(GenerateCode(body).c_str());
644 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
645 EXPECT_THAT(getDiagnosticString(),
646 HasSubstr("Expected component type to be equal to Result Type: "
647 "Dot operand index 3"));
648 }
649
TEST_F(ValidateArithmetics,DotDifferentVectorSize)650 TEST_F(ValidateArithmetics, DotDifferentVectorSize) {
651 const std::string body = R"(
652 %val = OpDot %f32 %f32vec2_01 %f32vec3_123
653 )";
654
655 CompileSuccessfully(GenerateCode(body).c_str());
656 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
657 EXPECT_THAT(
658 getDiagnosticString(),
659 HasSubstr(
660 "Expected operands to have the same number of components: Dot"));
661 }
662
TEST_F(ValidateArithmetics,VectorTimesScalarSuccess)663 TEST_F(ValidateArithmetics, VectorTimesScalarSuccess) {
664 const std::string body = R"(
665 %val = OpVectorTimesScalar %f32vec2 %f32vec2_01 %f32_2
666 )";
667
668 CompileSuccessfully(GenerateCode(body).c_str());
669 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
670 }
671
TEST_F(ValidateArithmetics,VectorTimesScalarWrongTypeId)672 TEST_F(ValidateArithmetics, VectorTimesScalarWrongTypeId) {
673 const std::string body = R"(
674 %val = OpVectorTimesScalar %u32vec2 %f32vec2_01 %f32_2
675 )";
676
677 CompileSuccessfully(GenerateCode(body).c_str());
678 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
679 EXPECT_THAT(getDiagnosticString(),
680 HasSubstr("Expected float vector type as Result Type: "
681 "VectorTimesScalar"));
682 }
683
TEST_F(ValidateArithmetics,VectorTimesScalarWrongVector)684 TEST_F(ValidateArithmetics, VectorTimesScalarWrongVector) {
685 const std::string body = R"(
686 %val = OpVectorTimesScalar %f32vec2 %f32vec3_012 %f32_2
687 )";
688
689 CompileSuccessfully(GenerateCode(body).c_str());
690 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
691 EXPECT_THAT(
692 getDiagnosticString(),
693 HasSubstr("Expected vector operand type to be equal to Result Type: "
694 "VectorTimesScalar"));
695 }
696
TEST_F(ValidateArithmetics,VectorTimesScalarWrongScalar)697 TEST_F(ValidateArithmetics, VectorTimesScalarWrongScalar) {
698 const std::string body = R"(
699 %val = OpVectorTimesScalar %f32vec2 %f32vec2_01 %f64_2
700 )";
701
702 CompileSuccessfully(GenerateCode(body).c_str());
703 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
704 EXPECT_THAT(
705 getDiagnosticString(),
706 HasSubstr("Expected scalar operand type to be equal to the component "
707 "type of the vector operand: VectorTimesScalar"));
708 }
709
TEST_F(ValidateArithmetics,MatrixTimesScalarSuccess)710 TEST_F(ValidateArithmetics, MatrixTimesScalarSuccess) {
711 const std::string body = R"(
712 %val = OpMatrixTimesScalar %f32mat22 %f32mat22_1212 %f32_2
713 )";
714
715 CompileSuccessfully(GenerateCode(body).c_str());
716 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
717 }
718
TEST_F(ValidateArithmetics,MatrixTimesScalarWrongTypeId)719 TEST_F(ValidateArithmetics, MatrixTimesScalarWrongTypeId) {
720 const std::string body = R"(
721 %val = OpMatrixTimesScalar %f32vec2 %f32mat22_1212 %f32_2
722 )";
723
724 CompileSuccessfully(GenerateCode(body).c_str());
725 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
726 EXPECT_THAT(getDiagnosticString(),
727 HasSubstr("Expected float matrix type as Result Type: "
728 "MatrixTimesScalar"));
729 }
730
TEST_F(ValidateArithmetics,MatrixTimesScalarWrongMatrix)731 TEST_F(ValidateArithmetics, MatrixTimesScalarWrongMatrix) {
732 const std::string body = R"(
733 %val = OpMatrixTimesScalar %f32mat22 %f32vec2_01 %f32_2
734 )";
735
736 CompileSuccessfully(GenerateCode(body).c_str());
737 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
738 EXPECT_THAT(
739 getDiagnosticString(),
740 HasSubstr("Expected matrix operand type to be equal to Result Type: "
741 "MatrixTimesScalar"));
742 }
743
TEST_F(ValidateArithmetics,MatrixTimesScalarWrongScalar)744 TEST_F(ValidateArithmetics, MatrixTimesScalarWrongScalar) {
745 const std::string body = R"(
746 %val = OpMatrixTimesScalar %f32mat22 %f32mat22_1212 %f64_2
747 )";
748
749 CompileSuccessfully(GenerateCode(body).c_str());
750 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
751 EXPECT_THAT(
752 getDiagnosticString(),
753 HasSubstr("Expected scalar operand type to be equal to the component "
754 "type of the matrix operand: MatrixTimesScalar"));
755 }
756
TEST_F(ValidateArithmetics,VectorTimesMatrix2x22Success)757 TEST_F(ValidateArithmetics, VectorTimesMatrix2x22Success) {
758 const std::string body = R"(
759 %val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f32mat22_1212
760 )";
761
762 CompileSuccessfully(GenerateCode(body).c_str());
763 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
764 }
765
TEST_F(ValidateArithmetics,VectorTimesMatrix3x32Success)766 TEST_F(ValidateArithmetics, VectorTimesMatrix3x32Success) {
767 const std::string body = R"(
768 %val = OpVectorTimesMatrix %f32vec2 %f32vec3_123 %f32mat32_123123
769 )";
770
771 CompileSuccessfully(GenerateCode(body).c_str());
772 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
773 }
774
TEST_F(ValidateArithmetics,VectorTimesMatrixWrongTypeId)775 TEST_F(ValidateArithmetics, VectorTimesMatrixWrongTypeId) {
776 const std::string body = R"(
777 %val = OpVectorTimesMatrix %f32mat22 %f32vec2_12 %f32mat22_1212
778 )";
779
780 CompileSuccessfully(GenerateCode(body).c_str());
781 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
782 EXPECT_THAT(getDiagnosticString(),
783 HasSubstr("Expected float vector type as Result Type: "
784 "VectorTimesMatrix"));
785 }
786
TEST_F(ValidateArithmetics,VectorTimesMatrixNotFloatVector)787 TEST_F(ValidateArithmetics, VectorTimesMatrixNotFloatVector) {
788 const std::string body = R"(
789 %val = OpVectorTimesMatrix %f32vec2 %u32vec2_12 %f32mat22_1212
790 )";
791
792 CompileSuccessfully(GenerateCode(body).c_str());
793 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
794 EXPECT_THAT(getDiagnosticString(),
795 HasSubstr("Expected float vector type as left operand: "
796 "VectorTimesMatrix"));
797 }
798
TEST_F(ValidateArithmetics,VectorTimesMatrixWrongVectorComponent)799 TEST_F(ValidateArithmetics, VectorTimesMatrixWrongVectorComponent) {
800 const std::string body = R"(
801 %val = OpVectorTimesMatrix %f32vec2 %f64vec2_12 %f32mat22_1212
802 )";
803
804 CompileSuccessfully(GenerateCode(body).c_str());
805 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
806 EXPECT_THAT(
807 getDiagnosticString(),
808 HasSubstr(
809 "Expected component types of Result Type and vector to be equal: "
810 "VectorTimesMatrix"));
811 }
812
TEST_F(ValidateArithmetics,VectorTimesMatrixWrongMatrix)813 TEST_F(ValidateArithmetics, VectorTimesMatrixWrongMatrix) {
814 const std::string body = R"(
815 %val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f32vec2_12
816 )";
817
818 CompileSuccessfully(GenerateCode(body).c_str());
819 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
820 EXPECT_THAT(getDiagnosticString(),
821 HasSubstr("Expected float matrix type as right operand: "
822 "VectorTimesMatrix"));
823 }
824
TEST_F(ValidateArithmetics,VectorTimesMatrixWrongMatrixComponent)825 TEST_F(ValidateArithmetics, VectorTimesMatrixWrongMatrixComponent) {
826 const std::string body = R"(
827 %val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f64mat22_1212
828 )";
829
830 CompileSuccessfully(GenerateCode(body).c_str());
831 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
832 EXPECT_THAT(
833 getDiagnosticString(),
834 HasSubstr(
835 "Expected component types of Result Type and matrix to be equal: "
836 "VectorTimesMatrix"));
837 }
838
TEST_F(ValidateArithmetics,VectorTimesMatrix2eq2x23Fail)839 TEST_F(ValidateArithmetics, VectorTimesMatrix2eq2x23Fail) {
840 const std::string body = R"(
841 %val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f32mat23_121212
842 )";
843
844 CompileSuccessfully(GenerateCode(body).c_str());
845 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
846 EXPECT_THAT(
847 getDiagnosticString(),
848 HasSubstr(
849 "Expected number of columns of the matrix to be equal to Result Type "
850 "vector size: VectorTimesMatrix"));
851 }
852
TEST_F(ValidateArithmetics,VectorTimesMatrix2x32Fail)853 TEST_F(ValidateArithmetics, VectorTimesMatrix2x32Fail) {
854 const std::string body = R"(
855 %val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f32mat32_123123
856 )";
857
858 CompileSuccessfully(GenerateCode(body).c_str());
859 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
860 EXPECT_THAT(
861 getDiagnosticString(),
862 HasSubstr(
863 "Expected number of rows of the matrix to be equal to the vector "
864 "operand size: VectorTimesMatrix"));
865 }
866
TEST_F(ValidateArithmetics,MatrixTimesVector22x2Success)867 TEST_F(ValidateArithmetics, MatrixTimesVector22x2Success) {
868 const std::string body = R"(
869 %val = OpMatrixTimesVector %f32vec2 %f32mat22_1212 %f32vec2_12
870 )";
871
872 CompileSuccessfully(GenerateCode(body).c_str());
873 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
874 }
875
TEST_F(ValidateArithmetics,MatrixTimesVector23x3Success)876 TEST_F(ValidateArithmetics, MatrixTimesVector23x3Success) {
877 const std::string body = R"(
878 %val = OpMatrixTimesVector %f32vec2 %f32mat23_121212 %f32vec3_123
879 )";
880
881 CompileSuccessfully(GenerateCode(body).c_str());
882 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
883 }
884
TEST_F(ValidateArithmetics,MatrixTimesVectorWrongTypeId)885 TEST_F(ValidateArithmetics, MatrixTimesVectorWrongTypeId) {
886 const std::string body = R"(
887 %val = OpMatrixTimesVector %f32mat22 %f32mat22_1212 %f32vec2_12
888 )";
889
890 CompileSuccessfully(GenerateCode(body).c_str());
891 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
892 EXPECT_THAT(getDiagnosticString(),
893 HasSubstr("Expected float vector type as Result Type: "
894 "MatrixTimesVector"));
895 }
896
TEST_F(ValidateArithmetics,MatrixTimesVectorWrongMatrix)897 TEST_F(ValidateArithmetics, MatrixTimesVectorWrongMatrix) {
898 const std::string body = R"(
899 %val = OpMatrixTimesVector %f32vec3 %f32vec3_123 %f32vec3_123
900 )";
901
902 CompileSuccessfully(GenerateCode(body).c_str());
903 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
904 EXPECT_THAT(getDiagnosticString(),
905 HasSubstr("Expected float matrix type as left operand: "
906 "MatrixTimesVector"));
907 }
908
TEST_F(ValidateArithmetics,MatrixTimesVectorWrongMatrixCol)909 TEST_F(ValidateArithmetics, MatrixTimesVectorWrongMatrixCol) {
910 const std::string body = R"(
911 %val = OpMatrixTimesVector %f32vec3 %f32mat23_121212 %f32vec3_123
912 )";
913
914 CompileSuccessfully(GenerateCode(body).c_str());
915 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
916 EXPECT_THAT(
917 getDiagnosticString(),
918 HasSubstr(
919 "Expected column type of the matrix to be equal to Result Type: "
920 "MatrixTimesVector"));
921 }
922
TEST_F(ValidateArithmetics,MatrixTimesVectorWrongVector)923 TEST_F(ValidateArithmetics, MatrixTimesVectorWrongVector) {
924 const std::string body = R"(
925 %val = OpMatrixTimesVector %f32vec2 %f32mat22_1212 %u32vec2_12
926 )";
927
928 CompileSuccessfully(GenerateCode(body).c_str());
929 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
930 EXPECT_THAT(getDiagnosticString(),
931 HasSubstr("Expected float vector type as right operand: "
932 "MatrixTimesVector"));
933 }
934
TEST_F(ValidateArithmetics,MatrixTimesVectorDifferentComponents)935 TEST_F(ValidateArithmetics, MatrixTimesVectorDifferentComponents) {
936 const std::string body = R"(
937 %val = OpMatrixTimesVector %f32vec2 %f32mat22_1212 %f64vec2_12
938 )";
939
940 CompileSuccessfully(GenerateCode(body).c_str());
941 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
942 EXPECT_THAT(getDiagnosticString(),
943 HasSubstr("Expected component types of the operands to be equal: "
944 "MatrixTimesVector"));
945 }
946
TEST_F(ValidateArithmetics,MatrixTimesVector22x3Fail)947 TEST_F(ValidateArithmetics, MatrixTimesVector22x3Fail) {
948 const std::string body = R"(
949 %val = OpMatrixTimesVector %f32vec2 %f32mat22_1212 %f32vec3_123
950 )";
951
952 CompileSuccessfully(GenerateCode(body).c_str());
953 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
954 EXPECT_THAT(
955 getDiagnosticString(),
956 HasSubstr(
957 "Expected number of columns of the matrix to be equal to the vector "
958 "size: MatrixTimesVector"));
959 }
960
TEST_F(ValidateArithmetics,MatrixTimesMatrix22x22Success)961 TEST_F(ValidateArithmetics, MatrixTimesMatrix22x22Success) {
962 const std::string body = R"(
963 %val = OpMatrixTimesMatrix %f32mat22 %f32mat22_1212 %f32mat22_1212
964 )";
965
966 CompileSuccessfully(GenerateCode(body).c_str());
967 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
968 }
969
TEST_F(ValidateArithmetics,MatrixTimesMatrix23x32Success)970 TEST_F(ValidateArithmetics, MatrixTimesMatrix23x32Success) {
971 const std::string body = R"(
972 %val = OpMatrixTimesMatrix %f32mat22 %f32mat23_121212 %f32mat32_123123
973 )";
974
975 CompileSuccessfully(GenerateCode(body).c_str());
976 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
977 }
978
TEST_F(ValidateArithmetics,MatrixTimesMatrix33x33Success)979 TEST_F(ValidateArithmetics, MatrixTimesMatrix33x33Success) {
980 const std::string body = R"(
981 %val = OpMatrixTimesMatrix %f32mat33 %f32mat33_123123123 %f32mat33_123123123
982 )";
983
984 CompileSuccessfully(GenerateCode(body).c_str());
985 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
986 }
987
TEST_F(ValidateArithmetics,MatrixTimesMatrixWrongTypeId)988 TEST_F(ValidateArithmetics, MatrixTimesMatrixWrongTypeId) {
989 const std::string body = R"(
990 %val = OpMatrixTimesMatrix %f32vec2 %f32mat22_1212 %f32mat22_1212
991 )";
992
993 CompileSuccessfully(GenerateCode(body).c_str());
994 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
995 EXPECT_THAT(
996 getDiagnosticString(),
997 HasSubstr(
998 "Expected float matrix type as Result Type: MatrixTimesMatrix"));
999 }
1000
TEST_F(ValidateArithmetics,MatrixTimesMatrixWrongLeftOperand)1001 TEST_F(ValidateArithmetics, MatrixTimesMatrixWrongLeftOperand) {
1002 const std::string body = R"(
1003 %val = OpMatrixTimesMatrix %f32mat22 %f32vec2_12 %f32mat22_1212
1004 )";
1005
1006 CompileSuccessfully(GenerateCode(body).c_str());
1007 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1008 EXPECT_THAT(
1009 getDiagnosticString(),
1010 HasSubstr(
1011 "Expected float matrix type as left operand: MatrixTimesMatrix"));
1012 }
1013
TEST_F(ValidateArithmetics,MatrixTimesMatrixWrongRightOperand)1014 TEST_F(ValidateArithmetics, MatrixTimesMatrixWrongRightOperand) {
1015 const std::string body = R"(
1016 %val = OpMatrixTimesMatrix %f32mat22 %f32mat22_1212 %f32vec2_12
1017 )";
1018
1019 CompileSuccessfully(GenerateCode(body).c_str());
1020 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1021 EXPECT_THAT(
1022 getDiagnosticString(),
1023 HasSubstr(
1024 "Expected float matrix type as right operand: MatrixTimesMatrix"));
1025 }
1026
TEST_F(ValidateArithmetics,MatrixTimesMatrix32x23Fail)1027 TEST_F(ValidateArithmetics, MatrixTimesMatrix32x23Fail) {
1028 const std::string body = R"(
1029 %val = OpMatrixTimesMatrix %f32mat22 %f32mat32_123123 %f32mat23_121212
1030 )";
1031
1032 CompileSuccessfully(GenerateCode(body).c_str());
1033 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1034 EXPECT_THAT(
1035 getDiagnosticString(),
1036 HasSubstr(
1037 "Expected column types of Result Type and left matrix to be equal: "
1038 "MatrixTimesMatrix"));
1039 }
1040
TEST_F(ValidateArithmetics,MatrixTimesMatrixDifferentComponents)1041 TEST_F(ValidateArithmetics, MatrixTimesMatrixDifferentComponents) {
1042 const std::string body = R"(
1043 %val = OpMatrixTimesMatrix %f32mat22 %f32mat22_1212 %f64mat22_1212
1044 )";
1045
1046 CompileSuccessfully(GenerateCode(body).c_str());
1047 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1048 EXPECT_THAT(getDiagnosticString(),
1049 HasSubstr("Expected component types of Result Type and right "
1050 "matrix to be equal: "
1051 "MatrixTimesMatrix"));
1052 }
1053
TEST_F(ValidateArithmetics,MatrixTimesMatrix23x23Fail)1054 TEST_F(ValidateArithmetics, MatrixTimesMatrix23x23Fail) {
1055 const std::string body = R"(
1056 %val = OpMatrixTimesMatrix %f32mat22 %f32mat23_121212 %f32mat23_121212
1057 )";
1058
1059 CompileSuccessfully(GenerateCode(body).c_str());
1060 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1061 EXPECT_THAT(getDiagnosticString(),
1062 HasSubstr("Expected number of columns of Result Type and right "
1063 "matrix to be equal: "
1064 "MatrixTimesMatrix"));
1065 }
1066
TEST_F(ValidateArithmetics,MatrixTimesMatrix23x22Fail)1067 TEST_F(ValidateArithmetics, MatrixTimesMatrix23x22Fail) {
1068 const std::string body = R"(
1069 %val = OpMatrixTimesMatrix %f32mat22 %f32mat23_121212 %f32mat22_1212
1070 )";
1071
1072 CompileSuccessfully(GenerateCode(body).c_str());
1073 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1074 EXPECT_THAT(getDiagnosticString(),
1075 HasSubstr("Expected number of columns of left matrix and number "
1076 "of rows of right "
1077 "matrix to be equal: MatrixTimesMatrix"));
1078 }
1079
TEST_F(ValidateArithmetics,OuterProduct2x2Success)1080 TEST_F(ValidateArithmetics, OuterProduct2x2Success) {
1081 const std::string body = R"(
1082 %val = OpOuterProduct %f32mat22 %f32vec2_12 %f32vec2_01
1083 )";
1084
1085 CompileSuccessfully(GenerateCode(body).c_str());
1086 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
1087 }
1088
TEST_F(ValidateArithmetics,OuterProduct3x2Success)1089 TEST_F(ValidateArithmetics, OuterProduct3x2Success) {
1090 const std::string body = R"(
1091 %val = OpOuterProduct %f32mat32 %f32vec3_123 %f32vec2_01
1092 )";
1093
1094 CompileSuccessfully(GenerateCode(body).c_str());
1095 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
1096 }
1097
TEST_F(ValidateArithmetics,OuterProduct2x3Success)1098 TEST_F(ValidateArithmetics, OuterProduct2x3Success) {
1099 const std::string body = R"(
1100 %val = OpOuterProduct %f32mat23 %f32vec2_01 %f32vec3_123
1101 )";
1102
1103 CompileSuccessfully(GenerateCode(body).c_str());
1104 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
1105 }
1106
TEST_F(ValidateArithmetics,OuterProductWrongTypeId)1107 TEST_F(ValidateArithmetics, OuterProductWrongTypeId) {
1108 const std::string body = R"(
1109 %val = OpOuterProduct %f32vec2 %f32vec2_01 %f32vec3_123
1110 )";
1111
1112 CompileSuccessfully(GenerateCode(body).c_str());
1113 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1114 EXPECT_THAT(getDiagnosticString(),
1115 HasSubstr("Expected float matrix type as Result Type: "
1116 "OuterProduct"));
1117 }
1118
TEST_F(ValidateArithmetics,OuterProductWrongLeftOperand)1119 TEST_F(ValidateArithmetics, OuterProductWrongLeftOperand) {
1120 const std::string body = R"(
1121 %val = OpOuterProduct %f32mat22 %f32vec3_123 %f32vec2_01
1122 )";
1123
1124 CompileSuccessfully(GenerateCode(body).c_str());
1125 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1126 EXPECT_THAT(
1127 getDiagnosticString(),
1128 HasSubstr("Expected column type of Result Type to be equal to the type "
1129 "of the left operand: OuterProduct"));
1130 }
1131
TEST_F(ValidateArithmetics,OuterProductRightOperandNotFloatVector)1132 TEST_F(ValidateArithmetics, OuterProductRightOperandNotFloatVector) {
1133 const std::string body = R"(
1134 %val = OpOuterProduct %f32mat22 %f32vec2_12 %u32vec2_01
1135 )";
1136
1137 CompileSuccessfully(GenerateCode(body).c_str());
1138 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1139 EXPECT_THAT(
1140 getDiagnosticString(),
1141 HasSubstr("Expected float vector type as right operand: OuterProduct"));
1142 }
1143
TEST_F(ValidateArithmetics,OuterProductRightOperandWrongComponent)1144 TEST_F(ValidateArithmetics, OuterProductRightOperandWrongComponent) {
1145 const std::string body = R"(
1146 %val = OpOuterProduct %f32mat22 %f32vec2_12 %f64vec2_01
1147 )";
1148
1149 CompileSuccessfully(GenerateCode(body).c_str());
1150 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1151 EXPECT_THAT(getDiagnosticString(),
1152 HasSubstr("Expected component types of the operands to be equal: "
1153 "OuterProduct"));
1154 }
1155
TEST_F(ValidateArithmetics,OuterProductRightOperandWrongDimension)1156 TEST_F(ValidateArithmetics, OuterProductRightOperandWrongDimension) {
1157 const std::string body = R"(
1158 %val = OpOuterProduct %f32mat22 %f32vec2_12 %f32vec3_123
1159 )";
1160
1161 CompileSuccessfully(GenerateCode(body).c_str());
1162 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1163 EXPECT_THAT(
1164 getDiagnosticString(),
1165 HasSubstr("Expected number of columns of the matrix to be equal to the "
1166 "vector size of the right operand: OuterProduct"));
1167 }
1168
GenerateCoopMatCode(const std::string & extra_types,const std::string & main_body)1169 std::string GenerateCoopMatCode(const std::string& extra_types,
1170 const std::string& main_body) {
1171 const std::string prefix =
1172 R"(
1173 OpCapability Shader
1174 OpCapability Float16
1175 OpCapability CooperativeMatrixNV
1176 OpExtension "SPV_NV_cooperative_matrix"
1177 OpMemoryModel Logical GLSL450
1178 OpEntryPoint GLCompute %main "main"
1179 %void = OpTypeVoid
1180 %func = OpTypeFunction %void
1181 %bool = OpTypeBool
1182 %f16 = OpTypeFloat 16
1183 %f32 = OpTypeFloat 32
1184 %u32 = OpTypeInt 32 0
1185 %s32 = OpTypeInt 32 1
1186
1187 %u32_8 = OpConstant %u32 8
1188 %u32_16 = OpConstant %u32 16
1189 %u32_4 = OpConstant %u32 4
1190 %subgroup = OpConstant %u32 3
1191
1192 %f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
1193 %u32mat = OpTypeCooperativeMatrixNV %u32 %subgroup %u32_8 %u32_8
1194 %s32mat = OpTypeCooperativeMatrixNV %s32 %subgroup %u32_8 %u32_8
1195
1196 %f16_1 = OpConstant %f16 1
1197 %f32_1 = OpConstant %f32 1
1198 %u32_1 = OpConstant %u32 1
1199 %s32_1 = OpConstant %s32 1
1200
1201 %f16mat_1 = OpConstantComposite %f16mat %f16_1
1202 %u32mat_1 = OpConstantComposite %u32mat %u32_1
1203 %s32mat_1 = OpConstantComposite %s32mat %s32_1
1204
1205 %u32_c1 = OpSpecConstant %u32 1
1206 %u32_c2 = OpSpecConstant %u32 2
1207
1208 %f16matc = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_c1 %u32_c2
1209 %f16matc_1 = OpConstantComposite %f16matc %f16_1
1210
1211 %mat16x4 = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_16 %u32_4
1212 %mat4x16 = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_4 %u32_16
1213 %mat16x16 = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_16 %u32_16
1214 %f16mat_16x4_1 = OpConstantComposite %mat16x4 %f16_1
1215 %f16mat_4x16_1 = OpConstantComposite %mat4x16 %f16_1
1216 %f16mat_16x16_1 = OpConstantComposite %mat16x16 %f16_1)";
1217
1218 const std::string func_begin =
1219 R"(
1220 %main = OpFunction %void None %func
1221 %main_entry = OpLabel)";
1222
1223 const std::string suffix =
1224 R"(
1225 OpReturn
1226 OpFunctionEnd)";
1227
1228 return prefix + extra_types + func_begin + main_body + suffix;
1229 }
1230
TEST_F(ValidateArithmetics,CoopMatSuccess)1231 TEST_F(ValidateArithmetics, CoopMatSuccess) {
1232 const std::string body = R"(
1233 %val1 = OpFAdd %f16mat %f16mat_1 %f16mat_1
1234 %val2 = OpFSub %f16mat %f16mat_1 %f16mat_1
1235 %val3 = OpFDiv %f16mat %f16mat_1 %f16mat_1
1236 %val4 = OpFNegate %f16mat %f16mat_1
1237 %val5 = OpIAdd %u32mat %u32mat_1 %u32mat_1
1238 %val6 = OpISub %u32mat %u32mat_1 %u32mat_1
1239 %val7 = OpUDiv %u32mat %u32mat_1 %u32mat_1
1240 %val8 = OpIAdd %s32mat %s32mat_1 %s32mat_1
1241 %val9 = OpISub %s32mat %s32mat_1 %s32mat_1
1242 %val10 = OpSDiv %s32mat %s32mat_1 %s32mat_1
1243 %val11 = OpSNegate %s32mat %s32mat_1
1244 %val12 = OpMatrixTimesScalar %f16mat %f16mat_1 %f16_1
1245 %val13 = OpMatrixTimesScalar %u32mat %u32mat_1 %u32_1
1246 %val14 = OpMatrixTimesScalar %s32mat %s32mat_1 %s32_1
1247 %val15 = OpCooperativeMatrixMulAddNV %mat16x16 %f16mat_16x4_1 %f16mat_4x16_1 %f16mat_16x16_1
1248 %val16 = OpCooperativeMatrixMulAddNV %f16matc %f16matc_1 %f16matc_1 %f16matc_1
1249 )";
1250
1251 CompileSuccessfully(GenerateCoopMatCode("", body).c_str());
1252 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
1253 }
1254
TEST_F(ValidateArithmetics,CoopMatFMulFail)1255 TEST_F(ValidateArithmetics, CoopMatFMulFail) {
1256 const std::string body = R"(
1257 %val1 = OpFMul %f16mat %f16mat_1 %f16mat_1
1258 )";
1259
1260 CompileSuccessfully(GenerateCoopMatCode("", body).c_str());
1261 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1262 EXPECT_THAT(
1263 getDiagnosticString(),
1264 HasSubstr(
1265 "Expected floating scalar or vector type as Result Type: FMul"));
1266 }
1267
TEST_F(ValidateArithmetics,CoopMatMatrixTimesScalarMismatchFail)1268 TEST_F(ValidateArithmetics, CoopMatMatrixTimesScalarMismatchFail) {
1269 const std::string body = R"(
1270 %val1 = OpMatrixTimesScalar %f16mat %f16mat_1 %f32_1
1271 )";
1272
1273 CompileSuccessfully(GenerateCoopMatCode("", body).c_str());
1274 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1275 EXPECT_THAT(
1276 getDiagnosticString(),
1277 HasSubstr("Expected scalar operand type to be equal to the component "
1278 "type of the matrix operand: MatrixTimesScalar"));
1279 }
1280
TEST_F(ValidateArithmetics,CoopMatScopeFail)1281 TEST_F(ValidateArithmetics, CoopMatScopeFail) {
1282 const std::string types = R"(
1283 %workgroup = OpConstant %u32 2
1284
1285 %mat16x16_wg = OpTypeCooperativeMatrixNV %f16 %workgroup %u32_16 %u32_16
1286 %f16matwg_16x16_1 = OpConstantComposite %mat16x16_wg %f16_1
1287 )";
1288
1289 const std::string body = R"(
1290 %val1 = OpCooperativeMatrixMulAddNV %mat16x16 %f16mat_16x4_1 %f16mat_4x16_1 %f16matwg_16x16_1
1291 )";
1292
1293 CompileSuccessfully(GenerateCoopMatCode(types, body).c_str());
1294 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1295 EXPECT_THAT(
1296 getDiagnosticString(),
1297 HasSubstr(
1298 "Cooperative matrix scopes must match: CooperativeMatrixMulAddNV"));
1299 }
1300
TEST_F(ValidateArithmetics,CoopMatDimFail)1301 TEST_F(ValidateArithmetics, CoopMatDimFail) {
1302 const std::string body = R"(
1303 %val1 = OpCooperativeMatrixMulAddNV %mat16x16 %f16mat_4x16_1 %f16mat_16x4_1 %f16mat_16x16_1
1304 )";
1305
1306 CompileSuccessfully(GenerateCoopMatCode("", body).c_str());
1307 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1308 EXPECT_THAT(
1309 getDiagnosticString(),
1310 HasSubstr("Cooperative matrix 'M' mismatch: CooperativeMatrixMulAddNV"));
1311 }
1312
TEST_F(ValidateArithmetics,CoopMatComponentTypeNotScalarNumeric)1313 TEST_F(ValidateArithmetics, CoopMatComponentTypeNotScalarNumeric) {
1314 const std::string types = R"(
1315 %bad = OpTypeCooperativeMatrixNV %bool %subgroup %u32_8 %u32_8
1316 )";
1317
1318 CompileSuccessfully(GenerateCoopMatCode(types, "").c_str());
1319 EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
1320 EXPECT_THAT(getDiagnosticString(),
1321 HasSubstr("OpTypeCooperativeMatrixNV Component Type <id> "
1322 "'4[%bool]' is not a scalar numerical type."));
1323 }
1324
TEST_F(ValidateArithmetics,CoopMatScopeNotConstantInt)1325 TEST_F(ValidateArithmetics, CoopMatScopeNotConstantInt) {
1326 const std::string types = R"(
1327 %bad = OpTypeCooperativeMatrixNV %f16 %f32_1 %u32_8 %u32_8
1328 )";
1329
1330 CompileSuccessfully(GenerateCoopMatCode(types, "").c_str());
1331 EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
1332 EXPECT_THAT(
1333 getDiagnosticString(),
1334 HasSubstr("OpTypeCooperativeMatrixNV Scope <id> '17[%float_1]' is not a "
1335 "constant instruction with scalar integer type."));
1336 }
1337
TEST_F(ValidateArithmetics,CoopMatRowsNotConstantInt)1338 TEST_F(ValidateArithmetics, CoopMatRowsNotConstantInt) {
1339 const std::string types = R"(
1340 %bad = OpTypeCooperativeMatrixNV %f16 %subgroup %f32_1 %u32_8
1341 )";
1342
1343 CompileSuccessfully(GenerateCoopMatCode(types, "").c_str());
1344 EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
1345 EXPECT_THAT(
1346 getDiagnosticString(),
1347 HasSubstr("OpTypeCooperativeMatrixNV Rows <id> '17[%float_1]' is not a "
1348 "constant instruction with scalar integer type."));
1349 }
1350
TEST_F(ValidateArithmetics,CoopMatColumnsNotConstantInt)1351 TEST_F(ValidateArithmetics, CoopMatColumnsNotConstantInt) {
1352 const std::string types = R"(
1353 %bad = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %f32_1
1354 )";
1355
1356 CompileSuccessfully(GenerateCoopMatCode(types, "").c_str());
1357 EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
1358 EXPECT_THAT(
1359 getDiagnosticString(),
1360 HasSubstr("OpTypeCooperativeMatrixNV Cols <id> '17[%float_1]' is not a "
1361 "constant instruction with scalar integer type."));
1362 }
1363
TEST_F(ValidateArithmetics,IAddCarrySuccess)1364 TEST_F(ValidateArithmetics, IAddCarrySuccess) {
1365 const std::string body = R"(
1366 %val1 = OpIAddCarry %struct_u32_u32 %u32_0 %u32_1
1367 %val2 = OpIAddCarry %struct_u32vec2_u32vec2 %u32vec2_01 %u32vec2_12
1368 )";
1369
1370 CompileSuccessfully(GenerateCode(body).c_str());
1371 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
1372 }
1373
TEST_F(ValidateArithmetics,IAddCarryResultTypeNotStruct)1374 TEST_F(ValidateArithmetics, IAddCarryResultTypeNotStruct) {
1375 const std::string body = R"(
1376 %val = OpIAddCarry %u32 %u32_0 %u32_1
1377 )";
1378
1379 CompileSuccessfully(GenerateCode(body).c_str());
1380 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1381 EXPECT_THAT(getDiagnosticString(),
1382 HasSubstr("Expected a struct as Result Type: IAddCarry"));
1383 }
1384
TEST_F(ValidateArithmetics,IAddCarryResultTypeNotTwoMembers)1385 TEST_F(ValidateArithmetics, IAddCarryResultTypeNotTwoMembers) {
1386 const std::string body = R"(
1387 %val = OpIAddCarry %struct_u32_u32_u32 %u32_0 %u32_1
1388 )";
1389
1390 CompileSuccessfully(GenerateCode(body).c_str());
1391 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1392 EXPECT_THAT(
1393 getDiagnosticString(),
1394 HasSubstr("Expected Result Type struct to have two members: IAddCarry"));
1395 }
1396
TEST_F(ValidateArithmetics,IAddCarryResultTypeMemberNotUnsignedInt)1397 TEST_F(ValidateArithmetics, IAddCarryResultTypeMemberNotUnsignedInt) {
1398 const std::string body = R"(
1399 %val = OpIAddCarry %struct_s32_s32 %s32_0 %s32_1
1400 )";
1401
1402 CompileSuccessfully(GenerateCode(body).c_str());
1403 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1404 EXPECT_THAT(getDiagnosticString(),
1405 HasSubstr("Expected Result Type struct member types to be "
1406 "unsigned integer scalar "
1407 "or vector: IAddCarry"));
1408 }
1409
TEST_F(ValidateArithmetics,IAddCarryWrongLeftOperand)1410 TEST_F(ValidateArithmetics, IAddCarryWrongLeftOperand) {
1411 const std::string body = R"(
1412 %val = OpIAddCarry %struct_u32_u32 %s32_0 %u32_1
1413 )";
1414
1415 CompileSuccessfully(GenerateCode(body).c_str());
1416 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1417 EXPECT_THAT(getDiagnosticString(),
1418 HasSubstr("Expected both operands to be of Result Type member "
1419 "type: IAddCarry"));
1420 }
1421
TEST_F(ValidateArithmetics,IAddCarryWrongRightOperand)1422 TEST_F(ValidateArithmetics, IAddCarryWrongRightOperand) {
1423 const std::string body = R"(
1424 %val = OpIAddCarry %struct_u32_u32 %u32_0 %s32_1
1425 )";
1426
1427 CompileSuccessfully(GenerateCode(body).c_str());
1428 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1429 EXPECT_THAT(getDiagnosticString(),
1430 HasSubstr("Expected both operands to be of Result Type member "
1431 "type: IAddCarry"));
1432 }
1433
TEST_F(ValidateArithmetics,OpSMulExtendedSuccess)1434 TEST_F(ValidateArithmetics, OpSMulExtendedSuccess) {
1435 const std::string body = R"(
1436 %val1 = OpSMulExtended %struct_u32_u32 %u32_0 %u32_1
1437 %val2 = OpSMulExtended %struct_s32_s32 %s32_0 %s32_1
1438 %val3 = OpSMulExtended %struct_u32vec2_u32vec2 %u32vec2_01 %u32vec2_12
1439 %val4 = OpSMulExtended %struct_s32vec2_s32vec2 %s32vec2_01 %s32vec2_12
1440 )";
1441
1442 CompileSuccessfully(GenerateCode(body).c_str());
1443 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
1444 }
1445
TEST_F(ValidateArithmetics,SMulExtendedResultTypeMemberNotInt)1446 TEST_F(ValidateArithmetics, SMulExtendedResultTypeMemberNotInt) {
1447 const std::string body = R"(
1448 %val = OpSMulExtended %struct_f32_f32 %f32_0 %f32_1
1449 )";
1450
1451 CompileSuccessfully(GenerateCode(body).c_str());
1452 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1453 EXPECT_THAT(
1454 getDiagnosticString(),
1455 HasSubstr("Expected Result Type struct member types to be integer scalar "
1456 "or vector: SMulExtended"));
1457 }
1458
TEST_F(ValidateArithmetics,SMulExtendedResultTypeMembersNotIdentical)1459 TEST_F(ValidateArithmetics, SMulExtendedResultTypeMembersNotIdentical) {
1460 const std::string body = R"(
1461 %val = OpSMulExtended %struct_s32_u32 %s32_0 %s32_1
1462 )";
1463
1464 CompileSuccessfully(GenerateCode(body).c_str());
1465 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1466 EXPECT_THAT(
1467 getDiagnosticString(),
1468 HasSubstr("Expected Result Type struct member types to be identical: "
1469 "SMulExtended"));
1470 }
1471
1472 } // namespace
1473 } // namespace val
1474 } // namespace spvtools
1475