• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/writer/spirv/spv_dump.h"
16 #include "src/writer/spirv/test_helper.h"
17 
18 namespace tint {
19 namespace writer {
20 namespace spirv {
21 namespace {
22 
23 using BuilderTest = TestHelper;
24 
25 struct BinaryData {
26   ast::BinaryOp op;
27   std::string name;
28 };
operator <<(std::ostream & out,BinaryData data)29 inline std::ostream& operator<<(std::ostream& out, BinaryData data) {
30   out << data.op;
31   return out;
32 }
33 
34 using BinaryArithSignedIntegerTest = TestParamHelper<BinaryData>;
TEST_P(BinaryArithSignedIntegerTest,Scalar)35 TEST_P(BinaryArithSignedIntegerTest, Scalar) {
36   auto param = GetParam();
37 
38   auto* lhs = Expr(3);
39   auto* rhs = Expr(4);
40 
41   auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
42 
43   WrapInFunction(expr);
44 
45   spirv::Builder& b = Build();
46 
47   b.push_function(Function{});
48 
49   EXPECT_EQ(b.GenerateBinaryExpression(expr), 4u) << b.error();
50   EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 1
51 %2 = OpConstant %1 3
52 %3 = OpConstant %1 4
53 )");
54   EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
55             "%4 = " + param.name + " %1 %2 %3\n");
56 }
57 
TEST_P(BinaryArithSignedIntegerTest,Vector)58 TEST_P(BinaryArithSignedIntegerTest, Vector) {
59   auto param = GetParam();
60 
61   // Skip ops that are illegal for this type
62   if (param.op == ast::BinaryOp::kAnd || param.op == ast::BinaryOp::kOr ||
63       param.op == ast::BinaryOp::kXor) {
64     return;
65   }
66 
67   auto* lhs = vec3<i32>(1, 1, 1);
68   auto* rhs = vec3<i32>(1, 1, 1);
69 
70   auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
71 
72   WrapInFunction(expr);
73 
74   spirv::Builder& b = Build();
75 
76   b.push_function(Function{});
77 
78   EXPECT_EQ(b.GenerateBinaryExpression(expr), 5u) << b.error();
79   EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1
80 %1 = OpTypeVector %2 3
81 %3 = OpConstant %2 1
82 %4 = OpConstantComposite %1 %3 %3 %3
83 )");
84   EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
85             "%5 = " + param.name + " %1 %4 %4\n");
86 }
TEST_P(BinaryArithSignedIntegerTest,Scalar_Loads)87 TEST_P(BinaryArithSignedIntegerTest, Scalar_Loads) {
88   auto param = GetParam();
89 
90   auto* var = Var("param", ty.i32());
91   auto* expr =
92       create<ast::BinaryExpression>(param.op, Expr("param"), Expr("param"));
93 
94   WrapInFunction(var, expr);
95 
96   spirv::Builder& b = Build();
97 
98   b.push_function(Function{});
99   EXPECT_TRUE(b.GenerateFunctionVariable(var)) << b.error();
100   EXPECT_EQ(b.GenerateBinaryExpression(expr), 7u) << b.error();
101   ASSERT_FALSE(b.has_error()) << b.error();
102 
103   EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeInt 32 1
104 %2 = OpTypePointer Function %3
105 %4 = OpConstantNull %3
106 )");
107   EXPECT_EQ(DumpInstructions(b.functions()[0].variables()),
108             R"(%1 = OpVariable %2 Function %4
109 )");
110   EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
111             R"(%5 = OpLoad %3 %1
112 %6 = OpLoad %3 %1
113 %7 = )" + param.name +
114                 R"( %3 %5 %6
115 )");
116 }
117 INSTANTIATE_TEST_SUITE_P(
118     BuilderTest,
119     BinaryArithSignedIntegerTest,
120     // NOTE: No left and right shift as they require u32 for rhs operand
121     testing::Values(BinaryData{ast::BinaryOp::kAdd, "OpIAdd"},
122                     BinaryData{ast::BinaryOp::kAnd, "OpBitwiseAnd"},
123                     BinaryData{ast::BinaryOp::kDivide, "OpSDiv"},
124                     BinaryData{ast::BinaryOp::kModulo, "OpSMod"},
125                     BinaryData{ast::BinaryOp::kMultiply, "OpIMul"},
126                     BinaryData{ast::BinaryOp::kOr, "OpBitwiseOr"},
127                     BinaryData{ast::BinaryOp::kSubtract, "OpISub"},
128                     BinaryData{ast::BinaryOp::kXor, "OpBitwiseXor"}));
129 
130 using BinaryArithUnsignedIntegerTest = TestParamHelper<BinaryData>;
TEST_P(BinaryArithUnsignedIntegerTest,Scalar)131 TEST_P(BinaryArithUnsignedIntegerTest, Scalar) {
132   auto param = GetParam();
133 
134   auto* lhs = Expr(3u);
135   auto* rhs = Expr(4u);
136 
137   auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
138 
139   WrapInFunction(expr);
140 
141   spirv::Builder& b = Build();
142 
143   b.push_function(Function{});
144 
145   EXPECT_EQ(b.GenerateBinaryExpression(expr), 4u) << b.error();
146   EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 0
147 %2 = OpConstant %1 3
148 %3 = OpConstant %1 4
149 )");
150   EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
151             "%4 = " + param.name + " %1 %2 %3\n");
152 }
TEST_P(BinaryArithUnsignedIntegerTest,Vector)153 TEST_P(BinaryArithUnsignedIntegerTest, Vector) {
154   auto param = GetParam();
155 
156   // Skip ops that are illegal for this type
157   if (param.op == ast::BinaryOp::kAnd || param.op == ast::BinaryOp::kOr ||
158       param.op == ast::BinaryOp::kXor) {
159     return;
160   }
161 
162   auto* lhs = vec3<u32>(1u, 1u, 1u);
163   auto* rhs = vec3<u32>(1u, 1u, 1u);
164 
165   auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
166 
167   WrapInFunction(expr);
168 
169   spirv::Builder& b = Build();
170 
171   b.push_function(Function{});
172 
173   EXPECT_EQ(b.GenerateBinaryExpression(expr), 5u) << b.error();
174   EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 0
175 %1 = OpTypeVector %2 3
176 %3 = OpConstant %2 1
177 %4 = OpConstantComposite %1 %3 %3 %3
178 )");
179   EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
180             "%5 = " + param.name + " %1 %4 %4\n");
181 }
182 INSTANTIATE_TEST_SUITE_P(
183     BuilderTest,
184     BinaryArithUnsignedIntegerTest,
185     testing::Values(BinaryData{ast::BinaryOp::kAdd, "OpIAdd"},
186                     BinaryData{ast::BinaryOp::kAnd, "OpBitwiseAnd"},
187                     BinaryData{ast::BinaryOp::kDivide, "OpUDiv"},
188                     BinaryData{ast::BinaryOp::kModulo, "OpUMod"},
189                     BinaryData{ast::BinaryOp::kMultiply, "OpIMul"},
190                     BinaryData{ast::BinaryOp::kOr, "OpBitwiseOr"},
191                     BinaryData{ast::BinaryOp::kShiftLeft, "OpShiftLeftLogical"},
192                     BinaryData{ast::BinaryOp::kShiftRight,
193                                "OpShiftRightLogical"},
194                     BinaryData{ast::BinaryOp::kSubtract, "OpISub"},
195                     BinaryData{ast::BinaryOp::kXor, "OpBitwiseXor"}));
196 
197 using BinaryArithFloatTest = TestParamHelper<BinaryData>;
TEST_P(BinaryArithFloatTest,Scalar)198 TEST_P(BinaryArithFloatTest, Scalar) {
199   auto param = GetParam();
200 
201   auto* lhs = Expr(3.2f);
202   auto* rhs = Expr(4.5f);
203 
204   auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
205 
206   WrapInFunction(expr);
207 
208   spirv::Builder& b = Build();
209 
210   b.push_function(Function{});
211 
212   EXPECT_EQ(b.GenerateBinaryExpression(expr), 4u) << b.error();
213   EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32
214 %2 = OpConstant %1 3.20000005
215 %3 = OpConstant %1 4.5
216 )");
217   EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
218             "%4 = " + param.name + " %1 %2 %3\n");
219 }
220 
TEST_P(BinaryArithFloatTest,Vector)221 TEST_P(BinaryArithFloatTest, Vector) {
222   auto param = GetParam();
223 
224   auto* lhs = vec3<f32>(1.f, 1.f, 1.f);
225   auto* rhs = vec3<f32>(1.f, 1.f, 1.f);
226 
227   auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
228 
229   WrapInFunction(expr);
230 
231   spirv::Builder& b = Build();
232 
233   b.push_function(Function{});
234 
235   EXPECT_EQ(b.GenerateBinaryExpression(expr), 5u) << b.error();
236   EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
237 %1 = OpTypeVector %2 3
238 %3 = OpConstant %2 1
239 %4 = OpConstantComposite %1 %3 %3 %3
240 )");
241   EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
242             "%5 = " + param.name + " %1 %4 %4\n");
243 }
244 INSTANTIATE_TEST_SUITE_P(
245     BuilderTest,
246     BinaryArithFloatTest,
247     testing::Values(BinaryData{ast::BinaryOp::kAdd, "OpFAdd"},
248                     BinaryData{ast::BinaryOp::kDivide, "OpFDiv"},
249                     BinaryData{ast::BinaryOp::kModulo, "OpFRem"},
250                     BinaryData{ast::BinaryOp::kMultiply, "OpFMul"},
251                     BinaryData{ast::BinaryOp::kSubtract, "OpFSub"}));
252 
253 using BinaryOperatorBoolTest = TestParamHelper<BinaryData>;
TEST_P(BinaryOperatorBoolTest,Scalar)254 TEST_P(BinaryOperatorBoolTest, Scalar) {
255   auto param = GetParam();
256 
257   auto* lhs = Expr(true);
258   auto* rhs = Expr(false);
259 
260   auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
261 
262   WrapInFunction(expr);
263 
264   spirv::Builder& b = Build();
265 
266   b.push_function(Function{});
267 
268   EXPECT_EQ(b.GenerateBinaryExpression(expr), 4u) << b.error();
269   EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeBool
270 %2 = OpConstantTrue %1
271 %3 = OpConstantFalse %1
272 )");
273   EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
274             "%4 = " + param.name + " %1 %2 %3\n");
275 }
276 
TEST_P(BinaryOperatorBoolTest,Vector)277 TEST_P(BinaryOperatorBoolTest, Vector) {
278   auto param = GetParam();
279 
280   auto* lhs = vec3<bool>(false, true, false);
281   auto* rhs = vec3<bool>(true, false, true);
282 
283   auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
284 
285   WrapInFunction(expr);
286 
287   spirv::Builder& b = Build();
288 
289   b.push_function(Function{});
290 
291   EXPECT_EQ(b.GenerateBinaryExpression(expr), 7u) << b.error();
292   EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeBool
293 %1 = OpTypeVector %2 3
294 %3 = OpConstantFalse %2
295 %4 = OpConstantTrue %2
296 %5 = OpConstantComposite %1 %3 %4 %3
297 %6 = OpConstantComposite %1 %4 %3 %4
298 )");
299   EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
300             "%7 = " + param.name + " %1 %5 %6\n");
301 }
302 INSTANTIATE_TEST_SUITE_P(
303     BuilderTest,
304     BinaryOperatorBoolTest,
305     testing::Values(BinaryData{ast::BinaryOp::kEqual, "OpLogicalEqual"},
306                     BinaryData{ast::BinaryOp::kNotEqual, "OpLogicalNotEqual"},
307                     BinaryData{ast::BinaryOp::kAnd, "OpLogicalAnd"},
308                     BinaryData{ast::BinaryOp::kOr, "OpLogicalOr"}));
309 
310 using BinaryCompareUnsignedIntegerTest = TestParamHelper<BinaryData>;
TEST_P(BinaryCompareUnsignedIntegerTest,Scalar)311 TEST_P(BinaryCompareUnsignedIntegerTest, Scalar) {
312   auto param = GetParam();
313 
314   auto* lhs = Expr(3u);
315   auto* rhs = Expr(4u);
316 
317   auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
318 
319   WrapInFunction(expr);
320 
321   spirv::Builder& b = Build();
322 
323   b.push_function(Function{});
324 
325   EXPECT_EQ(b.GenerateBinaryExpression(expr), 4u) << b.error();
326   EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 0
327 %2 = OpConstant %1 3
328 %3 = OpConstant %1 4
329 %5 = OpTypeBool
330 )");
331   EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
332             "%4 = " + param.name + " %5 %2 %3\n");
333 }
334 
TEST_P(BinaryCompareUnsignedIntegerTest,Vector)335 TEST_P(BinaryCompareUnsignedIntegerTest, Vector) {
336   auto param = GetParam();
337 
338   auto* lhs = vec3<u32>(1u, 1u, 1u);
339   auto* rhs = vec3<u32>(1u, 1u, 1u);
340 
341   auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
342 
343   WrapInFunction(expr);
344 
345   spirv::Builder& b = Build();
346 
347   b.push_function(Function{});
348 
349   EXPECT_EQ(b.GenerateBinaryExpression(expr), 5u) << b.error();
350   EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 0
351 %1 = OpTypeVector %2 3
352 %3 = OpConstant %2 1
353 %4 = OpConstantComposite %1 %3 %3 %3
354 %7 = OpTypeBool
355 %6 = OpTypeVector %7 3
356 )");
357   EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
358             "%5 = " + param.name + " %6 %4 %4\n");
359 }
360 INSTANTIATE_TEST_SUITE_P(
361     BuilderTest,
362     BinaryCompareUnsignedIntegerTest,
363     testing::Values(
364         BinaryData{ast::BinaryOp::kEqual, "OpIEqual"},
365         BinaryData{ast::BinaryOp::kGreaterThan, "OpUGreaterThan"},
366         BinaryData{ast::BinaryOp::kGreaterThanEqual, "OpUGreaterThanEqual"},
367         BinaryData{ast::BinaryOp::kLessThan, "OpULessThan"},
368         BinaryData{ast::BinaryOp::kLessThanEqual, "OpULessThanEqual"},
369         BinaryData{ast::BinaryOp::kNotEqual, "OpINotEqual"}));
370 
371 using BinaryCompareSignedIntegerTest = TestParamHelper<BinaryData>;
TEST_P(BinaryCompareSignedIntegerTest,Scalar)372 TEST_P(BinaryCompareSignedIntegerTest, Scalar) {
373   auto param = GetParam();
374 
375   auto* lhs = Expr(3);
376   auto* rhs = Expr(4);
377 
378   auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
379 
380   WrapInFunction(expr);
381 
382   spirv::Builder& b = Build();
383 
384   b.push_function(Function{});
385 
386   EXPECT_EQ(b.GenerateBinaryExpression(expr), 4u) << b.error();
387   EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 1
388 %2 = OpConstant %1 3
389 %3 = OpConstant %1 4
390 %5 = OpTypeBool
391 )");
392   EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
393             "%4 = " + param.name + " %5 %2 %3\n");
394 }
395 
TEST_P(BinaryCompareSignedIntegerTest,Vector)396 TEST_P(BinaryCompareSignedIntegerTest, Vector) {
397   auto param = GetParam();
398 
399   auto* lhs = vec3<i32>(1, 1, 1);
400   auto* rhs = vec3<i32>(1, 1, 1);
401 
402   auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
403 
404   WrapInFunction(expr);
405 
406   spirv::Builder& b = Build();
407 
408   b.push_function(Function{});
409 
410   EXPECT_EQ(b.GenerateBinaryExpression(expr), 5u) << b.error();
411   EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1
412 %1 = OpTypeVector %2 3
413 %3 = OpConstant %2 1
414 %4 = OpConstantComposite %1 %3 %3 %3
415 %7 = OpTypeBool
416 %6 = OpTypeVector %7 3
417 )");
418   EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
419             "%5 = " + param.name + " %6 %4 %4\n");
420 }
421 INSTANTIATE_TEST_SUITE_P(
422     BuilderTest,
423     BinaryCompareSignedIntegerTest,
424     testing::Values(
425         BinaryData{ast::BinaryOp::kEqual, "OpIEqual"},
426         BinaryData{ast::BinaryOp::kGreaterThan, "OpSGreaterThan"},
427         BinaryData{ast::BinaryOp::kGreaterThanEqual, "OpSGreaterThanEqual"},
428         BinaryData{ast::BinaryOp::kLessThan, "OpSLessThan"},
429         BinaryData{ast::BinaryOp::kLessThanEqual, "OpSLessThanEqual"},
430         BinaryData{ast::BinaryOp::kNotEqual, "OpINotEqual"}));
431 
432 using BinaryCompareFloatTest = TestParamHelper<BinaryData>;
TEST_P(BinaryCompareFloatTest,Scalar)433 TEST_P(BinaryCompareFloatTest, Scalar) {
434   auto param = GetParam();
435 
436   auto* lhs = Expr(3.2f);
437   auto* rhs = Expr(4.5f);
438 
439   auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
440 
441   WrapInFunction(expr);
442 
443   spirv::Builder& b = Build();
444 
445   b.push_function(Function{});
446 
447   EXPECT_EQ(b.GenerateBinaryExpression(expr), 4u) << b.error();
448   EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32
449 %2 = OpConstant %1 3.20000005
450 %3 = OpConstant %1 4.5
451 %5 = OpTypeBool
452 )");
453   EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
454             "%4 = " + param.name + " %5 %2 %3\n");
455 }
456 
TEST_P(BinaryCompareFloatTest,Vector)457 TEST_P(BinaryCompareFloatTest, Vector) {
458   auto param = GetParam();
459 
460   auto* lhs = vec3<f32>(1.f, 1.f, 1.f);
461   auto* rhs = vec3<f32>(1.f, 1.f, 1.f);
462 
463   auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
464 
465   WrapInFunction(expr);
466 
467   spirv::Builder& b = Build();
468 
469   b.push_function(Function{});
470 
471   EXPECT_EQ(b.GenerateBinaryExpression(expr), 5u) << b.error();
472   EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
473 %1 = OpTypeVector %2 3
474 %3 = OpConstant %2 1
475 %4 = OpConstantComposite %1 %3 %3 %3
476 %7 = OpTypeBool
477 %6 = OpTypeVector %7 3
478 )");
479   EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
480             "%5 = " + param.name + " %6 %4 %4\n");
481 }
482 INSTANTIATE_TEST_SUITE_P(
483     BuilderTest,
484     BinaryCompareFloatTest,
485     testing::Values(
486         BinaryData{ast::BinaryOp::kEqual, "OpFOrdEqual"},
487         BinaryData{ast::BinaryOp::kGreaterThan, "OpFOrdGreaterThan"},
488         BinaryData{ast::BinaryOp::kGreaterThanEqual, "OpFOrdGreaterThanEqual"},
489         BinaryData{ast::BinaryOp::kLessThan, "OpFOrdLessThan"},
490         BinaryData{ast::BinaryOp::kLessThanEqual, "OpFOrdLessThanEqual"},
491         BinaryData{ast::BinaryOp::kNotEqual, "OpFOrdNotEqual"}));
492 
TEST_F(BuilderTest,Binary_Multiply_VectorScalar)493 TEST_F(BuilderTest, Binary_Multiply_VectorScalar) {
494   auto* lhs = vec3<f32>(1.f, 1.f, 1.f);
495   auto* rhs = Expr(1.f);
496 
497   auto* expr =
498       create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
499 
500   WrapInFunction(expr);
501 
502   spirv::Builder& b = Build();
503 
504   b.push_function(Function{});
505 
506   EXPECT_EQ(b.GenerateBinaryExpression(expr), 5u) << b.error();
507   EXPECT_EQ(DumpInstructions(b.types()),
508             R"(%2 = OpTypeFloat 32
509 %1 = OpTypeVector %2 3
510 %3 = OpConstant %2 1
511 %4 = OpConstantComposite %1 %3 %3 %3
512 )");
513   EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
514             "%5 = OpVectorTimesScalar %1 %4 %3\n");
515 }
516 
TEST_F(BuilderTest,Binary_Multiply_ScalarVector)517 TEST_F(BuilderTest, Binary_Multiply_ScalarVector) {
518   auto* lhs = Expr(1.f);
519   auto* rhs = vec3<f32>(1.f, 1.f, 1.f);
520 
521   auto* expr =
522       create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
523 
524   WrapInFunction(expr);
525 
526   spirv::Builder& b = Build();
527 
528   b.push_function(Function{});
529 
530   EXPECT_EQ(b.GenerateBinaryExpression(expr), 5u) << b.error();
531   EXPECT_EQ(DumpInstructions(b.types()),
532             R"(%1 = OpTypeFloat 32
533 %2 = OpConstant %1 1
534 %3 = OpTypeVector %1 3
535 %4 = OpConstantComposite %3 %2 %2 %2
536 )");
537   EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
538             "%5 = OpVectorTimesScalar %3 %4 %2\n");
539 }
540 
TEST_F(BuilderTest,Binary_Multiply_MatrixScalar)541 TEST_F(BuilderTest, Binary_Multiply_MatrixScalar) {
542   auto* var = Var("mat", ty.mat3x3<f32>());
543 
544   auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply,
545                                              Expr("mat"), Expr(1.f));
546 
547   WrapInFunction(var, expr);
548 
549   spirv::Builder& b = Build();
550 
551   b.push_function(Function{});
552   ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
553 
554   EXPECT_EQ(b.GenerateBinaryExpression(expr), 8u) << b.error();
555   EXPECT_EQ(DumpInstructions(b.types()),
556             R"(%5 = OpTypeFloat 32
557 %4 = OpTypeVector %5 3
558 %3 = OpTypeMatrix %4 3
559 %2 = OpTypePointer Function %3
560 %1 = OpVariable %2 Function
561 %7 = OpConstant %5 1
562 )");
563   EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
564             R"(%6 = OpLoad %3 %1
565 %8 = OpMatrixTimesScalar %3 %6 %7
566 )");
567 }
568 
TEST_F(BuilderTest,Binary_Multiply_ScalarMatrix)569 TEST_F(BuilderTest, Binary_Multiply_ScalarMatrix) {
570   auto* var = Var("mat", ty.mat3x3<f32>());
571 
572   auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply,
573                                              Expr(1.f), Expr("mat"));
574 
575   WrapInFunction(var, expr);
576 
577   spirv::Builder& b = Build();
578 
579   b.push_function(Function{});
580   ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
581 
582   EXPECT_EQ(b.GenerateBinaryExpression(expr), 8u) << b.error();
583   EXPECT_EQ(DumpInstructions(b.types()),
584             R"(%5 = OpTypeFloat 32
585 %4 = OpTypeVector %5 3
586 %3 = OpTypeMatrix %4 3
587 %2 = OpTypePointer Function %3
588 %1 = OpVariable %2 Function
589 %6 = OpConstant %5 1
590 )");
591   EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
592             R"(%7 = OpLoad %3 %1
593 %8 = OpMatrixTimesScalar %3 %7 %6
594 )");
595 }
596 
TEST_F(BuilderTest,Binary_Multiply_MatrixVector)597 TEST_F(BuilderTest, Binary_Multiply_MatrixVector) {
598   auto* var = Var("mat", ty.mat3x3<f32>());
599   auto* rhs = vec3<f32>(1.f, 1.f, 1.f);
600 
601   auto* expr =
602       create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, Expr("mat"), rhs);
603 
604   WrapInFunction(var, expr);
605 
606   spirv::Builder& b = Build();
607 
608   b.push_function(Function{});
609   ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
610 
611   EXPECT_EQ(b.GenerateBinaryExpression(expr), 9u) << b.error();
612   EXPECT_EQ(DumpInstructions(b.types()),
613             R"(%5 = OpTypeFloat 32
614 %4 = OpTypeVector %5 3
615 %3 = OpTypeMatrix %4 3
616 %2 = OpTypePointer Function %3
617 %1 = OpVariable %2 Function
618 %7 = OpConstant %5 1
619 %8 = OpConstantComposite %4 %7 %7 %7
620 )");
621   EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
622             R"(%6 = OpLoad %3 %1
623 %9 = OpMatrixTimesVector %4 %6 %8
624 )");
625 }
626 
TEST_F(BuilderTest,Binary_Multiply_VectorMatrix)627 TEST_F(BuilderTest, Binary_Multiply_VectorMatrix) {
628   auto* var = Var("mat", ty.mat3x3<f32>());
629   auto* lhs = vec3<f32>(1.f, 1.f, 1.f);
630 
631   auto* expr =
632       create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, Expr("mat"));
633 
634   WrapInFunction(var, expr);
635 
636   spirv::Builder& b = Build();
637 
638   b.push_function(Function{});
639   ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
640 
641   EXPECT_EQ(b.GenerateBinaryExpression(expr), 9u) << b.error();
642   EXPECT_EQ(DumpInstructions(b.types()),
643             R"(%5 = OpTypeFloat 32
644 %4 = OpTypeVector %5 3
645 %3 = OpTypeMatrix %4 3
646 %2 = OpTypePointer Function %3
647 %1 = OpVariable %2 Function
648 %6 = OpConstant %5 1
649 %7 = OpConstantComposite %4 %6 %6 %6
650 )");
651   EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
652             R"(%8 = OpLoad %3 %1
653 %9 = OpVectorTimesMatrix %4 %7 %8
654 )");
655 }
656 
TEST_F(BuilderTest,Binary_Multiply_MatrixMatrix)657 TEST_F(BuilderTest, Binary_Multiply_MatrixMatrix) {
658   auto* var = Var("mat", ty.mat3x3<f32>());
659 
660   auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply,
661                                              Expr("mat"), Expr("mat"));
662 
663   WrapInFunction(var, expr);
664 
665   spirv::Builder& b = Build();
666 
667   b.push_function(Function{});
668   ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
669 
670   EXPECT_EQ(b.GenerateBinaryExpression(expr), 8u) << b.error();
671   EXPECT_EQ(DumpInstructions(b.types()),
672             R"(%5 = OpTypeFloat 32
673 %4 = OpTypeVector %5 3
674 %3 = OpTypeMatrix %4 3
675 %2 = OpTypePointer Function %3
676 %1 = OpVariable %2 Function
677 )");
678   EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
679             R"(%6 = OpLoad %3 %1
680 %7 = OpLoad %3 %1
681 %8 = OpMatrixTimesMatrix %3 %6 %7
682 )");
683 }
684 
TEST_F(BuilderTest,Binary_LogicalAnd)685 TEST_F(BuilderTest, Binary_LogicalAnd) {
686   auto* lhs =
687       create<ast::BinaryExpression>(ast::BinaryOp::kEqual, Expr(1), Expr(2));
688 
689   auto* rhs =
690       create<ast::BinaryExpression>(ast::BinaryOp::kEqual, Expr(3), Expr(4));
691 
692   auto* expr =
693       create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd, lhs, rhs);
694 
695   WrapInFunction(expr);
696 
697   spirv::Builder& b = Build();
698 
699   b.push_function(Function{});
700   b.GenerateLabel(b.next_id());
701 
702   EXPECT_EQ(b.GenerateBinaryExpression(expr), 12u) << b.error();
703   EXPECT_EQ(DumpInstructions(b.types()),
704             R"(%2 = OpTypeInt 32 1
705 %3 = OpConstant %2 1
706 %4 = OpConstant %2 2
707 %6 = OpTypeBool
708 %9 = OpConstant %2 3
709 %10 = OpConstant %2 4
710 )");
711   EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
712             R"(%1 = OpLabel
713 %5 = OpIEqual %6 %3 %4
714 OpSelectionMerge %7 None
715 OpBranchConditional %5 %8 %7
716 %8 = OpLabel
717 %11 = OpIEqual %6 %9 %10
718 OpBranch %7
719 %7 = OpLabel
720 %12 = OpPhi %6 %5 %1 %11 %8
721 )");
722 }
723 
TEST_F(BuilderTest,Binary_LogicalAnd_WithLoads)724 TEST_F(BuilderTest, Binary_LogicalAnd_WithLoads) {
725   auto* a_var =
726       Global("a", ty.bool_(), ast::StorageClass::kPrivate, Expr(true));
727   auto* b_var =
728       Global("b", ty.bool_(), ast::StorageClass::kPrivate, Expr(false));
729 
730   auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd,
731                                              Expr("a"), Expr("b"));
732 
733   WrapInFunction(expr);
734 
735   spirv::Builder& b = Build();
736 
737   b.push_function(Function{});
738   b.GenerateLabel(b.next_id());
739 
740   ASSERT_TRUE(b.GenerateGlobalVariable(a_var)) << b.error();
741   ASSERT_TRUE(b.GenerateGlobalVariable(b_var)) << b.error();
742 
743   EXPECT_EQ(b.GenerateBinaryExpression(expr), 12u) << b.error();
744   EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeBool
745 %3 = OpConstantTrue %2
746 %5 = OpTypePointer Private %2
747 %4 = OpVariable %5 Private %3
748 %6 = OpConstantFalse %2
749 %7 = OpVariable %5 Private %6
750 )");
751   EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
752             R"(%1 = OpLabel
753 %8 = OpLoad %2 %4
754 OpSelectionMerge %9 None
755 OpBranchConditional %8 %10 %9
756 %10 = OpLabel
757 %11 = OpLoad %2 %7
758 OpBranch %9
759 %9 = OpLabel
760 %12 = OpPhi %2 %8 %1 %11 %10
761 )");
762 }
763 
TEST_F(BuilderTest,Binary_logicalOr_Nested_LogicalAnd)764 TEST_F(BuilderTest, Binary_logicalOr_Nested_LogicalAnd) {
765   // Test an expression like
766   //    a || (b && c)
767   // From: crbug.com/tint/355
768 
769   auto* logical_and_expr = create<ast::BinaryExpression>(
770       ast::BinaryOp::kLogicalAnd, Expr(true), Expr(false));
771 
772   auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kLogicalOr,
773                                              Expr(true), logical_and_expr);
774 
775   WrapInFunction(expr);
776 
777   spirv::Builder& b = Build();
778 
779   b.push_function(Function{});
780   b.GenerateLabel(b.next_id());
781 
782   EXPECT_EQ(b.GenerateBinaryExpression(expr), 10u) << b.error();
783   EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeBool
784 %3 = OpConstantTrue %2
785 %8 = OpConstantFalse %2
786 )");
787   EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
788             R"(%1 = OpLabel
789 OpSelectionMerge %4 None
790 OpBranchConditional %3 %4 %5
791 %5 = OpLabel
792 OpSelectionMerge %6 None
793 OpBranchConditional %3 %7 %6
794 %7 = OpLabel
795 OpBranch %6
796 %6 = OpLabel
797 %9 = OpPhi %2 %3 %5 %8 %7
798 OpBranch %4
799 %4 = OpLabel
800 %10 = OpPhi %2 %3 %1 %9 %6
801 )");
802 }
803 
TEST_F(BuilderTest,Binary_logicalAnd_Nested_LogicalOr)804 TEST_F(BuilderTest, Binary_logicalAnd_Nested_LogicalOr) {
805   // Test an expression like
806   //    a && (b || c)
807   // From: crbug.com/tint/355
808 
809   auto* logical_or_expr = create<ast::BinaryExpression>(
810       ast::BinaryOp::kLogicalOr, Expr(true), Expr(false));
811 
812   auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd,
813                                              Expr(true), logical_or_expr);
814 
815   WrapInFunction(expr);
816 
817   spirv::Builder& b = Build();
818 
819   b.push_function(Function{});
820   b.GenerateLabel(b.next_id());
821 
822   EXPECT_EQ(b.GenerateBinaryExpression(expr), 10u) << b.error();
823   EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeBool
824 %3 = OpConstantTrue %2
825 %8 = OpConstantFalse %2
826 )");
827   EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
828             R"(%1 = OpLabel
829 OpSelectionMerge %4 None
830 OpBranchConditional %3 %5 %4
831 %5 = OpLabel
832 OpSelectionMerge %6 None
833 OpBranchConditional %3 %6 %7
834 %7 = OpLabel
835 OpBranch %6
836 %6 = OpLabel
837 %9 = OpPhi %2 %3 %5 %8 %7
838 OpBranch %4
839 %4 = OpLabel
840 %10 = OpPhi %2 %3 %1 %9 %6
841 )");
842 }
843 
TEST_F(BuilderTest,Binary_LogicalOr)844 TEST_F(BuilderTest, Binary_LogicalOr) {
845   auto* lhs =
846       create<ast::BinaryExpression>(ast::BinaryOp::kEqual, Expr(1), Expr(2));
847 
848   auto* rhs =
849       create<ast::BinaryExpression>(ast::BinaryOp::kEqual, Expr(3), Expr(4));
850 
851   auto* expr =
852       create<ast::BinaryExpression>(ast::BinaryOp::kLogicalOr, lhs, rhs);
853 
854   WrapInFunction(expr);
855 
856   spirv::Builder& b = Build();
857 
858   b.push_function(Function{});
859   b.GenerateLabel(b.next_id());
860 
861   EXPECT_EQ(b.GenerateBinaryExpression(expr), 12u) << b.error();
862   EXPECT_EQ(DumpInstructions(b.types()),
863             R"(%2 = OpTypeInt 32 1
864 %3 = OpConstant %2 1
865 %4 = OpConstant %2 2
866 %6 = OpTypeBool
867 %9 = OpConstant %2 3
868 %10 = OpConstant %2 4
869 )");
870   EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
871             R"(%1 = OpLabel
872 %5 = OpIEqual %6 %3 %4
873 OpSelectionMerge %7 None
874 OpBranchConditional %5 %7 %8
875 %8 = OpLabel
876 %11 = OpIEqual %6 %9 %10
877 OpBranch %7
878 %7 = OpLabel
879 %12 = OpPhi %6 %5 %1 %11 %8
880 )");
881 }
882 
TEST_F(BuilderTest,Binary_LogicalOr_WithLoads)883 TEST_F(BuilderTest, Binary_LogicalOr_WithLoads) {
884   auto* a_var =
885       Global("a", ty.bool_(), ast::StorageClass::kPrivate, Expr(true));
886   auto* b_var =
887       Global("b", ty.bool_(), ast::StorageClass::kPrivate, Expr(false));
888 
889   auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kLogicalOr,
890                                              Expr("a"), Expr("b"));
891 
892   WrapInFunction(expr);
893 
894   spirv::Builder& b = Build();
895 
896   b.push_function(Function{});
897   b.GenerateLabel(b.next_id());
898 
899   ASSERT_TRUE(b.GenerateGlobalVariable(a_var)) << b.error();
900   ASSERT_TRUE(b.GenerateGlobalVariable(b_var)) << b.error();
901 
902   EXPECT_EQ(b.GenerateBinaryExpression(expr), 12u) << b.error();
903   EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeBool
904 %3 = OpConstantTrue %2
905 %5 = OpTypePointer Private %2
906 %4 = OpVariable %5 Private %3
907 %6 = OpConstantFalse %2
908 %7 = OpVariable %5 Private %6
909 )");
910   EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
911             R"(%1 = OpLabel
912 %8 = OpLoad %2 %4
913 OpSelectionMerge %9 None
914 OpBranchConditional %8 %9 %10
915 %10 = OpLabel
916 %11 = OpLoad %2 %7
917 OpBranch %9
918 %9 = OpLabel
919 %12 = OpPhi %2 %8 %1 %11 %10
920 )");
921 }
922 
923 namespace BinaryArithVectorScalar {
924 
925 enum class Type { f32, i32, u32 };
MakeVectorExpr(ProgramBuilder * builder,Type type)926 static const ast::Expression* MakeVectorExpr(ProgramBuilder* builder,
927                                              Type type) {
928   switch (type) {
929     case Type::f32:
930       return builder->vec3<ProgramBuilder::f32>(1.f, 1.f, 1.f);
931     case Type::i32:
932       return builder->vec3<ProgramBuilder::i32>(1, 1, 1);
933     case Type::u32:
934       return builder->vec3<ProgramBuilder::u32>(1u, 1u, 1u);
935   }
936   return nullptr;
937 }
MakeScalarExpr(ProgramBuilder * builder,Type type)938 static const ast::Expression* MakeScalarExpr(ProgramBuilder* builder,
939                                              Type type) {
940   switch (type) {
941     case Type::f32:
942       return builder->Expr(1.f);
943     case Type::i32:
944       return builder->Expr(1);
945     case Type::u32:
946       return builder->Expr(1u);
947   }
948   return nullptr;
949 }
OpTypeDecl(Type type)950 static std::string OpTypeDecl(Type type) {
951   switch (type) {
952     case Type::f32:
953       return "OpTypeFloat 32";
954     case Type::i32:
955       return "OpTypeInt 32 1";
956     case Type::u32:
957       return "OpTypeInt 32 0";
958   }
959   return {};
960 }
961 
962 struct Param {
963   Type type;
964   ast::BinaryOp op;
965   std::string name;
966 };
967 
968 using BinaryArithVectorScalarTest = TestParamHelper<Param>;
TEST_P(BinaryArithVectorScalarTest,VectorScalar)969 TEST_P(BinaryArithVectorScalarTest, VectorScalar) {
970   auto& param = GetParam();
971 
972   const ast::Expression* lhs = MakeVectorExpr(this, param.type);
973   const ast::Expression* rhs = MakeScalarExpr(this, param.type);
974   std::string op_type_decl = OpTypeDecl(param.type);
975 
976   auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
977 
978   WrapInFunction(expr);
979 
980   spirv::Builder& b = Build();
981   ASSERT_TRUE(b.Build()) << b.error();
982 
983   EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
984 OpMemoryModel Logical GLSL450
985 OpEntryPoint GLCompute %3 "test_function"
986 OpExecutionMode %3 LocalSize 1 1 1
987 OpName %3 "test_function"
988 %2 = OpTypeVoid
989 %1 = OpTypeFunction %2
990 %6 = )" + op_type_decl + R"(
991 %5 = OpTypeVector %6 3
992 %7 = OpConstant %6 1
993 %8 = OpConstantComposite %5 %7 %7 %7
994 %11 = OpTypePointer Function %5
995 %12 = OpConstantNull %5
996 %3 = OpFunction %2 None %1
997 %4 = OpLabel
998 %10 = OpVariable %11 Function %12
999 %13 = OpCompositeConstruct %5 %7 %7 %7
1000 %9 = )" + param.name + R"( %5 %8 %13
1001 OpReturn
1002 OpFunctionEnd
1003 )");
1004 
1005   Validate(b);
1006 }
TEST_P(BinaryArithVectorScalarTest,ScalarVector)1007 TEST_P(BinaryArithVectorScalarTest, ScalarVector) {
1008   auto& param = GetParam();
1009 
1010   const ast::Expression* lhs = MakeScalarExpr(this, param.type);
1011   const ast::Expression* rhs = MakeVectorExpr(this, param.type);
1012   std::string op_type_decl = OpTypeDecl(param.type);
1013 
1014   auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
1015 
1016   WrapInFunction(expr);
1017 
1018   spirv::Builder& b = Build();
1019   ASSERT_TRUE(b.Build()) << b.error();
1020 
1021   EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
1022 OpMemoryModel Logical GLSL450
1023 OpEntryPoint GLCompute %3 "test_function"
1024 OpExecutionMode %3 LocalSize 1 1 1
1025 OpName %3 "test_function"
1026 %2 = OpTypeVoid
1027 %1 = OpTypeFunction %2
1028 %5 = )" + op_type_decl + R"(
1029 %6 = OpConstant %5 1
1030 %7 = OpTypeVector %5 3
1031 %8 = OpConstantComposite %7 %6 %6 %6
1032 %11 = OpTypePointer Function %7
1033 %12 = OpConstantNull %7
1034 %3 = OpFunction %2 None %1
1035 %4 = OpLabel
1036 %10 = OpVariable %11 Function %12
1037 %13 = OpCompositeConstruct %7 %6 %6 %6
1038 %9 = )" + param.name + R"( %7 %13 %8
1039 OpReturn
1040 OpFunctionEnd
1041 )");
1042 
1043   Validate(b);
1044 }
1045 INSTANTIATE_TEST_SUITE_P(
1046     BuilderTest,
1047     BinaryArithVectorScalarTest,
1048     testing::Values(Param{Type::f32, ast::BinaryOp::kAdd, "OpFAdd"},
1049                     Param{Type::f32, ast::BinaryOp::kDivide, "OpFDiv"},
1050                     // NOTE: Modulo not allowed on mixed float scalar-vector
1051                     // Param{Type::f32, ast::BinaryOp::kModulo, "OpFMod"},
1052                     // NOTE: We test f32 multiplies separately as we emit
1053                     // OpVectorTimesScalar for this case
1054                     // Param{Type::i32, ast::BinaryOp::kMultiply, "OpIMul"},
1055                     Param{Type::f32, ast::BinaryOp::kSubtract, "OpFSub"},
1056 
1057                     Param{Type::i32, ast::BinaryOp::kAdd, "OpIAdd"},
1058                     Param{Type::i32, ast::BinaryOp::kDivide, "OpSDiv"},
1059                     Param{Type::i32, ast::BinaryOp::kModulo, "OpSMod"},
1060                     Param{Type::i32, ast::BinaryOp::kMultiply, "OpIMul"},
1061                     Param{Type::i32, ast::BinaryOp::kSubtract, "OpISub"},
1062 
1063                     Param{Type::u32, ast::BinaryOp::kAdd, "OpIAdd"},
1064                     Param{Type::u32, ast::BinaryOp::kDivide, "OpUDiv"},
1065                     Param{Type::u32, ast::BinaryOp::kModulo, "OpUMod"},
1066                     Param{Type::u32, ast::BinaryOp::kMultiply, "OpIMul"},
1067                     Param{Type::u32, ast::BinaryOp::kSubtract, "OpISub"}));
1068 
1069 using BinaryArithVectorScalarMultiplyTest = TestParamHelper<Param>;
TEST_P(BinaryArithVectorScalarMultiplyTest,VectorScalar)1070 TEST_P(BinaryArithVectorScalarMultiplyTest, VectorScalar) {
1071   auto& param = GetParam();
1072 
1073   const ast::Expression* lhs = MakeVectorExpr(this, param.type);
1074   const ast::Expression* rhs = MakeScalarExpr(this, param.type);
1075   std::string op_type_decl = OpTypeDecl(param.type);
1076 
1077   auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
1078 
1079   WrapInFunction(expr);
1080 
1081   spirv::Builder& b = Build();
1082   ASSERT_TRUE(b.Build()) << b.error();
1083 
1084   EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
1085 OpMemoryModel Logical GLSL450
1086 OpEntryPoint GLCompute %3 "test_function"
1087 OpExecutionMode %3 LocalSize 1 1 1
1088 OpName %3 "test_function"
1089 %2 = OpTypeVoid
1090 %1 = OpTypeFunction %2
1091 %6 = )" + op_type_decl + R"(
1092 %5 = OpTypeVector %6 3
1093 %7 = OpConstant %6 1
1094 %8 = OpConstantComposite %5 %7 %7 %7
1095 %3 = OpFunction %2 None %1
1096 %4 = OpLabel
1097 %9 = OpVectorTimesScalar %5 %8 %7
1098 OpReturn
1099 OpFunctionEnd
1100 )");
1101 
1102   Validate(b);
1103 }
TEST_P(BinaryArithVectorScalarMultiplyTest,ScalarVector)1104 TEST_P(BinaryArithVectorScalarMultiplyTest, ScalarVector) {
1105   auto& param = GetParam();
1106 
1107   const ast::Expression* lhs = MakeScalarExpr(this, param.type);
1108   const ast::Expression* rhs = MakeVectorExpr(this, param.type);
1109   std::string op_type_decl = OpTypeDecl(param.type);
1110 
1111   auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
1112 
1113   WrapInFunction(expr);
1114 
1115   spirv::Builder& b = Build();
1116   ASSERT_TRUE(b.Build()) << b.error();
1117 
1118   EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
1119 OpMemoryModel Logical GLSL450
1120 OpEntryPoint GLCompute %3 "test_function"
1121 OpExecutionMode %3 LocalSize 1 1 1
1122 OpName %3 "test_function"
1123 %2 = OpTypeVoid
1124 %1 = OpTypeFunction %2
1125 %5 = )" + op_type_decl + R"(
1126 %6 = OpConstant %5 1
1127 %7 = OpTypeVector %5 3
1128 %8 = OpConstantComposite %7 %6 %6 %6
1129 %3 = OpFunction %2 None %1
1130 %4 = OpLabel
1131 %9 = OpVectorTimesScalar %7 %8 %6
1132 OpReturn
1133 OpFunctionEnd
1134 )");
1135 
1136   Validate(b);
1137 }
1138 INSTANTIATE_TEST_SUITE_P(BuilderTest,
1139                          BinaryArithVectorScalarMultiplyTest,
1140                          testing::Values(Param{
1141                              Type::f32, ast::BinaryOp::kMultiply, "OpFMul"}));
1142 
1143 }  // namespace BinaryArithVectorScalar
1144 
1145 namespace BinaryArithMatrixMatrix {
1146 
1147 struct Param {
1148   ast::BinaryOp op;
1149   std::string name;
1150 };
1151 
1152 using BinaryArithMatrixMatrix = TestParamHelper<Param>;
TEST_P(BinaryArithMatrixMatrix,AddOrSubtract)1153 TEST_P(BinaryArithMatrixMatrix, AddOrSubtract) {
1154   auto& param = GetParam();
1155 
1156   const ast::Expression* lhs = mat3x4<f32>();
1157   const ast::Expression* rhs = mat3x4<f32>();
1158 
1159   auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
1160 
1161   WrapInFunction(expr);
1162 
1163   spirv::Builder& b = Build();
1164   ASSERT_TRUE(b.Build()) << b.error();
1165 
1166   EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
1167 OpMemoryModel Logical GLSL450
1168 OpEntryPoint GLCompute %3 "test_function"
1169 OpExecutionMode %3 LocalSize 1 1 1
1170 OpName %3 "test_function"
1171 %2 = OpTypeVoid
1172 %1 = OpTypeFunction %2
1173 %7 = OpTypeFloat 32
1174 %6 = OpTypeVector %7 4
1175 %5 = OpTypeMatrix %6 3
1176 %8 = OpConstantNull %5
1177 %3 = OpFunction %2 None %1
1178 %4 = OpLabel
1179 %10 = OpCompositeExtract %6 %8 0
1180 %11 = OpCompositeExtract %6 %8 0
1181 %12 = )" + param.name + R"( %6 %10 %11
1182 %13 = OpCompositeExtract %6 %8 1
1183 %14 = OpCompositeExtract %6 %8 1
1184 %15 = )" + param.name + R"( %6 %13 %14
1185 %16 = OpCompositeExtract %6 %8 2
1186 %17 = OpCompositeExtract %6 %8 2
1187 %18 = )" + param.name + R"( %6 %16 %17
1188 %19 = OpCompositeConstruct %5 %12 %15 %18
1189 OpReturn
1190 OpFunctionEnd
1191 )");
1192 
1193   Validate(b);
1194 }
1195 INSTANTIATE_TEST_SUITE_P(  //
1196     BuilderTest,
1197     BinaryArithMatrixMatrix,
1198     testing::Values(Param{ast::BinaryOp::kAdd, "OpFAdd"},
1199                     Param{ast::BinaryOp::kSubtract, "OpFSub"}));
1200 
1201 using BinaryArithMatrixMatrixMultiply = TestParamHelper<Param>;
TEST_P(BinaryArithMatrixMatrixMultiply,Multiply)1202 TEST_P(BinaryArithMatrixMatrixMultiply, Multiply) {
1203   auto& param = GetParam();
1204 
1205   const ast::Expression* lhs = mat3x4<f32>();
1206   const ast::Expression* rhs = mat4x3<f32>();
1207 
1208   auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
1209 
1210   WrapInFunction(expr);
1211 
1212   spirv::Builder& b = Build();
1213   ASSERT_TRUE(b.Build()) << b.error();
1214 
1215   EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
1216 OpMemoryModel Logical GLSL450
1217 OpEntryPoint GLCompute %3 "test_function"
1218 OpExecutionMode %3 LocalSize 1 1 1
1219 OpName %3 "test_function"
1220 %2 = OpTypeVoid
1221 %1 = OpTypeFunction %2
1222 %7 = OpTypeFloat 32
1223 %6 = OpTypeVector %7 4
1224 %5 = OpTypeMatrix %6 3
1225 %8 = OpConstantNull %5
1226 %10 = OpTypeVector %7 3
1227 %9 = OpTypeMatrix %10 4
1228 %11 = OpConstantNull %9
1229 %13 = OpTypeMatrix %6 4
1230 %3 = OpFunction %2 None %1
1231 %4 = OpLabel
1232 %12 = OpMatrixTimesMatrix %13 %8 %11
1233 OpReturn
1234 OpFunctionEnd
1235 )");
1236 
1237   Validate(b);
1238 }
1239 INSTANTIATE_TEST_SUITE_P(  //
1240     BuilderTest,
1241     BinaryArithMatrixMatrixMultiply,
1242     testing::Values(Param{ast::BinaryOp::kMultiply, "OpFMul"}));
1243 
1244 }  // namespace BinaryArithMatrixMatrix
1245 
1246 }  // namespace
1247 }  // namespace spirv
1248 }  // namespace writer
1249 }  // namespace tint
1250