1 // Copyright 2020 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/ast/call_statement.h"
16 #include "src/ast/variable_decl_statement.h"
17 #include "src/writer/hlsl/test_helper.h"
18
19 namespace tint {
20 namespace writer {
21 namespace hlsl {
22 namespace {
23
24 using HlslGeneratorImplTest_Binary = TestHelper;
25
26 struct BinaryData {
27 const char* result;
28 ast::BinaryOp op;
29 };
operator <<(std::ostream & out,BinaryData data)30 inline std::ostream& operator<<(std::ostream& out, BinaryData data) {
31 out << data.op;
32 return out;
33 }
34
35 using HlslBinaryTest = TestParamHelper<BinaryData>;
TEST_P(HlslBinaryTest,Emit_f32)36 TEST_P(HlslBinaryTest, Emit_f32) {
37 auto params = GetParam();
38
39 // Skip ops that are illegal for this type
40 if (params.op == ast::BinaryOp::kAnd || params.op == ast::BinaryOp::kOr ||
41 params.op == ast::BinaryOp::kXor ||
42 params.op == ast::BinaryOp::kShiftLeft ||
43 params.op == ast::BinaryOp::kShiftRight) {
44 return;
45 }
46
47 Global("left", ty.f32(), ast::StorageClass::kPrivate);
48 Global("right", ty.f32(), ast::StorageClass::kPrivate);
49
50 auto* left = Expr("left");
51 auto* right = Expr("right");
52
53 auto* expr = create<ast::BinaryExpression>(params.op, left, right);
54
55 WrapInFunction(expr);
56
57 GeneratorImpl& gen = Build();
58
59 std::stringstream out;
60 ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
61 EXPECT_EQ(out.str(), params.result);
62 }
TEST_P(HlslBinaryTest,Emit_u32)63 TEST_P(HlslBinaryTest, Emit_u32) {
64 auto params = GetParam();
65
66 Global("left", ty.u32(), ast::StorageClass::kPrivate);
67 Global("right", ty.u32(), ast::StorageClass::kPrivate);
68
69 auto* left = Expr("left");
70 auto* right = Expr("right");
71
72 auto* expr = create<ast::BinaryExpression>(params.op, left, right);
73
74 WrapInFunction(expr);
75
76 GeneratorImpl& gen = Build();
77
78 std::stringstream out;
79 ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
80 EXPECT_EQ(out.str(), params.result);
81 }
TEST_P(HlslBinaryTest,Emit_i32)82 TEST_P(HlslBinaryTest, Emit_i32) {
83 auto params = GetParam();
84
85 // Skip ops that are illegal for this type
86 if (params.op == ast::BinaryOp::kShiftLeft ||
87 params.op == ast::BinaryOp::kShiftRight) {
88 return;
89 }
90
91 Global("left", ty.i32(), ast::StorageClass::kPrivate);
92 Global("right", ty.i32(), ast::StorageClass::kPrivate);
93
94 auto* left = Expr("left");
95 auto* right = Expr("right");
96
97 auto* expr = create<ast::BinaryExpression>(params.op, left, right);
98
99 WrapInFunction(expr);
100
101 GeneratorImpl& gen = Build();
102
103 std::stringstream out;
104 ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
105 EXPECT_EQ(out.str(), params.result);
106 }
107 INSTANTIATE_TEST_SUITE_P(
108 HlslGeneratorImplTest,
109 HlslBinaryTest,
110 testing::Values(
111 BinaryData{"(left & right)", ast::BinaryOp::kAnd},
112 BinaryData{"(left | right)", ast::BinaryOp::kOr},
113 BinaryData{"(left ^ right)", ast::BinaryOp::kXor},
114 BinaryData{"(left == right)", ast::BinaryOp::kEqual},
115 BinaryData{"(left != right)", ast::BinaryOp::kNotEqual},
116 BinaryData{"(left < right)", ast::BinaryOp::kLessThan},
117 BinaryData{"(left > right)", ast::BinaryOp::kGreaterThan},
118 BinaryData{"(left <= right)", ast::BinaryOp::kLessThanEqual},
119 BinaryData{"(left >= right)", ast::BinaryOp::kGreaterThanEqual},
120 BinaryData{"(left << right)", ast::BinaryOp::kShiftLeft},
121 BinaryData{"(left >> right)", ast::BinaryOp::kShiftRight},
122 BinaryData{"(left + right)", ast::BinaryOp::kAdd},
123 BinaryData{"(left - right)", ast::BinaryOp::kSubtract},
124 BinaryData{"(left * right)", ast::BinaryOp::kMultiply},
125 BinaryData{"(left / right)", ast::BinaryOp::kDivide},
126 BinaryData{"(left % right)", ast::BinaryOp::kModulo}));
127
TEST_F(HlslGeneratorImplTest_Binary,Multiply_VectorScalar)128 TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorScalar) {
129 auto* lhs = vec3<f32>(1.f, 1.f, 1.f);
130 auto* rhs = Expr(1.f);
131
132 auto* expr =
133 create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
134
135 WrapInFunction(expr);
136
137 GeneratorImpl& gen = Build();
138
139 std::stringstream out;
140 EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
141 EXPECT_EQ(out.str(),
142 "(float3(1.0f, 1.0f, 1.0f) * "
143 "1.0f)");
144 }
145
TEST_F(HlslGeneratorImplTest_Binary,Multiply_ScalarVector)146 TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarVector) {
147 auto* lhs = Expr(1.f);
148 auto* rhs = vec3<f32>(1.f, 1.f, 1.f);
149
150 auto* expr =
151 create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
152
153 WrapInFunction(expr);
154
155 GeneratorImpl& gen = Build();
156
157 std::stringstream out;
158 EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
159 EXPECT_EQ(out.str(),
160 "(1.0f * float3(1.0f, 1.0f, "
161 "1.0f))");
162 }
163
TEST_F(HlslGeneratorImplTest_Binary,Multiply_MatrixScalar)164 TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixScalar) {
165 Global("mat", ty.mat3x3<f32>(), ast::StorageClass::kPrivate);
166 auto* lhs = Expr("mat");
167 auto* rhs = Expr(1.f);
168
169 auto* expr =
170 create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
171 WrapInFunction(expr);
172
173 GeneratorImpl& gen = Build();
174
175 std::stringstream out;
176 EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
177 EXPECT_EQ(out.str(), "(mat * 1.0f)");
178 }
179
TEST_F(HlslGeneratorImplTest_Binary,Multiply_ScalarMatrix)180 TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarMatrix) {
181 Global("mat", ty.mat3x3<f32>(), ast::StorageClass::kPrivate);
182 auto* lhs = Expr(1.f);
183 auto* rhs = Expr("mat");
184
185 auto* expr =
186 create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
187 WrapInFunction(expr);
188
189 GeneratorImpl& gen = Build();
190
191 std::stringstream out;
192 EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
193 EXPECT_EQ(out.str(), "(1.0f * mat)");
194 }
195
TEST_F(HlslGeneratorImplTest_Binary,Multiply_MatrixVector)196 TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixVector) {
197 Global("mat", ty.mat3x3<f32>(), ast::StorageClass::kPrivate);
198 auto* lhs = Expr("mat");
199 auto* rhs = vec3<f32>(1.f, 1.f, 1.f);
200
201 auto* expr =
202 create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
203 WrapInFunction(expr);
204
205 GeneratorImpl& gen = Build();
206
207 std::stringstream out;
208 EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
209 EXPECT_EQ(out.str(), "mul(float3(1.0f, 1.0f, 1.0f), mat)");
210 }
211
TEST_F(HlslGeneratorImplTest_Binary,Multiply_VectorMatrix)212 TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorMatrix) {
213 Global("mat", ty.mat3x3<f32>(), ast::StorageClass::kPrivate);
214 auto* lhs = vec3<f32>(1.f, 1.f, 1.f);
215 auto* rhs = Expr("mat");
216
217 auto* expr =
218 create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
219 WrapInFunction(expr);
220
221 GeneratorImpl& gen = Build();
222
223 std::stringstream out;
224 EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
225 EXPECT_EQ(out.str(), "mul(mat, float3(1.0f, 1.0f, 1.0f))");
226 }
227
TEST_F(HlslGeneratorImplTest_Binary,Multiply_MatrixMatrix)228 TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixMatrix) {
229 Global("lhs", ty.mat3x3<f32>(), ast::StorageClass::kPrivate);
230 Global("rhs", ty.mat3x3<f32>(), ast::StorageClass::kPrivate);
231
232 auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply,
233 Expr("lhs"), Expr("rhs"));
234 WrapInFunction(expr);
235
236 GeneratorImpl& gen = Build();
237
238 std::stringstream out;
239 EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
240 EXPECT_EQ(out.str(), "mul(rhs, lhs)");
241 }
242
TEST_F(HlslGeneratorImplTest_Binary,Logical_And)243 TEST_F(HlslGeneratorImplTest_Binary, Logical_And) {
244 Global("a", ty.bool_(), ast::StorageClass::kPrivate);
245 Global("b", ty.bool_(), ast::StorageClass::kPrivate);
246
247 auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd,
248 Expr("a"), Expr("b"));
249 WrapInFunction(expr);
250
251 GeneratorImpl& gen = Build();
252
253 std::stringstream out;
254 ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
255 EXPECT_EQ(out.str(), "(tint_tmp)");
256 EXPECT_EQ(gen.result(), R"(bool tint_tmp = a;
257 if (tint_tmp) {
258 tint_tmp = b;
259 }
260 )");
261 }
262
TEST_F(HlslGeneratorImplTest_Binary,Logical_Multi)263 TEST_F(HlslGeneratorImplTest_Binary, Logical_Multi) {
264 // (a && b) || (c || d)
265 Global("a", ty.bool_(), ast::StorageClass::kPrivate);
266 Global("b", ty.bool_(), ast::StorageClass::kPrivate);
267 Global("c", ty.bool_(), ast::StorageClass::kPrivate);
268 Global("d", ty.bool_(), ast::StorageClass::kPrivate);
269
270 auto* expr = create<ast::BinaryExpression>(
271 ast::BinaryOp::kLogicalOr,
272 create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd, Expr("a"),
273 Expr("b")),
274 create<ast::BinaryExpression>(ast::BinaryOp::kLogicalOr, Expr("c"),
275 Expr("d")));
276 WrapInFunction(expr);
277
278 GeneratorImpl& gen = Build();
279
280 std::stringstream out;
281 ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
282 EXPECT_EQ(out.str(), "(tint_tmp)");
283 EXPECT_EQ(gen.result(), R"(bool tint_tmp_1 = a;
284 if (tint_tmp_1) {
285 tint_tmp_1 = b;
286 }
287 bool tint_tmp = (tint_tmp_1);
288 if (!tint_tmp) {
289 bool tint_tmp_2 = c;
290 if (!tint_tmp_2) {
291 tint_tmp_2 = d;
292 }
293 tint_tmp = (tint_tmp_2);
294 }
295 )");
296 }
297
TEST_F(HlslGeneratorImplTest_Binary,Logical_Or)298 TEST_F(HlslGeneratorImplTest_Binary, Logical_Or) {
299 Global("a", ty.bool_(), ast::StorageClass::kPrivate);
300 Global("b", ty.bool_(), ast::StorageClass::kPrivate);
301
302 auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kLogicalOr,
303 Expr("a"), Expr("b"));
304 WrapInFunction(expr);
305
306 GeneratorImpl& gen = Build();
307
308 std::stringstream out;
309 ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
310 EXPECT_EQ(out.str(), "(tint_tmp)");
311 EXPECT_EQ(gen.result(), R"(bool tint_tmp = a;
312 if (!tint_tmp) {
313 tint_tmp = b;
314 }
315 )");
316 }
317
TEST_F(HlslGeneratorImplTest_Binary,If_WithLogical)318 TEST_F(HlslGeneratorImplTest_Binary, If_WithLogical) {
319 // if (a && b) {
320 // return 1;
321 // } else if (b || c) {
322 // return 2;
323 // } else {
324 // return 3;
325 // }
326
327 Global("a", ty.bool_(), ast::StorageClass::kPrivate);
328 Global("b", ty.bool_(), ast::StorageClass::kPrivate);
329 Global("c", ty.bool_(), ast::StorageClass::kPrivate);
330
331 auto* expr = If(create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd,
332 Expr("a"), Expr("b")),
333 Block(Return(1)),
334 Else(create<ast::BinaryExpression>(ast::BinaryOp::kLogicalOr,
335 Expr("b"), Expr("c")),
336 Block(Return(2))),
337 Else(Block(Return(3))));
338 Func("func", {}, ty.i32(), {WrapInStatement(expr)});
339
340 GeneratorImpl& gen = Build();
341
342 ASSERT_TRUE(gen.EmitStatement(expr)) << gen.error();
343 EXPECT_EQ(gen.result(), R"(bool tint_tmp = a;
344 if (tint_tmp) {
345 tint_tmp = b;
346 }
347 if ((tint_tmp)) {
348 return 1;
349 } else {
350 bool tint_tmp_1 = b;
351 if (!tint_tmp_1) {
352 tint_tmp_1 = c;
353 }
354 if ((tint_tmp_1)) {
355 return 2;
356 } else {
357 return 3;
358 }
359 }
360 )");
361 }
362
TEST_F(HlslGeneratorImplTest_Binary,Return_WithLogical)363 TEST_F(HlslGeneratorImplTest_Binary, Return_WithLogical) {
364 // return (a && b) || c;
365
366 Global("a", ty.bool_(), ast::StorageClass::kPrivate);
367 Global("b", ty.bool_(), ast::StorageClass::kPrivate);
368 Global("c", ty.bool_(), ast::StorageClass::kPrivate);
369
370 auto* expr = Return(create<ast::BinaryExpression>(
371 ast::BinaryOp::kLogicalOr,
372 create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd, Expr("a"),
373 Expr("b")),
374 Expr("c")));
375 Func("func", {}, ty.bool_(), {WrapInStatement(expr)});
376
377 GeneratorImpl& gen = Build();
378
379 ASSERT_TRUE(gen.EmitStatement(expr)) << gen.error();
380 EXPECT_EQ(gen.result(), R"(bool tint_tmp_1 = a;
381 if (tint_tmp_1) {
382 tint_tmp_1 = b;
383 }
384 bool tint_tmp = (tint_tmp_1);
385 if (!tint_tmp) {
386 tint_tmp = c;
387 }
388 return (tint_tmp);
389 )");
390 }
391
TEST_F(HlslGeneratorImplTest_Binary,Assign_WithLogical)392 TEST_F(HlslGeneratorImplTest_Binary, Assign_WithLogical) {
393 // a = (b || c) && d;
394
395 Global("a", ty.bool_(), ast::StorageClass::kPrivate);
396 Global("b", ty.bool_(), ast::StorageClass::kPrivate);
397 Global("c", ty.bool_(), ast::StorageClass::kPrivate);
398 Global("d", ty.bool_(), ast::StorageClass::kPrivate);
399
400 auto* expr = Assign(
401 Expr("a"), create<ast::BinaryExpression>(
402 ast::BinaryOp::kLogicalAnd,
403 create<ast::BinaryExpression>(ast::BinaryOp::kLogicalOr,
404 Expr("b"), Expr("c")),
405 Expr("d")));
406 WrapInFunction(expr);
407
408 GeneratorImpl& gen = Build();
409
410 ASSERT_TRUE(gen.EmitStatement(expr)) << gen.error();
411 EXPECT_EQ(gen.result(), R"(bool tint_tmp_1 = b;
412 if (!tint_tmp_1) {
413 tint_tmp_1 = c;
414 }
415 bool tint_tmp = (tint_tmp_1);
416 if (tint_tmp) {
417 tint_tmp = d;
418 }
419 a = (tint_tmp);
420 )");
421 }
422
TEST_F(HlslGeneratorImplTest_Binary,Decl_WithLogical)423 TEST_F(HlslGeneratorImplTest_Binary, Decl_WithLogical) {
424 // var a : bool = (b && c) || d;
425
426 Global("b", ty.bool_(), ast::StorageClass::kPrivate);
427 Global("c", ty.bool_(), ast::StorageClass::kPrivate);
428 Global("d", ty.bool_(), ast::StorageClass::kPrivate);
429
430 auto* var = Var("a", ty.bool_(), ast::StorageClass::kNone,
431 create<ast::BinaryExpression>(
432 ast::BinaryOp::kLogicalOr,
433 create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd,
434 Expr("b"), Expr("c")),
435 Expr("d")));
436
437 auto* decl = Decl(var);
438 WrapInFunction(decl);
439
440 GeneratorImpl& gen = Build();
441
442 ASSERT_TRUE(gen.EmitStatement(decl)) << gen.error();
443 EXPECT_EQ(gen.result(), R"(bool tint_tmp_1 = b;
444 if (tint_tmp_1) {
445 tint_tmp_1 = c;
446 }
447 bool tint_tmp = (tint_tmp_1);
448 if (!tint_tmp) {
449 tint_tmp = d;
450 }
451 bool a = (tint_tmp);
452 )");
453 }
454
TEST_F(HlslGeneratorImplTest_Binary,Call_WithLogical)455 TEST_F(HlslGeneratorImplTest_Binary, Call_WithLogical) {
456 // foo(a && b, c || d, (a || c) && (b || d))
457
458 Func("foo",
459 {
460 Param(Sym(), ty.bool_()),
461 Param(Sym(), ty.bool_()),
462 Param(Sym(), ty.bool_()),
463 },
464 ty.void_(), ast::StatementList{}, ast::DecorationList{});
465 Global("a", ty.bool_(), ast::StorageClass::kPrivate);
466 Global("b", ty.bool_(), ast::StorageClass::kPrivate);
467 Global("c", ty.bool_(), ast::StorageClass::kPrivate);
468 Global("d", ty.bool_(), ast::StorageClass::kPrivate);
469
470 ast::ExpressionList params;
471 params.push_back(create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd,
472 Expr("a"), Expr("b")));
473 params.push_back(create<ast::BinaryExpression>(ast::BinaryOp::kLogicalOr,
474 Expr("c"), Expr("d")));
475 params.push_back(create<ast::BinaryExpression>(
476 ast::BinaryOp::kLogicalAnd,
477 create<ast::BinaryExpression>(ast::BinaryOp::kLogicalOr, Expr("a"),
478 Expr("c")),
479 create<ast::BinaryExpression>(ast::BinaryOp::kLogicalOr, Expr("b"),
480 Expr("d"))));
481
482 auto* expr = CallStmt(Call("foo", params));
483 WrapInFunction(expr);
484
485 GeneratorImpl& gen = Build();
486
487 ASSERT_TRUE(gen.EmitStatement(expr)) << gen.error();
488 EXPECT_EQ(gen.result(), R"(bool tint_tmp = a;
489 if (tint_tmp) {
490 tint_tmp = b;
491 }
492 bool tint_tmp_1 = c;
493 if (!tint_tmp_1) {
494 tint_tmp_1 = d;
495 }
496 bool tint_tmp_3 = a;
497 if (!tint_tmp_3) {
498 tint_tmp_3 = c;
499 }
500 bool tint_tmp_2 = (tint_tmp_3);
501 if (tint_tmp_2) {
502 bool tint_tmp_4 = b;
503 if (!tint_tmp_4) {
504 tint_tmp_4 = d;
505 }
506 tint_tmp_2 = (tint_tmp_4);
507 }
508 foo((tint_tmp), (tint_tmp_1), (tint_tmp_2));
509 )");
510 }
511
TEST_F(HlslGeneratorImplTest_Binary,DivideByLiteralZero_i32)512 TEST_F(HlslGeneratorImplTest_Binary, DivideByLiteralZero_i32) {
513 Global("a", ty.i32(), ast::StorageClass::kPrivate);
514
515 auto* expr = Div("a", 0);
516 WrapInFunction(expr);
517
518 GeneratorImpl& gen = Build();
519
520 std::stringstream out;
521 ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
522 EXPECT_EQ(out.str(), R"((a / 1))");
523 }
524
TEST_F(HlslGeneratorImplTest_Binary,DivideByLiteralZero_u32)525 TEST_F(HlslGeneratorImplTest_Binary, DivideByLiteralZero_u32) {
526 Global("a", ty.u32(), ast::StorageClass::kPrivate);
527
528 auto* expr = Div("a", 0u);
529 WrapInFunction(expr);
530
531 GeneratorImpl& gen = Build();
532
533 std::stringstream out;
534 ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
535 EXPECT_EQ(out.str(), R"((a / 1u))");
536 }
537
538 } // namespace
539 } // namespace hlsl
540 } // namespace writer
541 } // namespace tint
542