1 // Copyright 2020 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/writer/spirv/spv_dump.h"
16 #include "src/writer/spirv/test_helper.h"
17
18 namespace tint {
19 namespace writer {
20 namespace spirv {
21 namespace {
22
23 using BuilderTest = TestHelper;
24
25 struct BinaryData {
26 ast::BinaryOp op;
27 std::string name;
28 };
operator <<(std::ostream & out,BinaryData data)29 inline std::ostream& operator<<(std::ostream& out, BinaryData data) {
30 out << data.op;
31 return out;
32 }
33
34 using BinaryArithSignedIntegerTest = TestParamHelper<BinaryData>;
TEST_P(BinaryArithSignedIntegerTest,Scalar)35 TEST_P(BinaryArithSignedIntegerTest, Scalar) {
36 auto param = GetParam();
37
38 auto* lhs = Expr(3);
39 auto* rhs = Expr(4);
40
41 auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
42
43 WrapInFunction(expr);
44
45 spirv::Builder& b = Build();
46
47 b.push_function(Function{});
48
49 EXPECT_EQ(b.GenerateBinaryExpression(expr), 4u) << b.error();
50 EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 1
51 %2 = OpConstant %1 3
52 %3 = OpConstant %1 4
53 )");
54 EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
55 "%4 = " + param.name + " %1 %2 %3\n");
56 }
57
TEST_P(BinaryArithSignedIntegerTest,Vector)58 TEST_P(BinaryArithSignedIntegerTest, Vector) {
59 auto param = GetParam();
60
61 // Skip ops that are illegal for this type
62 if (param.op == ast::BinaryOp::kAnd || param.op == ast::BinaryOp::kOr ||
63 param.op == ast::BinaryOp::kXor) {
64 return;
65 }
66
67 auto* lhs = vec3<i32>(1, 1, 1);
68 auto* rhs = vec3<i32>(1, 1, 1);
69
70 auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
71
72 WrapInFunction(expr);
73
74 spirv::Builder& b = Build();
75
76 b.push_function(Function{});
77
78 EXPECT_EQ(b.GenerateBinaryExpression(expr), 5u) << b.error();
79 EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1
80 %1 = OpTypeVector %2 3
81 %3 = OpConstant %2 1
82 %4 = OpConstantComposite %1 %3 %3 %3
83 )");
84 EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
85 "%5 = " + param.name + " %1 %4 %4\n");
86 }
TEST_P(BinaryArithSignedIntegerTest,Scalar_Loads)87 TEST_P(BinaryArithSignedIntegerTest, Scalar_Loads) {
88 auto param = GetParam();
89
90 auto* var = Var("param", ty.i32());
91 auto* expr =
92 create<ast::BinaryExpression>(param.op, Expr("param"), Expr("param"));
93
94 WrapInFunction(var, expr);
95
96 spirv::Builder& b = Build();
97
98 b.push_function(Function{});
99 EXPECT_TRUE(b.GenerateFunctionVariable(var)) << b.error();
100 EXPECT_EQ(b.GenerateBinaryExpression(expr), 7u) << b.error();
101 ASSERT_FALSE(b.has_error()) << b.error();
102
103 EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeInt 32 1
104 %2 = OpTypePointer Function %3
105 %4 = OpConstantNull %3
106 )");
107 EXPECT_EQ(DumpInstructions(b.functions()[0].variables()),
108 R"(%1 = OpVariable %2 Function %4
109 )");
110 EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
111 R"(%5 = OpLoad %3 %1
112 %6 = OpLoad %3 %1
113 %7 = )" + param.name +
114 R"( %3 %5 %6
115 )");
116 }
117 INSTANTIATE_TEST_SUITE_P(
118 BuilderTest,
119 BinaryArithSignedIntegerTest,
120 // NOTE: No left and right shift as they require u32 for rhs operand
121 testing::Values(BinaryData{ast::BinaryOp::kAdd, "OpIAdd"},
122 BinaryData{ast::BinaryOp::kAnd, "OpBitwiseAnd"},
123 BinaryData{ast::BinaryOp::kDivide, "OpSDiv"},
124 BinaryData{ast::BinaryOp::kModulo, "OpSMod"},
125 BinaryData{ast::BinaryOp::kMultiply, "OpIMul"},
126 BinaryData{ast::BinaryOp::kOr, "OpBitwiseOr"},
127 BinaryData{ast::BinaryOp::kSubtract, "OpISub"},
128 BinaryData{ast::BinaryOp::kXor, "OpBitwiseXor"}));
129
130 using BinaryArithUnsignedIntegerTest = TestParamHelper<BinaryData>;
TEST_P(BinaryArithUnsignedIntegerTest,Scalar)131 TEST_P(BinaryArithUnsignedIntegerTest, Scalar) {
132 auto param = GetParam();
133
134 auto* lhs = Expr(3u);
135 auto* rhs = Expr(4u);
136
137 auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
138
139 WrapInFunction(expr);
140
141 spirv::Builder& b = Build();
142
143 b.push_function(Function{});
144
145 EXPECT_EQ(b.GenerateBinaryExpression(expr), 4u) << b.error();
146 EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 0
147 %2 = OpConstant %1 3
148 %3 = OpConstant %1 4
149 )");
150 EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
151 "%4 = " + param.name + " %1 %2 %3\n");
152 }
TEST_P(BinaryArithUnsignedIntegerTest,Vector)153 TEST_P(BinaryArithUnsignedIntegerTest, Vector) {
154 auto param = GetParam();
155
156 // Skip ops that are illegal for this type
157 if (param.op == ast::BinaryOp::kAnd || param.op == ast::BinaryOp::kOr ||
158 param.op == ast::BinaryOp::kXor) {
159 return;
160 }
161
162 auto* lhs = vec3<u32>(1u, 1u, 1u);
163 auto* rhs = vec3<u32>(1u, 1u, 1u);
164
165 auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
166
167 WrapInFunction(expr);
168
169 spirv::Builder& b = Build();
170
171 b.push_function(Function{});
172
173 EXPECT_EQ(b.GenerateBinaryExpression(expr), 5u) << b.error();
174 EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 0
175 %1 = OpTypeVector %2 3
176 %3 = OpConstant %2 1
177 %4 = OpConstantComposite %1 %3 %3 %3
178 )");
179 EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
180 "%5 = " + param.name + " %1 %4 %4\n");
181 }
182 INSTANTIATE_TEST_SUITE_P(
183 BuilderTest,
184 BinaryArithUnsignedIntegerTest,
185 testing::Values(BinaryData{ast::BinaryOp::kAdd, "OpIAdd"},
186 BinaryData{ast::BinaryOp::kAnd, "OpBitwiseAnd"},
187 BinaryData{ast::BinaryOp::kDivide, "OpUDiv"},
188 BinaryData{ast::BinaryOp::kModulo, "OpUMod"},
189 BinaryData{ast::BinaryOp::kMultiply, "OpIMul"},
190 BinaryData{ast::BinaryOp::kOr, "OpBitwiseOr"},
191 BinaryData{ast::BinaryOp::kShiftLeft, "OpShiftLeftLogical"},
192 BinaryData{ast::BinaryOp::kShiftRight,
193 "OpShiftRightLogical"},
194 BinaryData{ast::BinaryOp::kSubtract, "OpISub"},
195 BinaryData{ast::BinaryOp::kXor, "OpBitwiseXor"}));
196
197 using BinaryArithFloatTest = TestParamHelper<BinaryData>;
TEST_P(BinaryArithFloatTest,Scalar)198 TEST_P(BinaryArithFloatTest, Scalar) {
199 auto param = GetParam();
200
201 auto* lhs = Expr(3.2f);
202 auto* rhs = Expr(4.5f);
203
204 auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
205
206 WrapInFunction(expr);
207
208 spirv::Builder& b = Build();
209
210 b.push_function(Function{});
211
212 EXPECT_EQ(b.GenerateBinaryExpression(expr), 4u) << b.error();
213 EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32
214 %2 = OpConstant %1 3.20000005
215 %3 = OpConstant %1 4.5
216 )");
217 EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
218 "%4 = " + param.name + " %1 %2 %3\n");
219 }
220
TEST_P(BinaryArithFloatTest,Vector)221 TEST_P(BinaryArithFloatTest, Vector) {
222 auto param = GetParam();
223
224 auto* lhs = vec3<f32>(1.f, 1.f, 1.f);
225 auto* rhs = vec3<f32>(1.f, 1.f, 1.f);
226
227 auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
228
229 WrapInFunction(expr);
230
231 spirv::Builder& b = Build();
232
233 b.push_function(Function{});
234
235 EXPECT_EQ(b.GenerateBinaryExpression(expr), 5u) << b.error();
236 EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
237 %1 = OpTypeVector %2 3
238 %3 = OpConstant %2 1
239 %4 = OpConstantComposite %1 %3 %3 %3
240 )");
241 EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
242 "%5 = " + param.name + " %1 %4 %4\n");
243 }
244 INSTANTIATE_TEST_SUITE_P(
245 BuilderTest,
246 BinaryArithFloatTest,
247 testing::Values(BinaryData{ast::BinaryOp::kAdd, "OpFAdd"},
248 BinaryData{ast::BinaryOp::kDivide, "OpFDiv"},
249 BinaryData{ast::BinaryOp::kModulo, "OpFRem"},
250 BinaryData{ast::BinaryOp::kMultiply, "OpFMul"},
251 BinaryData{ast::BinaryOp::kSubtract, "OpFSub"}));
252
253 using BinaryOperatorBoolTest = TestParamHelper<BinaryData>;
TEST_P(BinaryOperatorBoolTest,Scalar)254 TEST_P(BinaryOperatorBoolTest, Scalar) {
255 auto param = GetParam();
256
257 auto* lhs = Expr(true);
258 auto* rhs = Expr(false);
259
260 auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
261
262 WrapInFunction(expr);
263
264 spirv::Builder& b = Build();
265
266 b.push_function(Function{});
267
268 EXPECT_EQ(b.GenerateBinaryExpression(expr), 4u) << b.error();
269 EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeBool
270 %2 = OpConstantTrue %1
271 %3 = OpConstantFalse %1
272 )");
273 EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
274 "%4 = " + param.name + " %1 %2 %3\n");
275 }
276
TEST_P(BinaryOperatorBoolTest,Vector)277 TEST_P(BinaryOperatorBoolTest, Vector) {
278 auto param = GetParam();
279
280 auto* lhs = vec3<bool>(false, true, false);
281 auto* rhs = vec3<bool>(true, false, true);
282
283 auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
284
285 WrapInFunction(expr);
286
287 spirv::Builder& b = Build();
288
289 b.push_function(Function{});
290
291 EXPECT_EQ(b.GenerateBinaryExpression(expr), 7u) << b.error();
292 EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeBool
293 %1 = OpTypeVector %2 3
294 %3 = OpConstantFalse %2
295 %4 = OpConstantTrue %2
296 %5 = OpConstantComposite %1 %3 %4 %3
297 %6 = OpConstantComposite %1 %4 %3 %4
298 )");
299 EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
300 "%7 = " + param.name + " %1 %5 %6\n");
301 }
302 INSTANTIATE_TEST_SUITE_P(
303 BuilderTest,
304 BinaryOperatorBoolTest,
305 testing::Values(BinaryData{ast::BinaryOp::kEqual, "OpLogicalEqual"},
306 BinaryData{ast::BinaryOp::kNotEqual, "OpLogicalNotEqual"},
307 BinaryData{ast::BinaryOp::kAnd, "OpLogicalAnd"},
308 BinaryData{ast::BinaryOp::kOr, "OpLogicalOr"}));
309
310 using BinaryCompareUnsignedIntegerTest = TestParamHelper<BinaryData>;
TEST_P(BinaryCompareUnsignedIntegerTest,Scalar)311 TEST_P(BinaryCompareUnsignedIntegerTest, Scalar) {
312 auto param = GetParam();
313
314 auto* lhs = Expr(3u);
315 auto* rhs = Expr(4u);
316
317 auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
318
319 WrapInFunction(expr);
320
321 spirv::Builder& b = Build();
322
323 b.push_function(Function{});
324
325 EXPECT_EQ(b.GenerateBinaryExpression(expr), 4u) << b.error();
326 EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 0
327 %2 = OpConstant %1 3
328 %3 = OpConstant %1 4
329 %5 = OpTypeBool
330 )");
331 EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
332 "%4 = " + param.name + " %5 %2 %3\n");
333 }
334
TEST_P(BinaryCompareUnsignedIntegerTest,Vector)335 TEST_P(BinaryCompareUnsignedIntegerTest, Vector) {
336 auto param = GetParam();
337
338 auto* lhs = vec3<u32>(1u, 1u, 1u);
339 auto* rhs = vec3<u32>(1u, 1u, 1u);
340
341 auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
342
343 WrapInFunction(expr);
344
345 spirv::Builder& b = Build();
346
347 b.push_function(Function{});
348
349 EXPECT_EQ(b.GenerateBinaryExpression(expr), 5u) << b.error();
350 EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 0
351 %1 = OpTypeVector %2 3
352 %3 = OpConstant %2 1
353 %4 = OpConstantComposite %1 %3 %3 %3
354 %7 = OpTypeBool
355 %6 = OpTypeVector %7 3
356 )");
357 EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
358 "%5 = " + param.name + " %6 %4 %4\n");
359 }
360 INSTANTIATE_TEST_SUITE_P(
361 BuilderTest,
362 BinaryCompareUnsignedIntegerTest,
363 testing::Values(
364 BinaryData{ast::BinaryOp::kEqual, "OpIEqual"},
365 BinaryData{ast::BinaryOp::kGreaterThan, "OpUGreaterThan"},
366 BinaryData{ast::BinaryOp::kGreaterThanEqual, "OpUGreaterThanEqual"},
367 BinaryData{ast::BinaryOp::kLessThan, "OpULessThan"},
368 BinaryData{ast::BinaryOp::kLessThanEqual, "OpULessThanEqual"},
369 BinaryData{ast::BinaryOp::kNotEqual, "OpINotEqual"}));
370
371 using BinaryCompareSignedIntegerTest = TestParamHelper<BinaryData>;
TEST_P(BinaryCompareSignedIntegerTest,Scalar)372 TEST_P(BinaryCompareSignedIntegerTest, Scalar) {
373 auto param = GetParam();
374
375 auto* lhs = Expr(3);
376 auto* rhs = Expr(4);
377
378 auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
379
380 WrapInFunction(expr);
381
382 spirv::Builder& b = Build();
383
384 b.push_function(Function{});
385
386 EXPECT_EQ(b.GenerateBinaryExpression(expr), 4u) << b.error();
387 EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 1
388 %2 = OpConstant %1 3
389 %3 = OpConstant %1 4
390 %5 = OpTypeBool
391 )");
392 EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
393 "%4 = " + param.name + " %5 %2 %3\n");
394 }
395
TEST_P(BinaryCompareSignedIntegerTest,Vector)396 TEST_P(BinaryCompareSignedIntegerTest, Vector) {
397 auto param = GetParam();
398
399 auto* lhs = vec3<i32>(1, 1, 1);
400 auto* rhs = vec3<i32>(1, 1, 1);
401
402 auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
403
404 WrapInFunction(expr);
405
406 spirv::Builder& b = Build();
407
408 b.push_function(Function{});
409
410 EXPECT_EQ(b.GenerateBinaryExpression(expr), 5u) << b.error();
411 EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1
412 %1 = OpTypeVector %2 3
413 %3 = OpConstant %2 1
414 %4 = OpConstantComposite %1 %3 %3 %3
415 %7 = OpTypeBool
416 %6 = OpTypeVector %7 3
417 )");
418 EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
419 "%5 = " + param.name + " %6 %4 %4\n");
420 }
421 INSTANTIATE_TEST_SUITE_P(
422 BuilderTest,
423 BinaryCompareSignedIntegerTest,
424 testing::Values(
425 BinaryData{ast::BinaryOp::kEqual, "OpIEqual"},
426 BinaryData{ast::BinaryOp::kGreaterThan, "OpSGreaterThan"},
427 BinaryData{ast::BinaryOp::kGreaterThanEqual, "OpSGreaterThanEqual"},
428 BinaryData{ast::BinaryOp::kLessThan, "OpSLessThan"},
429 BinaryData{ast::BinaryOp::kLessThanEqual, "OpSLessThanEqual"},
430 BinaryData{ast::BinaryOp::kNotEqual, "OpINotEqual"}));
431
432 using BinaryCompareFloatTest = TestParamHelper<BinaryData>;
TEST_P(BinaryCompareFloatTest,Scalar)433 TEST_P(BinaryCompareFloatTest, Scalar) {
434 auto param = GetParam();
435
436 auto* lhs = Expr(3.2f);
437 auto* rhs = Expr(4.5f);
438
439 auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
440
441 WrapInFunction(expr);
442
443 spirv::Builder& b = Build();
444
445 b.push_function(Function{});
446
447 EXPECT_EQ(b.GenerateBinaryExpression(expr), 4u) << b.error();
448 EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32
449 %2 = OpConstant %1 3.20000005
450 %3 = OpConstant %1 4.5
451 %5 = OpTypeBool
452 )");
453 EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
454 "%4 = " + param.name + " %5 %2 %3\n");
455 }
456
TEST_P(BinaryCompareFloatTest,Vector)457 TEST_P(BinaryCompareFloatTest, Vector) {
458 auto param = GetParam();
459
460 auto* lhs = vec3<f32>(1.f, 1.f, 1.f);
461 auto* rhs = vec3<f32>(1.f, 1.f, 1.f);
462
463 auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
464
465 WrapInFunction(expr);
466
467 spirv::Builder& b = Build();
468
469 b.push_function(Function{});
470
471 EXPECT_EQ(b.GenerateBinaryExpression(expr), 5u) << b.error();
472 EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
473 %1 = OpTypeVector %2 3
474 %3 = OpConstant %2 1
475 %4 = OpConstantComposite %1 %3 %3 %3
476 %7 = OpTypeBool
477 %6 = OpTypeVector %7 3
478 )");
479 EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
480 "%5 = " + param.name + " %6 %4 %4\n");
481 }
482 INSTANTIATE_TEST_SUITE_P(
483 BuilderTest,
484 BinaryCompareFloatTest,
485 testing::Values(
486 BinaryData{ast::BinaryOp::kEqual, "OpFOrdEqual"},
487 BinaryData{ast::BinaryOp::kGreaterThan, "OpFOrdGreaterThan"},
488 BinaryData{ast::BinaryOp::kGreaterThanEqual, "OpFOrdGreaterThanEqual"},
489 BinaryData{ast::BinaryOp::kLessThan, "OpFOrdLessThan"},
490 BinaryData{ast::BinaryOp::kLessThanEqual, "OpFOrdLessThanEqual"},
491 BinaryData{ast::BinaryOp::kNotEqual, "OpFOrdNotEqual"}));
492
TEST_F(BuilderTest,Binary_Multiply_VectorScalar)493 TEST_F(BuilderTest, Binary_Multiply_VectorScalar) {
494 auto* lhs = vec3<f32>(1.f, 1.f, 1.f);
495 auto* rhs = Expr(1.f);
496
497 auto* expr =
498 create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
499
500 WrapInFunction(expr);
501
502 spirv::Builder& b = Build();
503
504 b.push_function(Function{});
505
506 EXPECT_EQ(b.GenerateBinaryExpression(expr), 5u) << b.error();
507 EXPECT_EQ(DumpInstructions(b.types()),
508 R"(%2 = OpTypeFloat 32
509 %1 = OpTypeVector %2 3
510 %3 = OpConstant %2 1
511 %4 = OpConstantComposite %1 %3 %3 %3
512 )");
513 EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
514 "%5 = OpVectorTimesScalar %1 %4 %3\n");
515 }
516
TEST_F(BuilderTest,Binary_Multiply_ScalarVector)517 TEST_F(BuilderTest, Binary_Multiply_ScalarVector) {
518 auto* lhs = Expr(1.f);
519 auto* rhs = vec3<f32>(1.f, 1.f, 1.f);
520
521 auto* expr =
522 create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
523
524 WrapInFunction(expr);
525
526 spirv::Builder& b = Build();
527
528 b.push_function(Function{});
529
530 EXPECT_EQ(b.GenerateBinaryExpression(expr), 5u) << b.error();
531 EXPECT_EQ(DumpInstructions(b.types()),
532 R"(%1 = OpTypeFloat 32
533 %2 = OpConstant %1 1
534 %3 = OpTypeVector %1 3
535 %4 = OpConstantComposite %3 %2 %2 %2
536 )");
537 EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
538 "%5 = OpVectorTimesScalar %3 %4 %2\n");
539 }
540
TEST_F(BuilderTest,Binary_Multiply_MatrixScalar)541 TEST_F(BuilderTest, Binary_Multiply_MatrixScalar) {
542 auto* var = Var("mat", ty.mat3x3<f32>());
543
544 auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply,
545 Expr("mat"), Expr(1.f));
546
547 WrapInFunction(var, expr);
548
549 spirv::Builder& b = Build();
550
551 b.push_function(Function{});
552 ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
553
554 EXPECT_EQ(b.GenerateBinaryExpression(expr), 8u) << b.error();
555 EXPECT_EQ(DumpInstructions(b.types()),
556 R"(%5 = OpTypeFloat 32
557 %4 = OpTypeVector %5 3
558 %3 = OpTypeMatrix %4 3
559 %2 = OpTypePointer Function %3
560 %1 = OpVariable %2 Function
561 %7 = OpConstant %5 1
562 )");
563 EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
564 R"(%6 = OpLoad %3 %1
565 %8 = OpMatrixTimesScalar %3 %6 %7
566 )");
567 }
568
TEST_F(BuilderTest,Binary_Multiply_ScalarMatrix)569 TEST_F(BuilderTest, Binary_Multiply_ScalarMatrix) {
570 auto* var = Var("mat", ty.mat3x3<f32>());
571
572 auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply,
573 Expr(1.f), Expr("mat"));
574
575 WrapInFunction(var, expr);
576
577 spirv::Builder& b = Build();
578
579 b.push_function(Function{});
580 ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
581
582 EXPECT_EQ(b.GenerateBinaryExpression(expr), 8u) << b.error();
583 EXPECT_EQ(DumpInstructions(b.types()),
584 R"(%5 = OpTypeFloat 32
585 %4 = OpTypeVector %5 3
586 %3 = OpTypeMatrix %4 3
587 %2 = OpTypePointer Function %3
588 %1 = OpVariable %2 Function
589 %6 = OpConstant %5 1
590 )");
591 EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
592 R"(%7 = OpLoad %3 %1
593 %8 = OpMatrixTimesScalar %3 %7 %6
594 )");
595 }
596
TEST_F(BuilderTest,Binary_Multiply_MatrixVector)597 TEST_F(BuilderTest, Binary_Multiply_MatrixVector) {
598 auto* var = Var("mat", ty.mat3x3<f32>());
599 auto* rhs = vec3<f32>(1.f, 1.f, 1.f);
600
601 auto* expr =
602 create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, Expr("mat"), rhs);
603
604 WrapInFunction(var, expr);
605
606 spirv::Builder& b = Build();
607
608 b.push_function(Function{});
609 ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
610
611 EXPECT_EQ(b.GenerateBinaryExpression(expr), 9u) << b.error();
612 EXPECT_EQ(DumpInstructions(b.types()),
613 R"(%5 = OpTypeFloat 32
614 %4 = OpTypeVector %5 3
615 %3 = OpTypeMatrix %4 3
616 %2 = OpTypePointer Function %3
617 %1 = OpVariable %2 Function
618 %7 = OpConstant %5 1
619 %8 = OpConstantComposite %4 %7 %7 %7
620 )");
621 EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
622 R"(%6 = OpLoad %3 %1
623 %9 = OpMatrixTimesVector %4 %6 %8
624 )");
625 }
626
TEST_F(BuilderTest,Binary_Multiply_VectorMatrix)627 TEST_F(BuilderTest, Binary_Multiply_VectorMatrix) {
628 auto* var = Var("mat", ty.mat3x3<f32>());
629 auto* lhs = vec3<f32>(1.f, 1.f, 1.f);
630
631 auto* expr =
632 create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, Expr("mat"));
633
634 WrapInFunction(var, expr);
635
636 spirv::Builder& b = Build();
637
638 b.push_function(Function{});
639 ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
640
641 EXPECT_EQ(b.GenerateBinaryExpression(expr), 9u) << b.error();
642 EXPECT_EQ(DumpInstructions(b.types()),
643 R"(%5 = OpTypeFloat 32
644 %4 = OpTypeVector %5 3
645 %3 = OpTypeMatrix %4 3
646 %2 = OpTypePointer Function %3
647 %1 = OpVariable %2 Function
648 %6 = OpConstant %5 1
649 %7 = OpConstantComposite %4 %6 %6 %6
650 )");
651 EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
652 R"(%8 = OpLoad %3 %1
653 %9 = OpVectorTimesMatrix %4 %7 %8
654 )");
655 }
656
TEST_F(BuilderTest,Binary_Multiply_MatrixMatrix)657 TEST_F(BuilderTest, Binary_Multiply_MatrixMatrix) {
658 auto* var = Var("mat", ty.mat3x3<f32>());
659
660 auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply,
661 Expr("mat"), Expr("mat"));
662
663 WrapInFunction(var, expr);
664
665 spirv::Builder& b = Build();
666
667 b.push_function(Function{});
668 ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
669
670 EXPECT_EQ(b.GenerateBinaryExpression(expr), 8u) << b.error();
671 EXPECT_EQ(DumpInstructions(b.types()),
672 R"(%5 = OpTypeFloat 32
673 %4 = OpTypeVector %5 3
674 %3 = OpTypeMatrix %4 3
675 %2 = OpTypePointer Function %3
676 %1 = OpVariable %2 Function
677 )");
678 EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
679 R"(%6 = OpLoad %3 %1
680 %7 = OpLoad %3 %1
681 %8 = OpMatrixTimesMatrix %3 %6 %7
682 )");
683 }
684
TEST_F(BuilderTest,Binary_LogicalAnd)685 TEST_F(BuilderTest, Binary_LogicalAnd) {
686 auto* lhs =
687 create<ast::BinaryExpression>(ast::BinaryOp::kEqual, Expr(1), Expr(2));
688
689 auto* rhs =
690 create<ast::BinaryExpression>(ast::BinaryOp::kEqual, Expr(3), Expr(4));
691
692 auto* expr =
693 create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd, lhs, rhs);
694
695 WrapInFunction(expr);
696
697 spirv::Builder& b = Build();
698
699 b.push_function(Function{});
700 b.GenerateLabel(b.next_id());
701
702 EXPECT_EQ(b.GenerateBinaryExpression(expr), 12u) << b.error();
703 EXPECT_EQ(DumpInstructions(b.types()),
704 R"(%2 = OpTypeInt 32 1
705 %3 = OpConstant %2 1
706 %4 = OpConstant %2 2
707 %6 = OpTypeBool
708 %9 = OpConstant %2 3
709 %10 = OpConstant %2 4
710 )");
711 EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
712 R"(%1 = OpLabel
713 %5 = OpIEqual %6 %3 %4
714 OpSelectionMerge %7 None
715 OpBranchConditional %5 %8 %7
716 %8 = OpLabel
717 %11 = OpIEqual %6 %9 %10
718 OpBranch %7
719 %7 = OpLabel
720 %12 = OpPhi %6 %5 %1 %11 %8
721 )");
722 }
723
TEST_F(BuilderTest,Binary_LogicalAnd_WithLoads)724 TEST_F(BuilderTest, Binary_LogicalAnd_WithLoads) {
725 auto* a_var =
726 Global("a", ty.bool_(), ast::StorageClass::kPrivate, Expr(true));
727 auto* b_var =
728 Global("b", ty.bool_(), ast::StorageClass::kPrivate, Expr(false));
729
730 auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd,
731 Expr("a"), Expr("b"));
732
733 WrapInFunction(expr);
734
735 spirv::Builder& b = Build();
736
737 b.push_function(Function{});
738 b.GenerateLabel(b.next_id());
739
740 ASSERT_TRUE(b.GenerateGlobalVariable(a_var)) << b.error();
741 ASSERT_TRUE(b.GenerateGlobalVariable(b_var)) << b.error();
742
743 EXPECT_EQ(b.GenerateBinaryExpression(expr), 12u) << b.error();
744 EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeBool
745 %3 = OpConstantTrue %2
746 %5 = OpTypePointer Private %2
747 %4 = OpVariable %5 Private %3
748 %6 = OpConstantFalse %2
749 %7 = OpVariable %5 Private %6
750 )");
751 EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
752 R"(%1 = OpLabel
753 %8 = OpLoad %2 %4
754 OpSelectionMerge %9 None
755 OpBranchConditional %8 %10 %9
756 %10 = OpLabel
757 %11 = OpLoad %2 %7
758 OpBranch %9
759 %9 = OpLabel
760 %12 = OpPhi %2 %8 %1 %11 %10
761 )");
762 }
763
TEST_F(BuilderTest,Binary_logicalOr_Nested_LogicalAnd)764 TEST_F(BuilderTest, Binary_logicalOr_Nested_LogicalAnd) {
765 // Test an expression like
766 // a || (b && c)
767 // From: crbug.com/tint/355
768
769 auto* logical_and_expr = create<ast::BinaryExpression>(
770 ast::BinaryOp::kLogicalAnd, Expr(true), Expr(false));
771
772 auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kLogicalOr,
773 Expr(true), logical_and_expr);
774
775 WrapInFunction(expr);
776
777 spirv::Builder& b = Build();
778
779 b.push_function(Function{});
780 b.GenerateLabel(b.next_id());
781
782 EXPECT_EQ(b.GenerateBinaryExpression(expr), 10u) << b.error();
783 EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeBool
784 %3 = OpConstantTrue %2
785 %8 = OpConstantFalse %2
786 )");
787 EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
788 R"(%1 = OpLabel
789 OpSelectionMerge %4 None
790 OpBranchConditional %3 %4 %5
791 %5 = OpLabel
792 OpSelectionMerge %6 None
793 OpBranchConditional %3 %7 %6
794 %7 = OpLabel
795 OpBranch %6
796 %6 = OpLabel
797 %9 = OpPhi %2 %3 %5 %8 %7
798 OpBranch %4
799 %4 = OpLabel
800 %10 = OpPhi %2 %3 %1 %9 %6
801 )");
802 }
803
TEST_F(BuilderTest,Binary_logicalAnd_Nested_LogicalOr)804 TEST_F(BuilderTest, Binary_logicalAnd_Nested_LogicalOr) {
805 // Test an expression like
806 // a && (b || c)
807 // From: crbug.com/tint/355
808
809 auto* logical_or_expr = create<ast::BinaryExpression>(
810 ast::BinaryOp::kLogicalOr, Expr(true), Expr(false));
811
812 auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd,
813 Expr(true), logical_or_expr);
814
815 WrapInFunction(expr);
816
817 spirv::Builder& b = Build();
818
819 b.push_function(Function{});
820 b.GenerateLabel(b.next_id());
821
822 EXPECT_EQ(b.GenerateBinaryExpression(expr), 10u) << b.error();
823 EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeBool
824 %3 = OpConstantTrue %2
825 %8 = OpConstantFalse %2
826 )");
827 EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
828 R"(%1 = OpLabel
829 OpSelectionMerge %4 None
830 OpBranchConditional %3 %5 %4
831 %5 = OpLabel
832 OpSelectionMerge %6 None
833 OpBranchConditional %3 %6 %7
834 %7 = OpLabel
835 OpBranch %6
836 %6 = OpLabel
837 %9 = OpPhi %2 %3 %5 %8 %7
838 OpBranch %4
839 %4 = OpLabel
840 %10 = OpPhi %2 %3 %1 %9 %6
841 )");
842 }
843
TEST_F(BuilderTest,Binary_LogicalOr)844 TEST_F(BuilderTest, Binary_LogicalOr) {
845 auto* lhs =
846 create<ast::BinaryExpression>(ast::BinaryOp::kEqual, Expr(1), Expr(2));
847
848 auto* rhs =
849 create<ast::BinaryExpression>(ast::BinaryOp::kEqual, Expr(3), Expr(4));
850
851 auto* expr =
852 create<ast::BinaryExpression>(ast::BinaryOp::kLogicalOr, lhs, rhs);
853
854 WrapInFunction(expr);
855
856 spirv::Builder& b = Build();
857
858 b.push_function(Function{});
859 b.GenerateLabel(b.next_id());
860
861 EXPECT_EQ(b.GenerateBinaryExpression(expr), 12u) << b.error();
862 EXPECT_EQ(DumpInstructions(b.types()),
863 R"(%2 = OpTypeInt 32 1
864 %3 = OpConstant %2 1
865 %4 = OpConstant %2 2
866 %6 = OpTypeBool
867 %9 = OpConstant %2 3
868 %10 = OpConstant %2 4
869 )");
870 EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
871 R"(%1 = OpLabel
872 %5 = OpIEqual %6 %3 %4
873 OpSelectionMerge %7 None
874 OpBranchConditional %5 %7 %8
875 %8 = OpLabel
876 %11 = OpIEqual %6 %9 %10
877 OpBranch %7
878 %7 = OpLabel
879 %12 = OpPhi %6 %5 %1 %11 %8
880 )");
881 }
882
TEST_F(BuilderTest,Binary_LogicalOr_WithLoads)883 TEST_F(BuilderTest, Binary_LogicalOr_WithLoads) {
884 auto* a_var =
885 Global("a", ty.bool_(), ast::StorageClass::kPrivate, Expr(true));
886 auto* b_var =
887 Global("b", ty.bool_(), ast::StorageClass::kPrivate, Expr(false));
888
889 auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kLogicalOr,
890 Expr("a"), Expr("b"));
891
892 WrapInFunction(expr);
893
894 spirv::Builder& b = Build();
895
896 b.push_function(Function{});
897 b.GenerateLabel(b.next_id());
898
899 ASSERT_TRUE(b.GenerateGlobalVariable(a_var)) << b.error();
900 ASSERT_TRUE(b.GenerateGlobalVariable(b_var)) << b.error();
901
902 EXPECT_EQ(b.GenerateBinaryExpression(expr), 12u) << b.error();
903 EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeBool
904 %3 = OpConstantTrue %2
905 %5 = OpTypePointer Private %2
906 %4 = OpVariable %5 Private %3
907 %6 = OpConstantFalse %2
908 %7 = OpVariable %5 Private %6
909 )");
910 EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
911 R"(%1 = OpLabel
912 %8 = OpLoad %2 %4
913 OpSelectionMerge %9 None
914 OpBranchConditional %8 %9 %10
915 %10 = OpLabel
916 %11 = OpLoad %2 %7
917 OpBranch %9
918 %9 = OpLabel
919 %12 = OpPhi %2 %8 %1 %11 %10
920 )");
921 }
922
923 namespace BinaryArithVectorScalar {
924
925 enum class Type { f32, i32, u32 };
MakeVectorExpr(ProgramBuilder * builder,Type type)926 static const ast::Expression* MakeVectorExpr(ProgramBuilder* builder,
927 Type type) {
928 switch (type) {
929 case Type::f32:
930 return builder->vec3<ProgramBuilder::f32>(1.f, 1.f, 1.f);
931 case Type::i32:
932 return builder->vec3<ProgramBuilder::i32>(1, 1, 1);
933 case Type::u32:
934 return builder->vec3<ProgramBuilder::u32>(1u, 1u, 1u);
935 }
936 return nullptr;
937 }
MakeScalarExpr(ProgramBuilder * builder,Type type)938 static const ast::Expression* MakeScalarExpr(ProgramBuilder* builder,
939 Type type) {
940 switch (type) {
941 case Type::f32:
942 return builder->Expr(1.f);
943 case Type::i32:
944 return builder->Expr(1);
945 case Type::u32:
946 return builder->Expr(1u);
947 }
948 return nullptr;
949 }
OpTypeDecl(Type type)950 static std::string OpTypeDecl(Type type) {
951 switch (type) {
952 case Type::f32:
953 return "OpTypeFloat 32";
954 case Type::i32:
955 return "OpTypeInt 32 1";
956 case Type::u32:
957 return "OpTypeInt 32 0";
958 }
959 return {};
960 }
961
962 struct Param {
963 Type type;
964 ast::BinaryOp op;
965 std::string name;
966 };
967
968 using BinaryArithVectorScalarTest = TestParamHelper<Param>;
TEST_P(BinaryArithVectorScalarTest,VectorScalar)969 TEST_P(BinaryArithVectorScalarTest, VectorScalar) {
970 auto& param = GetParam();
971
972 const ast::Expression* lhs = MakeVectorExpr(this, param.type);
973 const ast::Expression* rhs = MakeScalarExpr(this, param.type);
974 std::string op_type_decl = OpTypeDecl(param.type);
975
976 auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
977
978 WrapInFunction(expr);
979
980 spirv::Builder& b = Build();
981 ASSERT_TRUE(b.Build()) << b.error();
982
983 EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
984 OpMemoryModel Logical GLSL450
985 OpEntryPoint GLCompute %3 "test_function"
986 OpExecutionMode %3 LocalSize 1 1 1
987 OpName %3 "test_function"
988 %2 = OpTypeVoid
989 %1 = OpTypeFunction %2
990 %6 = )" + op_type_decl + R"(
991 %5 = OpTypeVector %6 3
992 %7 = OpConstant %6 1
993 %8 = OpConstantComposite %5 %7 %7 %7
994 %11 = OpTypePointer Function %5
995 %12 = OpConstantNull %5
996 %3 = OpFunction %2 None %1
997 %4 = OpLabel
998 %10 = OpVariable %11 Function %12
999 %13 = OpCompositeConstruct %5 %7 %7 %7
1000 %9 = )" + param.name + R"( %5 %8 %13
1001 OpReturn
1002 OpFunctionEnd
1003 )");
1004
1005 Validate(b);
1006 }
TEST_P(BinaryArithVectorScalarTest,ScalarVector)1007 TEST_P(BinaryArithVectorScalarTest, ScalarVector) {
1008 auto& param = GetParam();
1009
1010 const ast::Expression* lhs = MakeScalarExpr(this, param.type);
1011 const ast::Expression* rhs = MakeVectorExpr(this, param.type);
1012 std::string op_type_decl = OpTypeDecl(param.type);
1013
1014 auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
1015
1016 WrapInFunction(expr);
1017
1018 spirv::Builder& b = Build();
1019 ASSERT_TRUE(b.Build()) << b.error();
1020
1021 EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
1022 OpMemoryModel Logical GLSL450
1023 OpEntryPoint GLCompute %3 "test_function"
1024 OpExecutionMode %3 LocalSize 1 1 1
1025 OpName %3 "test_function"
1026 %2 = OpTypeVoid
1027 %1 = OpTypeFunction %2
1028 %5 = )" + op_type_decl + R"(
1029 %6 = OpConstant %5 1
1030 %7 = OpTypeVector %5 3
1031 %8 = OpConstantComposite %7 %6 %6 %6
1032 %11 = OpTypePointer Function %7
1033 %12 = OpConstantNull %7
1034 %3 = OpFunction %2 None %1
1035 %4 = OpLabel
1036 %10 = OpVariable %11 Function %12
1037 %13 = OpCompositeConstruct %7 %6 %6 %6
1038 %9 = )" + param.name + R"( %7 %13 %8
1039 OpReturn
1040 OpFunctionEnd
1041 )");
1042
1043 Validate(b);
1044 }
1045 INSTANTIATE_TEST_SUITE_P(
1046 BuilderTest,
1047 BinaryArithVectorScalarTest,
1048 testing::Values(Param{Type::f32, ast::BinaryOp::kAdd, "OpFAdd"},
1049 Param{Type::f32, ast::BinaryOp::kDivide, "OpFDiv"},
1050 // NOTE: Modulo not allowed on mixed float scalar-vector
1051 // Param{Type::f32, ast::BinaryOp::kModulo, "OpFMod"},
1052 // NOTE: We test f32 multiplies separately as we emit
1053 // OpVectorTimesScalar for this case
1054 // Param{Type::i32, ast::BinaryOp::kMultiply, "OpIMul"},
1055 Param{Type::f32, ast::BinaryOp::kSubtract, "OpFSub"},
1056
1057 Param{Type::i32, ast::BinaryOp::kAdd, "OpIAdd"},
1058 Param{Type::i32, ast::BinaryOp::kDivide, "OpSDiv"},
1059 Param{Type::i32, ast::BinaryOp::kModulo, "OpSMod"},
1060 Param{Type::i32, ast::BinaryOp::kMultiply, "OpIMul"},
1061 Param{Type::i32, ast::BinaryOp::kSubtract, "OpISub"},
1062
1063 Param{Type::u32, ast::BinaryOp::kAdd, "OpIAdd"},
1064 Param{Type::u32, ast::BinaryOp::kDivide, "OpUDiv"},
1065 Param{Type::u32, ast::BinaryOp::kModulo, "OpUMod"},
1066 Param{Type::u32, ast::BinaryOp::kMultiply, "OpIMul"},
1067 Param{Type::u32, ast::BinaryOp::kSubtract, "OpISub"}));
1068
1069 using BinaryArithVectorScalarMultiplyTest = TestParamHelper<Param>;
TEST_P(BinaryArithVectorScalarMultiplyTest,VectorScalar)1070 TEST_P(BinaryArithVectorScalarMultiplyTest, VectorScalar) {
1071 auto& param = GetParam();
1072
1073 const ast::Expression* lhs = MakeVectorExpr(this, param.type);
1074 const ast::Expression* rhs = MakeScalarExpr(this, param.type);
1075 std::string op_type_decl = OpTypeDecl(param.type);
1076
1077 auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
1078
1079 WrapInFunction(expr);
1080
1081 spirv::Builder& b = Build();
1082 ASSERT_TRUE(b.Build()) << b.error();
1083
1084 EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
1085 OpMemoryModel Logical GLSL450
1086 OpEntryPoint GLCompute %3 "test_function"
1087 OpExecutionMode %3 LocalSize 1 1 1
1088 OpName %3 "test_function"
1089 %2 = OpTypeVoid
1090 %1 = OpTypeFunction %2
1091 %6 = )" + op_type_decl + R"(
1092 %5 = OpTypeVector %6 3
1093 %7 = OpConstant %6 1
1094 %8 = OpConstantComposite %5 %7 %7 %7
1095 %3 = OpFunction %2 None %1
1096 %4 = OpLabel
1097 %9 = OpVectorTimesScalar %5 %8 %7
1098 OpReturn
1099 OpFunctionEnd
1100 )");
1101
1102 Validate(b);
1103 }
TEST_P(BinaryArithVectorScalarMultiplyTest,ScalarVector)1104 TEST_P(BinaryArithVectorScalarMultiplyTest, ScalarVector) {
1105 auto& param = GetParam();
1106
1107 const ast::Expression* lhs = MakeScalarExpr(this, param.type);
1108 const ast::Expression* rhs = MakeVectorExpr(this, param.type);
1109 std::string op_type_decl = OpTypeDecl(param.type);
1110
1111 auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
1112
1113 WrapInFunction(expr);
1114
1115 spirv::Builder& b = Build();
1116 ASSERT_TRUE(b.Build()) << b.error();
1117
1118 EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
1119 OpMemoryModel Logical GLSL450
1120 OpEntryPoint GLCompute %3 "test_function"
1121 OpExecutionMode %3 LocalSize 1 1 1
1122 OpName %3 "test_function"
1123 %2 = OpTypeVoid
1124 %1 = OpTypeFunction %2
1125 %5 = )" + op_type_decl + R"(
1126 %6 = OpConstant %5 1
1127 %7 = OpTypeVector %5 3
1128 %8 = OpConstantComposite %7 %6 %6 %6
1129 %3 = OpFunction %2 None %1
1130 %4 = OpLabel
1131 %9 = OpVectorTimesScalar %7 %8 %6
1132 OpReturn
1133 OpFunctionEnd
1134 )");
1135
1136 Validate(b);
1137 }
1138 INSTANTIATE_TEST_SUITE_P(BuilderTest,
1139 BinaryArithVectorScalarMultiplyTest,
1140 testing::Values(Param{
1141 Type::f32, ast::BinaryOp::kMultiply, "OpFMul"}));
1142
1143 } // namespace BinaryArithVectorScalar
1144
1145 namespace BinaryArithMatrixMatrix {
1146
1147 struct Param {
1148 ast::BinaryOp op;
1149 std::string name;
1150 };
1151
1152 using BinaryArithMatrixMatrix = TestParamHelper<Param>;
TEST_P(BinaryArithMatrixMatrix,AddOrSubtract)1153 TEST_P(BinaryArithMatrixMatrix, AddOrSubtract) {
1154 auto& param = GetParam();
1155
1156 const ast::Expression* lhs = mat3x4<f32>();
1157 const ast::Expression* rhs = mat3x4<f32>();
1158
1159 auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
1160
1161 WrapInFunction(expr);
1162
1163 spirv::Builder& b = Build();
1164 ASSERT_TRUE(b.Build()) << b.error();
1165
1166 EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
1167 OpMemoryModel Logical GLSL450
1168 OpEntryPoint GLCompute %3 "test_function"
1169 OpExecutionMode %3 LocalSize 1 1 1
1170 OpName %3 "test_function"
1171 %2 = OpTypeVoid
1172 %1 = OpTypeFunction %2
1173 %7 = OpTypeFloat 32
1174 %6 = OpTypeVector %7 4
1175 %5 = OpTypeMatrix %6 3
1176 %8 = OpConstantNull %5
1177 %3 = OpFunction %2 None %1
1178 %4 = OpLabel
1179 %10 = OpCompositeExtract %6 %8 0
1180 %11 = OpCompositeExtract %6 %8 0
1181 %12 = )" + param.name + R"( %6 %10 %11
1182 %13 = OpCompositeExtract %6 %8 1
1183 %14 = OpCompositeExtract %6 %8 1
1184 %15 = )" + param.name + R"( %6 %13 %14
1185 %16 = OpCompositeExtract %6 %8 2
1186 %17 = OpCompositeExtract %6 %8 2
1187 %18 = )" + param.name + R"( %6 %16 %17
1188 %19 = OpCompositeConstruct %5 %12 %15 %18
1189 OpReturn
1190 OpFunctionEnd
1191 )");
1192
1193 Validate(b);
1194 }
1195 INSTANTIATE_TEST_SUITE_P( //
1196 BuilderTest,
1197 BinaryArithMatrixMatrix,
1198 testing::Values(Param{ast::BinaryOp::kAdd, "OpFAdd"},
1199 Param{ast::BinaryOp::kSubtract, "OpFSub"}));
1200
1201 using BinaryArithMatrixMatrixMultiply = TestParamHelper<Param>;
TEST_P(BinaryArithMatrixMatrixMultiply,Multiply)1202 TEST_P(BinaryArithMatrixMatrixMultiply, Multiply) {
1203 auto& param = GetParam();
1204
1205 const ast::Expression* lhs = mat3x4<f32>();
1206 const ast::Expression* rhs = mat4x3<f32>();
1207
1208 auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
1209
1210 WrapInFunction(expr);
1211
1212 spirv::Builder& b = Build();
1213 ASSERT_TRUE(b.Build()) << b.error();
1214
1215 EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
1216 OpMemoryModel Logical GLSL450
1217 OpEntryPoint GLCompute %3 "test_function"
1218 OpExecutionMode %3 LocalSize 1 1 1
1219 OpName %3 "test_function"
1220 %2 = OpTypeVoid
1221 %1 = OpTypeFunction %2
1222 %7 = OpTypeFloat 32
1223 %6 = OpTypeVector %7 4
1224 %5 = OpTypeMatrix %6 3
1225 %8 = OpConstantNull %5
1226 %10 = OpTypeVector %7 3
1227 %9 = OpTypeMatrix %10 4
1228 %11 = OpConstantNull %9
1229 %13 = OpTypeMatrix %6 4
1230 %3 = OpFunction %2 None %1
1231 %4 = OpLabel
1232 %12 = OpMatrixTimesMatrix %13 %8 %11
1233 OpReturn
1234 OpFunctionEnd
1235 )");
1236
1237 Validate(b);
1238 }
1239 INSTANTIATE_TEST_SUITE_P( //
1240 BuilderTest,
1241 BinaryArithMatrixMatrixMultiply,
1242 testing::Values(Param{ast::BinaryOp::kMultiply, "OpFMul"}));
1243
1244 } // namespace BinaryArithMatrixMatrix
1245
1246 } // namespace
1247 } // namespace spirv
1248 } // namespace writer
1249 } // namespace tint
1250