• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // Tests for unique type declaration rules validator.
16 
17 #include <string>
18 
19 #include "gmock/gmock.h"
20 #include "test/unit_spirv.h"
21 #include "test/val/val_fixtures.h"
22 
23 namespace spvtools {
24 namespace val {
25 namespace {
26 
27 using ::testing::HasSubstr;
28 using ::testing::Not;
29 
30 using ValidateArithmetics = spvtest::ValidateBase<bool>;
31 
GenerateCode(const std::string & main_body)32 std::string GenerateCode(const std::string& main_body) {
33   const std::string prefix =
34       R"(
35 OpCapability Shader
36 OpCapability Int64
37 OpCapability Float64
38 OpCapability Matrix
39 %ext_inst = OpExtInstImport "GLSL.std.450"
40 OpMemoryModel Logical GLSL450
41 OpEntryPoint Fragment %main "main"
42 OpExecutionMode %main OriginUpperLeft
43 %void = OpTypeVoid
44 %func = OpTypeFunction %void
45 %bool = OpTypeBool
46 %f32 = OpTypeFloat 32
47 %u32 = OpTypeInt 32 0
48 %s32 = OpTypeInt 32 1
49 %f64 = OpTypeFloat 64
50 %u64 = OpTypeInt 64 0
51 %s64 = OpTypeInt 64 1
52 %boolvec2 = OpTypeVector %bool 2
53 %s32vec2 = OpTypeVector %s32 2
54 %u32vec2 = OpTypeVector %u32 2
55 %u64vec2 = OpTypeVector %u64 2
56 %f32vec2 = OpTypeVector %f32 2
57 %f64vec2 = OpTypeVector %f64 2
58 %boolvec3 = OpTypeVector %bool 3
59 %u32vec3 = OpTypeVector %u32 3
60 %u64vec3 = OpTypeVector %u64 3
61 %s32vec3 = OpTypeVector %s32 3
62 %f32vec3 = OpTypeVector %f32 3
63 %f64vec3 = OpTypeVector %f64 3
64 %boolvec4 = OpTypeVector %bool 4
65 %u32vec4 = OpTypeVector %u32 4
66 %u64vec4 = OpTypeVector %u64 4
67 %s32vec4 = OpTypeVector %s32 4
68 %f32vec4 = OpTypeVector %f32 4
69 %f64vec4 = OpTypeVector %f64 4
70 
71 %f32mat22 = OpTypeMatrix %f32vec2 2
72 %f32mat23 = OpTypeMatrix %f32vec2 3
73 %f32mat32 = OpTypeMatrix %f32vec3 2
74 %f32mat33 = OpTypeMatrix %f32vec3 3
75 %f64mat22 = OpTypeMatrix %f64vec2 2
76 
77 %struct_f32_f32 = OpTypeStruct %f32 %f32
78 %struct_u32_u32 = OpTypeStruct %u32 %u32
79 %struct_u32_u32_u32 = OpTypeStruct %u32 %u32 %u32
80 %struct_s32_s32 = OpTypeStruct %s32 %s32
81 %struct_s32_u32 = OpTypeStruct %s32 %u32
82 %struct_u32vec2_u32vec2 = OpTypeStruct %u32vec2 %u32vec2
83 %struct_s32vec2_s32vec2 = OpTypeStruct %s32vec2 %s32vec2
84 
85 %f32_0 = OpConstant %f32 0
86 %f32_1 = OpConstant %f32 1
87 %f32_2 = OpConstant %f32 2
88 %f32_3 = OpConstant %f32 3
89 %f32_4 = OpConstant %f32 4
90 %f32_pi = OpConstant %f32 3.14159
91 
92 %s32_0 = OpConstant %s32 0
93 %s32_1 = OpConstant %s32 1
94 %s32_2 = OpConstant %s32 2
95 %s32_3 = OpConstant %s32 3
96 %s32_4 = OpConstant %s32 4
97 %s32_m1 = OpConstant %s32 -1
98 
99 %u32_0 = OpConstant %u32 0
100 %u32_1 = OpConstant %u32 1
101 %u32_2 = OpConstant %u32 2
102 %u32_3 = OpConstant %u32 3
103 %u32_4 = OpConstant %u32 4
104 
105 %f64_0 = OpConstant %f64 0
106 %f64_1 = OpConstant %f64 1
107 %f64_2 = OpConstant %f64 2
108 %f64_3 = OpConstant %f64 3
109 %f64_4 = OpConstant %f64 4
110 
111 %s64_0 = OpConstant %s64 0
112 %s64_1 = OpConstant %s64 1
113 %s64_2 = OpConstant %s64 2
114 %s64_3 = OpConstant %s64 3
115 %s64_4 = OpConstant %s64 4
116 %s64_m1 = OpConstant %s64 -1
117 
118 %u64_0 = OpConstant %u64 0
119 %u64_1 = OpConstant %u64 1
120 %u64_2 = OpConstant %u64 2
121 %u64_3 = OpConstant %u64 3
122 %u64_4 = OpConstant %u64 4
123 
124 %u32vec2_01 = OpConstantComposite %u32vec2 %u32_0 %u32_1
125 %u32vec2_12 = OpConstantComposite %u32vec2 %u32_1 %u32_2
126 %u32vec3_012 = OpConstantComposite %u32vec3 %u32_0 %u32_1 %u32_2
127 %u32vec3_123 = OpConstantComposite %u32vec3 %u32_1 %u32_2 %u32_3
128 %u32vec4_0123 = OpConstantComposite %u32vec4 %u32_0 %u32_1 %u32_2 %u32_3
129 %u32vec4_1234 = OpConstantComposite %u32vec4 %u32_1 %u32_2 %u32_3 %u32_4
130 
131 %s32vec2_01 = OpConstantComposite %s32vec2 %s32_0 %s32_1
132 %s32vec2_12 = OpConstantComposite %s32vec2 %s32_1 %s32_2
133 %s32vec3_012 = OpConstantComposite %s32vec3 %s32_0 %s32_1 %s32_2
134 %s32vec3_123 = OpConstantComposite %s32vec3 %s32_1 %s32_2 %s32_3
135 %s32vec4_0123 = OpConstantComposite %s32vec4 %s32_0 %s32_1 %s32_2 %s32_3
136 %s32vec4_1234 = OpConstantComposite %s32vec4 %s32_1 %s32_2 %s32_3 %s32_4
137 
138 %f32vec2_01 = OpConstantComposite %f32vec2 %f32_0 %f32_1
139 %f32vec2_12 = OpConstantComposite %f32vec2 %f32_1 %f32_2
140 %f32vec3_012 = OpConstantComposite %f32vec3 %f32_0 %f32_1 %f32_2
141 %f32vec3_123 = OpConstantComposite %f32vec3 %f32_1 %f32_2 %f32_3
142 %f32vec4_0123 = OpConstantComposite %f32vec4 %f32_0 %f32_1 %f32_2 %f32_3
143 %f32vec4_1234 = OpConstantComposite %f32vec4 %f32_1 %f32_2 %f32_3 %f32_4
144 
145 %f64vec2_01 = OpConstantComposite %f64vec2 %f64_0 %f64_1
146 %f64vec2_12 = OpConstantComposite %f64vec2 %f64_1 %f64_2
147 %f64vec3_012 = OpConstantComposite %f64vec3 %f64_0 %f64_1 %f64_2
148 %f64vec3_123 = OpConstantComposite %f64vec3 %f64_1 %f64_2 %f64_3
149 %f64vec4_0123 = OpConstantComposite %f64vec4 %f64_0 %f64_1 %f64_2 %f64_3
150 %f64vec4_1234 = OpConstantComposite %f64vec4 %f64_1 %f64_2 %f64_3 %f64_4
151 
152 %f32mat22_1212 = OpConstantComposite %f32mat22 %f32vec2_12 %f32vec2_12
153 %f32mat23_121212 = OpConstantComposite %f32mat23 %f32vec2_12 %f32vec2_12 %f32vec2_12
154 %f32mat32_123123 = OpConstantComposite %f32mat32 %f32vec3_123 %f32vec3_123
155 %f32mat33_123123123 = OpConstantComposite %f32mat33 %f32vec3_123 %f32vec3_123 %f32vec3_123
156 
157 %f64mat22_1212 = OpConstantComposite %f64mat22 %f64vec2_12 %f64vec2_12
158 
159 %main = OpFunction %void None %func
160 %main_entry = OpLabel)";
161 
162   const std::string suffix =
163       R"(
164 OpReturn
165 OpFunctionEnd)";
166 
167   return prefix + main_body + suffix;
168 }
169 
TEST_F(ValidateArithmetics,F32Success)170 TEST_F(ValidateArithmetics, F32Success) {
171   const std::string body = R"(
172 %val1 = OpFMul %f32 %f32_0 %f32_1
173 %val2 = OpFSub %f32 %f32_2 %f32_0
174 %val3 = OpFAdd %f32 %val1 %val2
175 %val4 = OpFNegate %f32 %val3
176 %val5 = OpFDiv %f32 %val4 %val1
177 %val6 = OpFRem %f32 %val4 %f32_2
178 %val7 = OpFMod %f32 %val4 %f32_2
179 )";
180 
181   CompileSuccessfully(GenerateCode(body).c_str());
182   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
183 }
184 
TEST_F(ValidateArithmetics,F64Success)185 TEST_F(ValidateArithmetics, F64Success) {
186   const std::string body = R"(
187 %val1 = OpFMul %f64 %f64_0 %f64_1
188 %val2 = OpFSub %f64 %f64_2 %f64_0
189 %val3 = OpFAdd %f64 %val1 %val2
190 %val4 = OpFNegate %f64 %val3
191 %val5 = OpFDiv %f64 %val4 %val1
192 %val6 = OpFRem %f64 %val4 %f64_2
193 %val7 = OpFMod %f64 %val4 %f64_2
194 )";
195 
196   CompileSuccessfully(GenerateCode(body).c_str());
197   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
198 }
199 
TEST_F(ValidateArithmetics,Int32Success)200 TEST_F(ValidateArithmetics, Int32Success) {
201   const std::string body = R"(
202 %val1 = OpIMul %u32 %s32_0 %u32_1
203 %val2 = OpIMul %s32 %s32_2 %u32_1
204 %val3 = OpIAdd %u32 %val1 %val2
205 %val4 = OpIAdd %s32 %val1 %val2
206 %val5 = OpISub %u32 %val3 %val4
207 %val6 = OpISub %s32 %val4 %val3
208 %val7 = OpSDiv %s32 %val4 %val3
209 %val8 = OpSNegate %s32 %val7
210 %val9 = OpSRem %s32 %val4 %val3
211 %val10 = OpSMod %s32 %val4 %val3
212 )";
213 
214   CompileSuccessfully(GenerateCode(body).c_str());
215   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
216 }
217 
TEST_F(ValidateArithmetics,Int64Success)218 TEST_F(ValidateArithmetics, Int64Success) {
219   const std::string body = R"(
220 %val1 = OpIMul %u64 %s64_0 %u64_1
221 %val2 = OpIMul %s64 %s64_2 %u64_1
222 %val3 = OpIAdd %u64 %val1 %val2
223 %val4 = OpIAdd %s64 %val1 %val2
224 %val5 = OpISub %u64 %val3 %val4
225 %val6 = OpISub %s64 %val4 %val3
226 %val7 = OpSDiv %s64 %val4 %val3
227 %val8 = OpSNegate %s64 %val7
228 %val9 = OpSRem %s64 %val4 %val3
229 %val10 = OpSMod %s64 %val4 %val3
230 )";
231 
232   CompileSuccessfully(GenerateCode(body).c_str());
233   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
234 }
235 
TEST_F(ValidateArithmetics,F32Vec2Success)236 TEST_F(ValidateArithmetics, F32Vec2Success) {
237   const std::string body = R"(
238 %val1 = OpFMul %f32vec2 %f32vec2_01 %f32vec2_12
239 %val2 = OpFSub %f32vec2 %f32vec2_12 %f32vec2_01
240 %val3 = OpFAdd %f32vec2 %val1 %val2
241 %val4 = OpFNegate %f32vec2 %val3
242 %val5 = OpFDiv %f32vec2 %val4 %val1
243 %val6 = OpFRem %f32vec2 %val4 %f32vec2_12
244 %val7 = OpFMod %f32vec2 %val4 %f32vec2_12
245 )";
246 
247   CompileSuccessfully(GenerateCode(body).c_str());
248   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
249 }
250 
TEST_F(ValidateArithmetics,F64Vec2Success)251 TEST_F(ValidateArithmetics, F64Vec2Success) {
252   const std::string body = R"(
253 %val1 = OpFMul %f64vec2 %f64vec2_01 %f64vec2_12
254 %val2 = OpFSub %f64vec2 %f64vec2_12 %f64vec2_01
255 %val3 = OpFAdd %f64vec2 %val1 %val2
256 %val4 = OpFNegate %f64vec2 %val3
257 %val5 = OpFDiv %f64vec2 %val4 %val1
258 %val6 = OpFRem %f64vec2 %val4 %f64vec2_12
259 %val7 = OpFMod %f64vec2 %val4 %f64vec2_12
260 )";
261 
262   CompileSuccessfully(GenerateCode(body).c_str());
263   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
264 }
265 
TEST_F(ValidateArithmetics,U32Vec2Success)266 TEST_F(ValidateArithmetics, U32Vec2Success) {
267   const std::string body = R"(
268 %val1 = OpIMul %u32vec2 %u32vec2_01 %u32vec2_12
269 %val2 = OpISub %u32vec2 %u32vec2_12 %u32vec2_01
270 %val3 = OpIAdd %u32vec2 %val1 %val2
271 %val4 = OpSNegate %u32vec2 %val3
272 %val5 = OpSDiv %u32vec2 %val4 %val1
273 %val6 = OpSRem %u32vec2 %val4 %u32vec2_12
274 %val7 = OpSMod %u32vec2 %val4 %u32vec2_12
275 )";
276 
277   CompileSuccessfully(GenerateCode(body).c_str());
278   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
279 }
280 
TEST_F(ValidateArithmetics,FNegateTypeIdU32)281 TEST_F(ValidateArithmetics, FNegateTypeIdU32) {
282   const std::string body = R"(
283 %val = OpFNegate %u32 %u32_0
284 )";
285 
286   CompileSuccessfully(GenerateCode(body).c_str());
287   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
288   EXPECT_THAT(
289       getDiagnosticString(),
290       HasSubstr(
291           "Expected floating scalar or vector type as Result Type: FNegate"));
292 }
293 
TEST_F(ValidateArithmetics,FNegateTypeIdVec2U32)294 TEST_F(ValidateArithmetics, FNegateTypeIdVec2U32) {
295   const std::string body = R"(
296 %val = OpFNegate %u32vec2 %u32vec2_01
297 )";
298 
299   CompileSuccessfully(GenerateCode(body).c_str());
300   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
301   EXPECT_THAT(
302       getDiagnosticString(),
303       HasSubstr(
304           "Expected floating scalar or vector type as Result Type: FNegate"));
305 }
306 
TEST_F(ValidateArithmetics,FNegateWrongOperand)307 TEST_F(ValidateArithmetics, FNegateWrongOperand) {
308   const std::string body = R"(
309 %val = OpFNegate %f32 %u32_0
310 )";
311 
312   CompileSuccessfully(GenerateCode(body).c_str());
313   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
314   EXPECT_THAT(getDiagnosticString(),
315               HasSubstr("Expected arithmetic operands to be of Result Type: "
316                         "FNegate operand index 2"));
317 }
318 
TEST_F(ValidateArithmetics,FMulTypeIdU32)319 TEST_F(ValidateArithmetics, FMulTypeIdU32) {
320   const std::string body = R"(
321 %val = OpFMul %u32 %u32_0 %u32_1
322 )";
323 
324   CompileSuccessfully(GenerateCode(body).c_str());
325   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
326   EXPECT_THAT(
327       getDiagnosticString(),
328       HasSubstr(
329           "Expected floating scalar or vector type as Result Type: FMul"));
330 }
331 
TEST_F(ValidateArithmetics,FMulTypeIdVec2U32)332 TEST_F(ValidateArithmetics, FMulTypeIdVec2U32) {
333   const std::string body = R"(
334 %val = OpFMul %u32vec2 %u32vec2_01 %u32vec2_12
335 )";
336 
337   CompileSuccessfully(GenerateCode(body).c_str());
338   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
339   EXPECT_THAT(
340       getDiagnosticString(),
341       HasSubstr(
342           "Expected floating scalar or vector type as Result Type: FMul"));
343 }
344 
TEST_F(ValidateArithmetics,FMulWrongOperand1)345 TEST_F(ValidateArithmetics, FMulWrongOperand1) {
346   const std::string body = R"(
347 %val = OpFMul %f32 %u32_0 %f32_1
348 )";
349 
350   CompileSuccessfully(GenerateCode(body).c_str());
351   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
352   EXPECT_THAT(getDiagnosticString(),
353               HasSubstr("Expected arithmetic operands to be of Result Type: "
354                         "FMul operand index 2"));
355 }
356 
TEST_F(ValidateArithmetics,FMulWrongOperand2)357 TEST_F(ValidateArithmetics, FMulWrongOperand2) {
358   const std::string body = R"(
359 %val = OpFMul %f32 %f32_0 %u32_1
360 )";
361 
362   CompileSuccessfully(GenerateCode(body).c_str());
363   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
364   EXPECT_THAT(getDiagnosticString(),
365               HasSubstr("Expected arithmetic operands to be of Result Type: "
366                         "FMul operand index 3"));
367 }
368 
TEST_F(ValidateArithmetics,FMulWrongVectorOperand1)369 TEST_F(ValidateArithmetics, FMulWrongVectorOperand1) {
370   const std::string body = R"(
371 %val = OpFMul %f64vec3 %f32vec3_123 %f64vec3_012
372 )";
373 
374   CompileSuccessfully(GenerateCode(body).c_str());
375   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
376   EXPECT_THAT(getDiagnosticString(),
377               HasSubstr("Expected arithmetic operands to be of Result Type: "
378                         "FMul operand index 2"));
379 }
380 
TEST_F(ValidateArithmetics,FMulWrongVectorOperand2)381 TEST_F(ValidateArithmetics, FMulWrongVectorOperand2) {
382   const std::string body = R"(
383 %val = OpFMul %f32vec3 %f32vec3_123 %f64vec3_012
384 )";
385 
386   CompileSuccessfully(GenerateCode(body).c_str());
387   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
388   EXPECT_THAT(getDiagnosticString(),
389               HasSubstr("Expected arithmetic operands to be of Result Type: "
390                         "FMul operand index 3"));
391 }
392 
TEST_F(ValidateArithmetics,IMulFloatTypeId)393 TEST_F(ValidateArithmetics, IMulFloatTypeId) {
394   const std::string body = R"(
395 %val = OpIMul %f32 %u32_0 %s32_1
396 )";
397 
398   CompileSuccessfully(GenerateCode(body).c_str());
399   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
400   EXPECT_THAT(
401       getDiagnosticString(),
402       HasSubstr("Expected int scalar or vector type as Result Type: IMul"));
403 }
404 
TEST_F(ValidateArithmetics,IMulFloatOperand1)405 TEST_F(ValidateArithmetics, IMulFloatOperand1) {
406   const std::string body = R"(
407 %val = OpIMul %u32 %f32_0 %s32_1
408 )";
409 
410   CompileSuccessfully(GenerateCode(body).c_str());
411   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
412   EXPECT_THAT(getDiagnosticString(),
413               HasSubstr("Expected int scalar or vector type as operand: "
414                         "IMul operand index 2"));
415 }
416 
TEST_F(ValidateArithmetics,IMulFloatOperand2)417 TEST_F(ValidateArithmetics, IMulFloatOperand2) {
418   const std::string body = R"(
419 %val = OpIMul %u32 %s32_0 %f32_1
420 )";
421 
422   CompileSuccessfully(GenerateCode(body).c_str());
423   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
424   EXPECT_THAT(getDiagnosticString(),
425               HasSubstr("Expected int scalar or vector type as operand: "
426                         "IMul operand index 3"));
427 }
428 
TEST_F(ValidateArithmetics,IMulWrongBitWidthOperand1)429 TEST_F(ValidateArithmetics, IMulWrongBitWidthOperand1) {
430   const std::string body = R"(
431 %val = OpIMul %u64 %u32_0 %s64_1
432 )";
433 
434   CompileSuccessfully(GenerateCode(body).c_str());
435   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
436   EXPECT_THAT(
437       getDiagnosticString(),
438       HasSubstr("Expected arithmetic operands to have the same bit width "
439                 "as Result Type: IMul operand index 2"));
440 }
441 
TEST_F(ValidateArithmetics,IMulWrongBitWidthOperand2)442 TEST_F(ValidateArithmetics, IMulWrongBitWidthOperand2) {
443   const std::string body = R"(
444 %val = OpIMul %u32 %u32_0 %s64_1
445 )";
446 
447   CompileSuccessfully(GenerateCode(body).c_str());
448   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
449   EXPECT_THAT(
450       getDiagnosticString(),
451       HasSubstr("Expected arithmetic operands to have the same bit width "
452                 "as Result Type: IMul operand index 3"));
453 }
454 
TEST_F(ValidateArithmetics,IMulWrongBitWidthVector)455 TEST_F(ValidateArithmetics, IMulWrongBitWidthVector) {
456   const std::string body = R"(
457 %val = OpIMul %u64vec3 %u32vec3_012 %u32vec3_123
458 )";
459 
460   CompileSuccessfully(GenerateCode(body).c_str());
461   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
462   EXPECT_THAT(
463       getDiagnosticString(),
464       HasSubstr("Expected arithmetic operands to have the same bit width "
465                 "as Result Type: IMul operand index 2"));
466 }
467 
TEST_F(ValidateArithmetics,IMulVectorScalarOperand1)468 TEST_F(ValidateArithmetics, IMulVectorScalarOperand1) {
469   const std::string body = R"(
470 %val = OpIMul %u32vec2 %u32_0 %u32vec2_01
471 )";
472 
473   CompileSuccessfully(GenerateCode(body).c_str());
474   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
475   EXPECT_THAT(
476       getDiagnosticString(),
477       HasSubstr("Expected arithmetic operands to have the same dimension "
478                 "as Result Type: IMul operand index 2"));
479 }
480 
TEST_F(ValidateArithmetics,IMulVectorScalarOperand2)481 TEST_F(ValidateArithmetics, IMulVectorScalarOperand2) {
482   const std::string body = R"(
483 %val = OpIMul %u32vec2 %u32vec2_01 %u32_0
484 )";
485 
486   CompileSuccessfully(GenerateCode(body).c_str());
487   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
488   EXPECT_THAT(
489       getDiagnosticString(),
490       HasSubstr("Expected arithmetic operands to have the same dimension "
491                 "as Result Type: IMul operand index 3"));
492 }
493 
TEST_F(ValidateArithmetics,IMulScalarVectorOperand1)494 TEST_F(ValidateArithmetics, IMulScalarVectorOperand1) {
495   const std::string body = R"(
496 %val = OpIMul %s32 %u32vec2_01 %u32_0
497 )";
498 
499   CompileSuccessfully(GenerateCode(body).c_str());
500   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
501   EXPECT_THAT(
502       getDiagnosticString(),
503       HasSubstr("Expected arithmetic operands to have the same dimension "
504                 "as Result Type: IMul operand index 2"));
505 }
506 
TEST_F(ValidateArithmetics,IMulScalarVectorOperand2)507 TEST_F(ValidateArithmetics, IMulScalarVectorOperand2) {
508   const std::string body = R"(
509 %val = OpIMul %u32 %u32_0 %s32vec2_01
510 )";
511 
512   CompileSuccessfully(GenerateCode(body).c_str());
513   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
514   EXPECT_THAT(
515       getDiagnosticString(),
516       HasSubstr("Expected arithmetic operands to have the same dimension "
517                 "as Result Type: IMul operand index 3"));
518 }
519 
TEST_F(ValidateArithmetics,SNegateFloat)520 TEST_F(ValidateArithmetics, SNegateFloat) {
521   const std::string body = R"(
522 %val = OpSNegate %s32 %f32_1
523 )";
524 
525   CompileSuccessfully(GenerateCode(body).c_str());
526   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
527   EXPECT_THAT(getDiagnosticString(),
528               HasSubstr("Expected int scalar or vector type as operand: "
529                         "SNegate operand index 2"));
530 }
531 
TEST_F(ValidateArithmetics,UDivFloatType)532 TEST_F(ValidateArithmetics, UDivFloatType) {
533   const std::string body = R"(
534 %val = OpUDiv %f32 %u32_2 %u32_1
535 )";
536 
537   CompileSuccessfully(GenerateCode(body).c_str());
538   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
539   EXPECT_THAT(
540       getDiagnosticString(),
541       HasSubstr(
542           "Expected unsigned int scalar or vector type as Result Type: UDiv"));
543 }
544 
TEST_F(ValidateArithmetics,UDivSignedIntType)545 TEST_F(ValidateArithmetics, UDivSignedIntType) {
546   const std::string body = R"(
547 %val = OpUDiv %s32 %u32_2 %u32_1
548 )";
549 
550   CompileSuccessfully(GenerateCode(body).c_str());
551   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
552   EXPECT_THAT(
553       getDiagnosticString(),
554       HasSubstr(
555           "Expected unsigned int scalar or vector type as Result Type: UDiv"));
556 }
557 
TEST_F(ValidateArithmetics,UDivWrongOperand1)558 TEST_F(ValidateArithmetics, UDivWrongOperand1) {
559   const std::string body = R"(
560 %val = OpUDiv %u64 %f64_2 %u64_1
561 )";
562 
563   CompileSuccessfully(GenerateCode(body).c_str());
564   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
565   EXPECT_THAT(getDiagnosticString(),
566               HasSubstr("Expected arithmetic operands to be of Result Type: "
567                         "UDiv operand index 2"));
568 }
569 
TEST_F(ValidateArithmetics,UDivWrongOperand2)570 TEST_F(ValidateArithmetics, UDivWrongOperand2) {
571   const std::string body = R"(
572 %val = OpUDiv %u64 %u64_2 %u32_1
573 )";
574 
575   CompileSuccessfully(GenerateCode(body).c_str());
576   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
577   EXPECT_THAT(getDiagnosticString(),
578               HasSubstr("Expected arithmetic operands to be of Result Type: "
579                         "UDiv operand index 3"));
580 }
581 
TEST_F(ValidateArithmetics,DotSuccess)582 TEST_F(ValidateArithmetics, DotSuccess) {
583   const std::string body = R"(
584 %val = OpDot %f32 %f32vec2_01 %f32vec2_12
585 )";
586 
587   CompileSuccessfully(GenerateCode(body).c_str());
588   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
589 }
590 
TEST_F(ValidateArithmetics,DotWrongTypeId)591 TEST_F(ValidateArithmetics, DotWrongTypeId) {
592   const std::string body = R"(
593 %val = OpDot %u32 %u32vec2_01 %u32vec2_12
594 )";
595 
596   CompileSuccessfully(GenerateCode(body).c_str());
597   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
598   EXPECT_THAT(getDiagnosticString(),
599               HasSubstr("Expected float scalar type as Result Type: Dot"));
600 }
601 
TEST_F(ValidateArithmetics,DotNotVectorTypeOperand1)602 TEST_F(ValidateArithmetics, DotNotVectorTypeOperand1) {
603   const std::string body = R"(
604 %val = OpDot %f32 %f32 %f32vec2_12
605 )";
606 
607   CompileSuccessfully(GenerateCode(body).c_str());
608   ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
609   EXPECT_THAT(getDiagnosticString(),
610               HasSubstr("Operand '6[%float]' cannot be a "
611                         "type"));
612 }
613 
TEST_F(ValidateArithmetics,DotNotVectorTypeOperand2)614 TEST_F(ValidateArithmetics, DotNotVectorTypeOperand2) {
615   const std::string body = R"(
616 %val = OpDot %f32 %f32vec3_012 %f32_1
617 )";
618 
619   CompileSuccessfully(GenerateCode(body).c_str());
620   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
621   EXPECT_THAT(
622       getDiagnosticString(),
623       HasSubstr("Expected float vector as operand: Dot operand index 3"));
624 }
625 
TEST_F(ValidateArithmetics,DotWrongComponentOperand1)626 TEST_F(ValidateArithmetics, DotWrongComponentOperand1) {
627   const std::string body = R"(
628 %val = OpDot %f64 %f32vec2_01 %f64vec2_12
629 )";
630 
631   CompileSuccessfully(GenerateCode(body).c_str());
632   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
633   EXPECT_THAT(getDiagnosticString(),
634               HasSubstr("Expected component type to be equal to Result Type: "
635                         "Dot operand index 2"));
636 }
637 
TEST_F(ValidateArithmetics,DotWrongComponentOperand2)638 TEST_F(ValidateArithmetics, DotWrongComponentOperand2) {
639   const std::string body = R"(
640 %val = OpDot %f32 %f32vec2_01 %f64vec2_12
641 )";
642 
643   CompileSuccessfully(GenerateCode(body).c_str());
644   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
645   EXPECT_THAT(getDiagnosticString(),
646               HasSubstr("Expected component type to be equal to Result Type: "
647                         "Dot operand index 3"));
648 }
649 
TEST_F(ValidateArithmetics,DotDifferentVectorSize)650 TEST_F(ValidateArithmetics, DotDifferentVectorSize) {
651   const std::string body = R"(
652 %val = OpDot %f32 %f32vec2_01 %f32vec3_123
653 )";
654 
655   CompileSuccessfully(GenerateCode(body).c_str());
656   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
657   EXPECT_THAT(
658       getDiagnosticString(),
659       HasSubstr(
660           "Expected operands to have the same number of components: Dot"));
661 }
662 
TEST_F(ValidateArithmetics,VectorTimesScalarSuccess)663 TEST_F(ValidateArithmetics, VectorTimesScalarSuccess) {
664   const std::string body = R"(
665 %val = OpVectorTimesScalar %f32vec2 %f32vec2_01 %f32_2
666 )";
667 
668   CompileSuccessfully(GenerateCode(body).c_str());
669   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
670 }
671 
TEST_F(ValidateArithmetics,VectorTimesScalarWrongTypeId)672 TEST_F(ValidateArithmetics, VectorTimesScalarWrongTypeId) {
673   const std::string body = R"(
674 %val = OpVectorTimesScalar %u32vec2 %f32vec2_01 %f32_2
675 )";
676 
677   CompileSuccessfully(GenerateCode(body).c_str());
678   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
679   EXPECT_THAT(getDiagnosticString(),
680               HasSubstr("Expected float vector type as Result Type: "
681                         "VectorTimesScalar"));
682 }
683 
TEST_F(ValidateArithmetics,VectorTimesScalarWrongVector)684 TEST_F(ValidateArithmetics, VectorTimesScalarWrongVector) {
685   const std::string body = R"(
686 %val = OpVectorTimesScalar %f32vec2 %f32vec3_012 %f32_2
687 )";
688 
689   CompileSuccessfully(GenerateCode(body).c_str());
690   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
691   EXPECT_THAT(
692       getDiagnosticString(),
693       HasSubstr("Expected vector operand type to be equal to Result Type: "
694                 "VectorTimesScalar"));
695 }
696 
TEST_F(ValidateArithmetics,VectorTimesScalarWrongScalar)697 TEST_F(ValidateArithmetics, VectorTimesScalarWrongScalar) {
698   const std::string body = R"(
699 %val = OpVectorTimesScalar %f32vec2 %f32vec2_01 %f64_2
700 )";
701 
702   CompileSuccessfully(GenerateCode(body).c_str());
703   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
704   EXPECT_THAT(
705       getDiagnosticString(),
706       HasSubstr("Expected scalar operand type to be equal to the component "
707                 "type of the vector operand: VectorTimesScalar"));
708 }
709 
TEST_F(ValidateArithmetics,MatrixTimesScalarSuccess)710 TEST_F(ValidateArithmetics, MatrixTimesScalarSuccess) {
711   const std::string body = R"(
712 %val = OpMatrixTimesScalar %f32mat22 %f32mat22_1212 %f32_2
713 )";
714 
715   CompileSuccessfully(GenerateCode(body).c_str());
716   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
717 }
718 
TEST_F(ValidateArithmetics,MatrixTimesScalarWrongTypeId)719 TEST_F(ValidateArithmetics, MatrixTimesScalarWrongTypeId) {
720   const std::string body = R"(
721 %val = OpMatrixTimesScalar %f32vec2 %f32mat22_1212 %f32_2
722 )";
723 
724   CompileSuccessfully(GenerateCode(body).c_str());
725   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
726   EXPECT_THAT(getDiagnosticString(),
727               HasSubstr("Expected float matrix type as Result Type: "
728                         "MatrixTimesScalar"));
729 }
730 
TEST_F(ValidateArithmetics,MatrixTimesScalarWrongMatrix)731 TEST_F(ValidateArithmetics, MatrixTimesScalarWrongMatrix) {
732   const std::string body = R"(
733 %val = OpMatrixTimesScalar %f32mat22 %f32vec2_01 %f32_2
734 )";
735 
736   CompileSuccessfully(GenerateCode(body).c_str());
737   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
738   EXPECT_THAT(
739       getDiagnosticString(),
740       HasSubstr("Expected matrix operand type to be equal to Result Type: "
741                 "MatrixTimesScalar"));
742 }
743 
TEST_F(ValidateArithmetics,MatrixTimesScalarWrongScalar)744 TEST_F(ValidateArithmetics, MatrixTimesScalarWrongScalar) {
745   const std::string body = R"(
746 %val = OpMatrixTimesScalar %f32mat22 %f32mat22_1212 %f64_2
747 )";
748 
749   CompileSuccessfully(GenerateCode(body).c_str());
750   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
751   EXPECT_THAT(
752       getDiagnosticString(),
753       HasSubstr("Expected scalar operand type to be equal to the component "
754                 "type of the matrix operand: MatrixTimesScalar"));
755 }
756 
TEST_F(ValidateArithmetics,VectorTimesMatrix2x22Success)757 TEST_F(ValidateArithmetics, VectorTimesMatrix2x22Success) {
758   const std::string body = R"(
759 %val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f32mat22_1212
760 )";
761 
762   CompileSuccessfully(GenerateCode(body).c_str());
763   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
764 }
765 
TEST_F(ValidateArithmetics,VectorTimesMatrix3x32Success)766 TEST_F(ValidateArithmetics, VectorTimesMatrix3x32Success) {
767   const std::string body = R"(
768 %val = OpVectorTimesMatrix %f32vec2 %f32vec3_123 %f32mat32_123123
769 )";
770 
771   CompileSuccessfully(GenerateCode(body).c_str());
772   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
773 }
774 
TEST_F(ValidateArithmetics,VectorTimesMatrixWrongTypeId)775 TEST_F(ValidateArithmetics, VectorTimesMatrixWrongTypeId) {
776   const std::string body = R"(
777 %val = OpVectorTimesMatrix %f32mat22 %f32vec2_12 %f32mat22_1212
778 )";
779 
780   CompileSuccessfully(GenerateCode(body).c_str());
781   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
782   EXPECT_THAT(getDiagnosticString(),
783               HasSubstr("Expected float vector type as Result Type: "
784                         "VectorTimesMatrix"));
785 }
786 
TEST_F(ValidateArithmetics,VectorTimesMatrixNotFloatVector)787 TEST_F(ValidateArithmetics, VectorTimesMatrixNotFloatVector) {
788   const std::string body = R"(
789 %val = OpVectorTimesMatrix %f32vec2 %u32vec2_12 %f32mat22_1212
790 )";
791 
792   CompileSuccessfully(GenerateCode(body).c_str());
793   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
794   EXPECT_THAT(getDiagnosticString(),
795               HasSubstr("Expected float vector type as left operand: "
796                         "VectorTimesMatrix"));
797 }
798 
TEST_F(ValidateArithmetics,VectorTimesMatrixWrongVectorComponent)799 TEST_F(ValidateArithmetics, VectorTimesMatrixWrongVectorComponent) {
800   const std::string body = R"(
801 %val = OpVectorTimesMatrix %f32vec2 %f64vec2_12 %f32mat22_1212
802 )";
803 
804   CompileSuccessfully(GenerateCode(body).c_str());
805   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
806   EXPECT_THAT(
807       getDiagnosticString(),
808       HasSubstr(
809           "Expected component types of Result Type and vector to be equal: "
810           "VectorTimesMatrix"));
811 }
812 
TEST_F(ValidateArithmetics,VectorTimesMatrixWrongMatrix)813 TEST_F(ValidateArithmetics, VectorTimesMatrixWrongMatrix) {
814   const std::string body = R"(
815 %val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f32vec2_12
816 )";
817 
818   CompileSuccessfully(GenerateCode(body).c_str());
819   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
820   EXPECT_THAT(getDiagnosticString(),
821               HasSubstr("Expected float matrix type as right operand: "
822                         "VectorTimesMatrix"));
823 }
824 
TEST_F(ValidateArithmetics,VectorTimesMatrixWrongMatrixComponent)825 TEST_F(ValidateArithmetics, VectorTimesMatrixWrongMatrixComponent) {
826   const std::string body = R"(
827 %val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f64mat22_1212
828 )";
829 
830   CompileSuccessfully(GenerateCode(body).c_str());
831   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
832   EXPECT_THAT(
833       getDiagnosticString(),
834       HasSubstr(
835           "Expected component types of Result Type and matrix to be equal: "
836           "VectorTimesMatrix"));
837 }
838 
TEST_F(ValidateArithmetics,VectorTimesMatrix2eq2x23Fail)839 TEST_F(ValidateArithmetics, VectorTimesMatrix2eq2x23Fail) {
840   const std::string body = R"(
841 %val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f32mat23_121212
842 )";
843 
844   CompileSuccessfully(GenerateCode(body).c_str());
845   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
846   EXPECT_THAT(
847       getDiagnosticString(),
848       HasSubstr(
849           "Expected number of columns of the matrix to be equal to Result Type "
850           "vector size: VectorTimesMatrix"));
851 }
852 
TEST_F(ValidateArithmetics,VectorTimesMatrix2x32Fail)853 TEST_F(ValidateArithmetics, VectorTimesMatrix2x32Fail) {
854   const std::string body = R"(
855 %val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f32mat32_123123
856 )";
857 
858   CompileSuccessfully(GenerateCode(body).c_str());
859   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
860   EXPECT_THAT(
861       getDiagnosticString(),
862       HasSubstr(
863           "Expected number of rows of the matrix to be equal to the vector "
864           "operand size: VectorTimesMatrix"));
865 }
866 
TEST_F(ValidateArithmetics,MatrixTimesVector22x2Success)867 TEST_F(ValidateArithmetics, MatrixTimesVector22x2Success) {
868   const std::string body = R"(
869 %val = OpMatrixTimesVector %f32vec2 %f32mat22_1212 %f32vec2_12
870 )";
871 
872   CompileSuccessfully(GenerateCode(body).c_str());
873   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
874 }
875 
TEST_F(ValidateArithmetics,MatrixTimesVector23x3Success)876 TEST_F(ValidateArithmetics, MatrixTimesVector23x3Success) {
877   const std::string body = R"(
878 %val = OpMatrixTimesVector %f32vec2 %f32mat23_121212 %f32vec3_123
879 )";
880 
881   CompileSuccessfully(GenerateCode(body).c_str());
882   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
883 }
884 
TEST_F(ValidateArithmetics,MatrixTimesVectorWrongTypeId)885 TEST_F(ValidateArithmetics, MatrixTimesVectorWrongTypeId) {
886   const std::string body = R"(
887 %val = OpMatrixTimesVector %f32mat22 %f32mat22_1212 %f32vec2_12
888 )";
889 
890   CompileSuccessfully(GenerateCode(body).c_str());
891   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
892   EXPECT_THAT(getDiagnosticString(),
893               HasSubstr("Expected float vector type as Result Type: "
894                         "MatrixTimesVector"));
895 }
896 
TEST_F(ValidateArithmetics,MatrixTimesVectorWrongMatrix)897 TEST_F(ValidateArithmetics, MatrixTimesVectorWrongMatrix) {
898   const std::string body = R"(
899 %val = OpMatrixTimesVector %f32vec3 %f32vec3_123 %f32vec3_123
900 )";
901 
902   CompileSuccessfully(GenerateCode(body).c_str());
903   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
904   EXPECT_THAT(getDiagnosticString(),
905               HasSubstr("Expected float matrix type as left operand: "
906                         "MatrixTimesVector"));
907 }
908 
TEST_F(ValidateArithmetics,MatrixTimesVectorWrongMatrixCol)909 TEST_F(ValidateArithmetics, MatrixTimesVectorWrongMatrixCol) {
910   const std::string body = R"(
911 %val = OpMatrixTimesVector %f32vec3 %f32mat23_121212 %f32vec3_123
912 )";
913 
914   CompileSuccessfully(GenerateCode(body).c_str());
915   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
916   EXPECT_THAT(
917       getDiagnosticString(),
918       HasSubstr(
919           "Expected column type of the matrix to be equal to Result Type: "
920           "MatrixTimesVector"));
921 }
922 
TEST_F(ValidateArithmetics,MatrixTimesVectorWrongVector)923 TEST_F(ValidateArithmetics, MatrixTimesVectorWrongVector) {
924   const std::string body = R"(
925 %val = OpMatrixTimesVector %f32vec2 %f32mat22_1212 %u32vec2_12
926 )";
927 
928   CompileSuccessfully(GenerateCode(body).c_str());
929   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
930   EXPECT_THAT(getDiagnosticString(),
931               HasSubstr("Expected float vector type as right operand: "
932                         "MatrixTimesVector"));
933 }
934 
TEST_F(ValidateArithmetics,MatrixTimesVectorDifferentComponents)935 TEST_F(ValidateArithmetics, MatrixTimesVectorDifferentComponents) {
936   const std::string body = R"(
937 %val = OpMatrixTimesVector %f32vec2 %f32mat22_1212 %f64vec2_12
938 )";
939 
940   CompileSuccessfully(GenerateCode(body).c_str());
941   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
942   EXPECT_THAT(getDiagnosticString(),
943               HasSubstr("Expected component types of the operands to be equal: "
944                         "MatrixTimesVector"));
945 }
946 
TEST_F(ValidateArithmetics,MatrixTimesVector22x3Fail)947 TEST_F(ValidateArithmetics, MatrixTimesVector22x3Fail) {
948   const std::string body = R"(
949 %val = OpMatrixTimesVector %f32vec2 %f32mat22_1212 %f32vec3_123
950 )";
951 
952   CompileSuccessfully(GenerateCode(body).c_str());
953   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
954   EXPECT_THAT(
955       getDiagnosticString(),
956       HasSubstr(
957           "Expected number of columns of the matrix to be equal to the vector "
958           "size: MatrixTimesVector"));
959 }
960 
TEST_F(ValidateArithmetics,MatrixTimesMatrix22x22Success)961 TEST_F(ValidateArithmetics, MatrixTimesMatrix22x22Success) {
962   const std::string body = R"(
963 %val = OpMatrixTimesMatrix %f32mat22 %f32mat22_1212 %f32mat22_1212
964 )";
965 
966   CompileSuccessfully(GenerateCode(body).c_str());
967   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
968 }
969 
TEST_F(ValidateArithmetics,MatrixTimesMatrix23x32Success)970 TEST_F(ValidateArithmetics, MatrixTimesMatrix23x32Success) {
971   const std::string body = R"(
972 %val = OpMatrixTimesMatrix %f32mat22 %f32mat23_121212 %f32mat32_123123
973 )";
974 
975   CompileSuccessfully(GenerateCode(body).c_str());
976   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
977 }
978 
TEST_F(ValidateArithmetics,MatrixTimesMatrix33x33Success)979 TEST_F(ValidateArithmetics, MatrixTimesMatrix33x33Success) {
980   const std::string body = R"(
981 %val = OpMatrixTimesMatrix %f32mat33 %f32mat33_123123123 %f32mat33_123123123
982 )";
983 
984   CompileSuccessfully(GenerateCode(body).c_str());
985   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
986 }
987 
TEST_F(ValidateArithmetics,MatrixTimesMatrixWrongTypeId)988 TEST_F(ValidateArithmetics, MatrixTimesMatrixWrongTypeId) {
989   const std::string body = R"(
990 %val = OpMatrixTimesMatrix %f32vec2 %f32mat22_1212 %f32mat22_1212
991 )";
992 
993   CompileSuccessfully(GenerateCode(body).c_str());
994   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
995   EXPECT_THAT(
996       getDiagnosticString(),
997       HasSubstr(
998           "Expected float matrix type as Result Type: MatrixTimesMatrix"));
999 }
1000 
TEST_F(ValidateArithmetics,MatrixTimesMatrixWrongLeftOperand)1001 TEST_F(ValidateArithmetics, MatrixTimesMatrixWrongLeftOperand) {
1002   const std::string body = R"(
1003 %val = OpMatrixTimesMatrix %f32mat22 %f32vec2_12 %f32mat22_1212
1004 )";
1005 
1006   CompileSuccessfully(GenerateCode(body).c_str());
1007   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1008   EXPECT_THAT(
1009       getDiagnosticString(),
1010       HasSubstr(
1011           "Expected float matrix type as left operand: MatrixTimesMatrix"));
1012 }
1013 
TEST_F(ValidateArithmetics,MatrixTimesMatrixWrongRightOperand)1014 TEST_F(ValidateArithmetics, MatrixTimesMatrixWrongRightOperand) {
1015   const std::string body = R"(
1016 %val = OpMatrixTimesMatrix %f32mat22 %f32mat22_1212 %f32vec2_12
1017 )";
1018 
1019   CompileSuccessfully(GenerateCode(body).c_str());
1020   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1021   EXPECT_THAT(
1022       getDiagnosticString(),
1023       HasSubstr(
1024           "Expected float matrix type as right operand: MatrixTimesMatrix"));
1025 }
1026 
TEST_F(ValidateArithmetics,MatrixTimesMatrix32x23Fail)1027 TEST_F(ValidateArithmetics, MatrixTimesMatrix32x23Fail) {
1028   const std::string body = R"(
1029 %val = OpMatrixTimesMatrix %f32mat22 %f32mat32_123123 %f32mat23_121212
1030 )";
1031 
1032   CompileSuccessfully(GenerateCode(body).c_str());
1033   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1034   EXPECT_THAT(
1035       getDiagnosticString(),
1036       HasSubstr(
1037           "Expected column types of Result Type and left matrix to be equal: "
1038           "MatrixTimesMatrix"));
1039 }
1040 
TEST_F(ValidateArithmetics,MatrixTimesMatrixDifferentComponents)1041 TEST_F(ValidateArithmetics, MatrixTimesMatrixDifferentComponents) {
1042   const std::string body = R"(
1043 %val = OpMatrixTimesMatrix %f32mat22 %f32mat22_1212 %f64mat22_1212
1044 )";
1045 
1046   CompileSuccessfully(GenerateCode(body).c_str());
1047   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1048   EXPECT_THAT(getDiagnosticString(),
1049               HasSubstr("Expected component types of Result Type and right "
1050                         "matrix to be equal: "
1051                         "MatrixTimesMatrix"));
1052 }
1053 
TEST_F(ValidateArithmetics,MatrixTimesMatrix23x23Fail)1054 TEST_F(ValidateArithmetics, MatrixTimesMatrix23x23Fail) {
1055   const std::string body = R"(
1056 %val = OpMatrixTimesMatrix %f32mat22 %f32mat23_121212 %f32mat23_121212
1057 )";
1058 
1059   CompileSuccessfully(GenerateCode(body).c_str());
1060   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1061   EXPECT_THAT(getDiagnosticString(),
1062               HasSubstr("Expected number of columns of Result Type and right "
1063                         "matrix to be equal: "
1064                         "MatrixTimesMatrix"));
1065 }
1066 
TEST_F(ValidateArithmetics,MatrixTimesMatrix23x22Fail)1067 TEST_F(ValidateArithmetics, MatrixTimesMatrix23x22Fail) {
1068   const std::string body = R"(
1069 %val = OpMatrixTimesMatrix %f32mat22 %f32mat23_121212 %f32mat22_1212
1070 )";
1071 
1072   CompileSuccessfully(GenerateCode(body).c_str());
1073   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1074   EXPECT_THAT(getDiagnosticString(),
1075               HasSubstr("Expected number of columns of left matrix and number "
1076                         "of rows of right "
1077                         "matrix to be equal: MatrixTimesMatrix"));
1078 }
1079 
TEST_F(ValidateArithmetics,OuterProduct2x2Success)1080 TEST_F(ValidateArithmetics, OuterProduct2x2Success) {
1081   const std::string body = R"(
1082 %val = OpOuterProduct %f32mat22 %f32vec2_12 %f32vec2_01
1083 )";
1084 
1085   CompileSuccessfully(GenerateCode(body).c_str());
1086   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
1087 }
1088 
TEST_F(ValidateArithmetics,OuterProduct3x2Success)1089 TEST_F(ValidateArithmetics, OuterProduct3x2Success) {
1090   const std::string body = R"(
1091 %val = OpOuterProduct %f32mat32 %f32vec3_123 %f32vec2_01
1092 )";
1093 
1094   CompileSuccessfully(GenerateCode(body).c_str());
1095   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
1096 }
1097 
TEST_F(ValidateArithmetics,OuterProduct2x3Success)1098 TEST_F(ValidateArithmetics, OuterProduct2x3Success) {
1099   const std::string body = R"(
1100 %val = OpOuterProduct %f32mat23 %f32vec2_01 %f32vec3_123
1101 )";
1102 
1103   CompileSuccessfully(GenerateCode(body).c_str());
1104   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
1105 }
1106 
TEST_F(ValidateArithmetics,OuterProductWrongTypeId)1107 TEST_F(ValidateArithmetics, OuterProductWrongTypeId) {
1108   const std::string body = R"(
1109 %val = OpOuterProduct %f32vec2 %f32vec2_01 %f32vec3_123
1110 )";
1111 
1112   CompileSuccessfully(GenerateCode(body).c_str());
1113   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1114   EXPECT_THAT(getDiagnosticString(),
1115               HasSubstr("Expected float matrix type as Result Type: "
1116                         "OuterProduct"));
1117 }
1118 
TEST_F(ValidateArithmetics,OuterProductWrongLeftOperand)1119 TEST_F(ValidateArithmetics, OuterProductWrongLeftOperand) {
1120   const std::string body = R"(
1121 %val = OpOuterProduct %f32mat22 %f32vec3_123 %f32vec2_01
1122 )";
1123 
1124   CompileSuccessfully(GenerateCode(body).c_str());
1125   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1126   EXPECT_THAT(
1127       getDiagnosticString(),
1128       HasSubstr("Expected column type of Result Type to be equal to the type "
1129                 "of the left operand: OuterProduct"));
1130 }
1131 
TEST_F(ValidateArithmetics,OuterProductRightOperandNotFloatVector)1132 TEST_F(ValidateArithmetics, OuterProductRightOperandNotFloatVector) {
1133   const std::string body = R"(
1134 %val = OpOuterProduct %f32mat22 %f32vec2_12 %u32vec2_01
1135 )";
1136 
1137   CompileSuccessfully(GenerateCode(body).c_str());
1138   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1139   EXPECT_THAT(
1140       getDiagnosticString(),
1141       HasSubstr("Expected float vector type as right operand: OuterProduct"));
1142 }
1143 
TEST_F(ValidateArithmetics,OuterProductRightOperandWrongComponent)1144 TEST_F(ValidateArithmetics, OuterProductRightOperandWrongComponent) {
1145   const std::string body = R"(
1146 %val = OpOuterProduct %f32mat22 %f32vec2_12 %f64vec2_01
1147 )";
1148 
1149   CompileSuccessfully(GenerateCode(body).c_str());
1150   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1151   EXPECT_THAT(getDiagnosticString(),
1152               HasSubstr("Expected component types of the operands to be equal: "
1153                         "OuterProduct"));
1154 }
1155 
TEST_F(ValidateArithmetics,OuterProductRightOperandWrongDimension)1156 TEST_F(ValidateArithmetics, OuterProductRightOperandWrongDimension) {
1157   const std::string body = R"(
1158 %val = OpOuterProduct %f32mat22 %f32vec2_12 %f32vec3_123
1159 )";
1160 
1161   CompileSuccessfully(GenerateCode(body).c_str());
1162   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1163   EXPECT_THAT(
1164       getDiagnosticString(),
1165       HasSubstr("Expected number of columns of the matrix to be equal to the "
1166                 "vector size of the right operand: OuterProduct"));
1167 }
1168 
GenerateCoopMatCode(const std::string & extra_types,const std::string & main_body)1169 std::string GenerateCoopMatCode(const std::string& extra_types,
1170                                 const std::string& main_body) {
1171   const std::string prefix =
1172       R"(
1173 OpCapability Shader
1174 OpCapability Float16
1175 OpCapability CooperativeMatrixNV
1176 OpExtension "SPV_NV_cooperative_matrix"
1177 OpMemoryModel Logical GLSL450
1178 OpEntryPoint GLCompute %main "main"
1179 %void = OpTypeVoid
1180 %func = OpTypeFunction %void
1181 %bool = OpTypeBool
1182 %f16 = OpTypeFloat 16
1183 %f32 = OpTypeFloat 32
1184 %u32 = OpTypeInt 32 0
1185 %s32 = OpTypeInt 32 1
1186 
1187 %u32_8 = OpConstant %u32 8
1188 %u32_16 = OpConstant %u32 16
1189 %u32_4 = OpConstant %u32 4
1190 %subgroup = OpConstant %u32 3
1191 
1192 %f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
1193 %u32mat = OpTypeCooperativeMatrixNV %u32 %subgroup %u32_8 %u32_8
1194 %s32mat = OpTypeCooperativeMatrixNV %s32 %subgroup %u32_8 %u32_8
1195 
1196 %f16_1 = OpConstant %f16 1
1197 %f32_1 = OpConstant %f32 1
1198 %u32_1 = OpConstant %u32 1
1199 %s32_1 = OpConstant %s32 1
1200 
1201 %f16mat_1 = OpConstantComposite %f16mat %f16_1
1202 %u32mat_1 = OpConstantComposite %u32mat %u32_1
1203 %s32mat_1 = OpConstantComposite %s32mat %s32_1
1204 
1205 %u32_c1 = OpSpecConstant %u32 1
1206 %u32_c2 = OpSpecConstant %u32 2
1207 
1208 %f16matc = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_c1 %u32_c2
1209 %f16matc_1 = OpConstantComposite %f16matc %f16_1
1210 
1211 %mat16x4 = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_16 %u32_4
1212 %mat4x16 = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_4 %u32_16
1213 %mat16x16 = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_16 %u32_16
1214 %f16mat_16x4_1 = OpConstantComposite %mat16x4 %f16_1
1215 %f16mat_4x16_1 = OpConstantComposite %mat4x16 %f16_1
1216 %f16mat_16x16_1 = OpConstantComposite %mat16x16 %f16_1)";
1217 
1218   const std::string func_begin =
1219       R"(
1220 %main = OpFunction %void None %func
1221 %main_entry = OpLabel)";
1222 
1223   const std::string suffix =
1224       R"(
1225 OpReturn
1226 OpFunctionEnd)";
1227 
1228   return prefix + extra_types + func_begin + main_body + suffix;
1229 }
1230 
TEST_F(ValidateArithmetics,CoopMatSuccess)1231 TEST_F(ValidateArithmetics, CoopMatSuccess) {
1232   const std::string body = R"(
1233 %val1 = OpFAdd %f16mat %f16mat_1 %f16mat_1
1234 %val2 = OpFSub %f16mat %f16mat_1 %f16mat_1
1235 %val3 = OpFDiv %f16mat %f16mat_1 %f16mat_1
1236 %val4 = OpFNegate %f16mat %f16mat_1
1237 %val5 = OpIAdd %u32mat %u32mat_1 %u32mat_1
1238 %val6 = OpISub %u32mat %u32mat_1 %u32mat_1
1239 %val7 = OpUDiv %u32mat %u32mat_1 %u32mat_1
1240 %val8 = OpIAdd %s32mat %s32mat_1 %s32mat_1
1241 %val9 = OpISub %s32mat %s32mat_1 %s32mat_1
1242 %val10 = OpSDiv %s32mat %s32mat_1 %s32mat_1
1243 %val11 = OpSNegate %s32mat %s32mat_1
1244 %val12 = OpMatrixTimesScalar %f16mat %f16mat_1 %f16_1
1245 %val13 = OpMatrixTimesScalar %u32mat %u32mat_1 %u32_1
1246 %val14 = OpMatrixTimesScalar %s32mat %s32mat_1 %s32_1
1247 %val15 = OpCooperativeMatrixMulAddNV %mat16x16 %f16mat_16x4_1 %f16mat_4x16_1 %f16mat_16x16_1
1248 %val16 = OpCooperativeMatrixMulAddNV %f16matc %f16matc_1 %f16matc_1 %f16matc_1
1249 )";
1250 
1251   CompileSuccessfully(GenerateCoopMatCode("", body).c_str());
1252   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
1253 }
1254 
TEST_F(ValidateArithmetics,CoopMatFMulFail)1255 TEST_F(ValidateArithmetics, CoopMatFMulFail) {
1256   const std::string body = R"(
1257 %val1 = OpFMul %f16mat %f16mat_1 %f16mat_1
1258 )";
1259 
1260   CompileSuccessfully(GenerateCoopMatCode("", body).c_str());
1261   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1262   EXPECT_THAT(
1263       getDiagnosticString(),
1264       HasSubstr(
1265           "Expected floating scalar or vector type as Result Type: FMul"));
1266 }
1267 
TEST_F(ValidateArithmetics,CoopMatMatrixTimesScalarMismatchFail)1268 TEST_F(ValidateArithmetics, CoopMatMatrixTimesScalarMismatchFail) {
1269   const std::string body = R"(
1270 %val1 = OpMatrixTimesScalar %f16mat %f16mat_1 %f32_1
1271 )";
1272 
1273   CompileSuccessfully(GenerateCoopMatCode("", body).c_str());
1274   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1275   EXPECT_THAT(
1276       getDiagnosticString(),
1277       HasSubstr("Expected scalar operand type to be equal to the component "
1278                 "type of the matrix operand: MatrixTimesScalar"));
1279 }
1280 
TEST_F(ValidateArithmetics,CoopMatScopeFail)1281 TEST_F(ValidateArithmetics, CoopMatScopeFail) {
1282   const std::string types = R"(
1283 %workgroup = OpConstant %u32 2
1284 
1285 %mat16x16_wg = OpTypeCooperativeMatrixNV %f16 %workgroup %u32_16 %u32_16
1286 %f16matwg_16x16_1 = OpConstantComposite %mat16x16_wg %f16_1
1287 )";
1288 
1289   const std::string body = R"(
1290 %val1 = OpCooperativeMatrixMulAddNV %mat16x16 %f16mat_16x4_1 %f16mat_4x16_1 %f16matwg_16x16_1
1291 )";
1292 
1293   CompileSuccessfully(GenerateCoopMatCode(types, body).c_str());
1294   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1295   EXPECT_THAT(
1296       getDiagnosticString(),
1297       HasSubstr(
1298           "Cooperative matrix scopes must match: CooperativeMatrixMulAddNV"));
1299 }
1300 
TEST_F(ValidateArithmetics,CoopMatDimFail)1301 TEST_F(ValidateArithmetics, CoopMatDimFail) {
1302   const std::string body = R"(
1303 %val1 = OpCooperativeMatrixMulAddNV %mat16x16 %f16mat_4x16_1 %f16mat_16x4_1 %f16mat_16x16_1
1304 )";
1305 
1306   CompileSuccessfully(GenerateCoopMatCode("", body).c_str());
1307   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1308   EXPECT_THAT(
1309       getDiagnosticString(),
1310       HasSubstr("Cooperative matrix 'M' mismatch: CooperativeMatrixMulAddNV"));
1311 }
1312 
TEST_F(ValidateArithmetics,CoopMatComponentTypeNotScalarNumeric)1313 TEST_F(ValidateArithmetics, CoopMatComponentTypeNotScalarNumeric) {
1314   const std::string types = R"(
1315 %bad = OpTypeCooperativeMatrixNV %bool %subgroup %u32_8 %u32_8
1316 )";
1317 
1318   CompileSuccessfully(GenerateCoopMatCode(types, "").c_str());
1319   EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
1320   EXPECT_THAT(getDiagnosticString(),
1321               HasSubstr("OpTypeCooperativeMatrixNV Component Type <id> "
1322                         "'4[%bool]' is not a scalar numerical type."));
1323 }
1324 
TEST_F(ValidateArithmetics,CoopMatScopeNotConstantInt)1325 TEST_F(ValidateArithmetics, CoopMatScopeNotConstantInt) {
1326   const std::string types = R"(
1327 %bad = OpTypeCooperativeMatrixNV %f16 %f32_1 %u32_8 %u32_8
1328 )";
1329 
1330   CompileSuccessfully(GenerateCoopMatCode(types, "").c_str());
1331   EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
1332   EXPECT_THAT(
1333       getDiagnosticString(),
1334       HasSubstr("OpTypeCooperativeMatrixNV Scope <id> '17[%float_1]' is not a "
1335                 "constant instruction with scalar integer type."));
1336 }
1337 
TEST_F(ValidateArithmetics,CoopMatRowsNotConstantInt)1338 TEST_F(ValidateArithmetics, CoopMatRowsNotConstantInt) {
1339   const std::string types = R"(
1340 %bad = OpTypeCooperativeMatrixNV %f16 %subgroup %f32_1 %u32_8
1341 )";
1342 
1343   CompileSuccessfully(GenerateCoopMatCode(types, "").c_str());
1344   EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
1345   EXPECT_THAT(
1346       getDiagnosticString(),
1347       HasSubstr("OpTypeCooperativeMatrixNV Rows <id> '17[%float_1]' is not a "
1348                 "constant instruction with scalar integer type."));
1349 }
1350 
TEST_F(ValidateArithmetics,CoopMatColumnsNotConstantInt)1351 TEST_F(ValidateArithmetics, CoopMatColumnsNotConstantInt) {
1352   const std::string types = R"(
1353 %bad = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %f32_1
1354 )";
1355 
1356   CompileSuccessfully(GenerateCoopMatCode(types, "").c_str());
1357   EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
1358   EXPECT_THAT(
1359       getDiagnosticString(),
1360       HasSubstr("OpTypeCooperativeMatrixNV Cols <id> '17[%float_1]' is not a "
1361                 "constant instruction with scalar integer type."));
1362 }
1363 
TEST_F(ValidateArithmetics,IAddCarrySuccess)1364 TEST_F(ValidateArithmetics, IAddCarrySuccess) {
1365   const std::string body = R"(
1366 %val1 = OpIAddCarry %struct_u32_u32 %u32_0 %u32_1
1367 %val2 = OpIAddCarry %struct_u32vec2_u32vec2 %u32vec2_01 %u32vec2_12
1368 )";
1369 
1370   CompileSuccessfully(GenerateCode(body).c_str());
1371   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
1372 }
1373 
TEST_F(ValidateArithmetics,IAddCarryResultTypeNotStruct)1374 TEST_F(ValidateArithmetics, IAddCarryResultTypeNotStruct) {
1375   const std::string body = R"(
1376 %val = OpIAddCarry %u32 %u32_0 %u32_1
1377 )";
1378 
1379   CompileSuccessfully(GenerateCode(body).c_str());
1380   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1381   EXPECT_THAT(getDiagnosticString(),
1382               HasSubstr("Expected a struct as Result Type: IAddCarry"));
1383 }
1384 
TEST_F(ValidateArithmetics,IAddCarryResultTypeNotTwoMembers)1385 TEST_F(ValidateArithmetics, IAddCarryResultTypeNotTwoMembers) {
1386   const std::string body = R"(
1387 %val = OpIAddCarry %struct_u32_u32_u32 %u32_0 %u32_1
1388 )";
1389 
1390   CompileSuccessfully(GenerateCode(body).c_str());
1391   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1392   EXPECT_THAT(
1393       getDiagnosticString(),
1394       HasSubstr("Expected Result Type struct to have two members: IAddCarry"));
1395 }
1396 
TEST_F(ValidateArithmetics,IAddCarryResultTypeMemberNotUnsignedInt)1397 TEST_F(ValidateArithmetics, IAddCarryResultTypeMemberNotUnsignedInt) {
1398   const std::string body = R"(
1399 %val = OpIAddCarry %struct_s32_s32 %s32_0 %s32_1
1400 )";
1401 
1402   CompileSuccessfully(GenerateCode(body).c_str());
1403   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1404   EXPECT_THAT(getDiagnosticString(),
1405               HasSubstr("Expected Result Type struct member types to be "
1406                         "unsigned integer scalar "
1407                         "or vector: IAddCarry"));
1408 }
1409 
TEST_F(ValidateArithmetics,IAddCarryWrongLeftOperand)1410 TEST_F(ValidateArithmetics, IAddCarryWrongLeftOperand) {
1411   const std::string body = R"(
1412 %val = OpIAddCarry %struct_u32_u32 %s32_0 %u32_1
1413 )";
1414 
1415   CompileSuccessfully(GenerateCode(body).c_str());
1416   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1417   EXPECT_THAT(getDiagnosticString(),
1418               HasSubstr("Expected both operands to be of Result Type member "
1419                         "type: IAddCarry"));
1420 }
1421 
TEST_F(ValidateArithmetics,IAddCarryWrongRightOperand)1422 TEST_F(ValidateArithmetics, IAddCarryWrongRightOperand) {
1423   const std::string body = R"(
1424 %val = OpIAddCarry %struct_u32_u32 %u32_0 %s32_1
1425 )";
1426 
1427   CompileSuccessfully(GenerateCode(body).c_str());
1428   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1429   EXPECT_THAT(getDiagnosticString(),
1430               HasSubstr("Expected both operands to be of Result Type member "
1431                         "type: IAddCarry"));
1432 }
1433 
TEST_F(ValidateArithmetics,OpSMulExtendedSuccess)1434 TEST_F(ValidateArithmetics, OpSMulExtendedSuccess) {
1435   const std::string body = R"(
1436 %val1 = OpSMulExtended %struct_u32_u32 %u32_0 %u32_1
1437 %val2 = OpSMulExtended %struct_s32_s32 %s32_0 %s32_1
1438 %val3 = OpSMulExtended %struct_u32vec2_u32vec2 %u32vec2_01 %u32vec2_12
1439 %val4 = OpSMulExtended %struct_s32vec2_s32vec2 %s32vec2_01 %s32vec2_12
1440 )";
1441 
1442   CompileSuccessfully(GenerateCode(body).c_str());
1443   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
1444 }
1445 
TEST_F(ValidateArithmetics,SMulExtendedResultTypeMemberNotInt)1446 TEST_F(ValidateArithmetics, SMulExtendedResultTypeMemberNotInt) {
1447   const std::string body = R"(
1448 %val = OpSMulExtended %struct_f32_f32 %f32_0 %f32_1
1449 )";
1450 
1451   CompileSuccessfully(GenerateCode(body).c_str());
1452   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1453   EXPECT_THAT(
1454       getDiagnosticString(),
1455       HasSubstr("Expected Result Type struct member types to be integer scalar "
1456                 "or vector: SMulExtended"));
1457 }
1458 
TEST_F(ValidateArithmetics,SMulExtendedResultTypeMembersNotIdentical)1459 TEST_F(ValidateArithmetics, SMulExtendedResultTypeMembersNotIdentical) {
1460   const std::string body = R"(
1461 %val = OpSMulExtended %struct_s32_u32 %s32_0 %s32_1
1462 )";
1463 
1464   CompileSuccessfully(GenerateCode(body).c_str());
1465   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
1466   EXPECT_THAT(
1467       getDiagnosticString(),
1468       HasSubstr("Expected Result Type struct member types to be identical: "
1469                 "SMulExtended"));
1470 }
1471 
1472 }  // namespace
1473 }  // namespace val
1474 }  // namespace spvtools
1475