• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/ast/call_statement.h"
16 #include "src/ast/variable_decl_statement.h"
17 #include "src/writer/hlsl/test_helper.h"
18 
19 namespace tint {
20 namespace writer {
21 namespace hlsl {
22 namespace {
23 
24 using HlslGeneratorImplTest_Binary = TestHelper;
25 
26 struct BinaryData {
27   const char* result;
28   ast::BinaryOp op;
29 };
operator <<(std::ostream & out,BinaryData data)30 inline std::ostream& operator<<(std::ostream& out, BinaryData data) {
31   out << data.op;
32   return out;
33 }
34 
35 using HlslBinaryTest = TestParamHelper<BinaryData>;
TEST_P(HlslBinaryTest,Emit_f32)36 TEST_P(HlslBinaryTest, Emit_f32) {
37   auto params = GetParam();
38 
39   // Skip ops that are illegal for this type
40   if (params.op == ast::BinaryOp::kAnd || params.op == ast::BinaryOp::kOr ||
41       params.op == ast::BinaryOp::kXor ||
42       params.op == ast::BinaryOp::kShiftLeft ||
43       params.op == ast::BinaryOp::kShiftRight) {
44     return;
45   }
46 
47   Global("left", ty.f32(), ast::StorageClass::kPrivate);
48   Global("right", ty.f32(), ast::StorageClass::kPrivate);
49 
50   auto* left = Expr("left");
51   auto* right = Expr("right");
52 
53   auto* expr = create<ast::BinaryExpression>(params.op, left, right);
54 
55   WrapInFunction(expr);
56 
57   GeneratorImpl& gen = Build();
58 
59   std::stringstream out;
60   ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
61   EXPECT_EQ(out.str(), params.result);
62 }
TEST_P(HlslBinaryTest,Emit_u32)63 TEST_P(HlslBinaryTest, Emit_u32) {
64   auto params = GetParam();
65 
66   Global("left", ty.u32(), ast::StorageClass::kPrivate);
67   Global("right", ty.u32(), ast::StorageClass::kPrivate);
68 
69   auto* left = Expr("left");
70   auto* right = Expr("right");
71 
72   auto* expr = create<ast::BinaryExpression>(params.op, left, right);
73 
74   WrapInFunction(expr);
75 
76   GeneratorImpl& gen = Build();
77 
78   std::stringstream out;
79   ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
80   EXPECT_EQ(out.str(), params.result);
81 }
TEST_P(HlslBinaryTest,Emit_i32)82 TEST_P(HlslBinaryTest, Emit_i32) {
83   auto params = GetParam();
84 
85   // Skip ops that are illegal for this type
86   if (params.op == ast::BinaryOp::kShiftLeft ||
87       params.op == ast::BinaryOp::kShiftRight) {
88     return;
89   }
90 
91   Global("left", ty.i32(), ast::StorageClass::kPrivate);
92   Global("right", ty.i32(), ast::StorageClass::kPrivate);
93 
94   auto* left = Expr("left");
95   auto* right = Expr("right");
96 
97   auto* expr = create<ast::BinaryExpression>(params.op, left, right);
98 
99   WrapInFunction(expr);
100 
101   GeneratorImpl& gen = Build();
102 
103   std::stringstream out;
104   ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
105   EXPECT_EQ(out.str(), params.result);
106 }
107 INSTANTIATE_TEST_SUITE_P(
108     HlslGeneratorImplTest,
109     HlslBinaryTest,
110     testing::Values(
111         BinaryData{"(left & right)", ast::BinaryOp::kAnd},
112         BinaryData{"(left | right)", ast::BinaryOp::kOr},
113         BinaryData{"(left ^ right)", ast::BinaryOp::kXor},
114         BinaryData{"(left == right)", ast::BinaryOp::kEqual},
115         BinaryData{"(left != right)", ast::BinaryOp::kNotEqual},
116         BinaryData{"(left < right)", ast::BinaryOp::kLessThan},
117         BinaryData{"(left > right)", ast::BinaryOp::kGreaterThan},
118         BinaryData{"(left <= right)", ast::BinaryOp::kLessThanEqual},
119         BinaryData{"(left >= right)", ast::BinaryOp::kGreaterThanEqual},
120         BinaryData{"(left << right)", ast::BinaryOp::kShiftLeft},
121         BinaryData{"(left >> right)", ast::BinaryOp::kShiftRight},
122         BinaryData{"(left + right)", ast::BinaryOp::kAdd},
123         BinaryData{"(left - right)", ast::BinaryOp::kSubtract},
124         BinaryData{"(left * right)", ast::BinaryOp::kMultiply},
125         BinaryData{"(left / right)", ast::BinaryOp::kDivide},
126         BinaryData{"(left % right)", ast::BinaryOp::kModulo}));
127 
TEST_F(HlslGeneratorImplTest_Binary,Multiply_VectorScalar)128 TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorScalar) {
129   auto* lhs = vec3<f32>(1.f, 1.f, 1.f);
130   auto* rhs = Expr(1.f);
131 
132   auto* expr =
133       create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
134 
135   WrapInFunction(expr);
136 
137   GeneratorImpl& gen = Build();
138 
139   std::stringstream out;
140   EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
141   EXPECT_EQ(out.str(),
142             "(float3(1.0f, 1.0f, 1.0f) * "
143             "1.0f)");
144 }
145 
TEST_F(HlslGeneratorImplTest_Binary,Multiply_ScalarVector)146 TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarVector) {
147   auto* lhs = Expr(1.f);
148   auto* rhs = vec3<f32>(1.f, 1.f, 1.f);
149 
150   auto* expr =
151       create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
152 
153   WrapInFunction(expr);
154 
155   GeneratorImpl& gen = Build();
156 
157   std::stringstream out;
158   EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
159   EXPECT_EQ(out.str(),
160             "(1.0f * float3(1.0f, 1.0f, "
161             "1.0f))");
162 }
163 
TEST_F(HlslGeneratorImplTest_Binary,Multiply_MatrixScalar)164 TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixScalar) {
165   Global("mat", ty.mat3x3<f32>(), ast::StorageClass::kPrivate);
166   auto* lhs = Expr("mat");
167   auto* rhs = Expr(1.f);
168 
169   auto* expr =
170       create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
171   WrapInFunction(expr);
172 
173   GeneratorImpl& gen = Build();
174 
175   std::stringstream out;
176   EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
177   EXPECT_EQ(out.str(), "(mat * 1.0f)");
178 }
179 
TEST_F(HlslGeneratorImplTest_Binary,Multiply_ScalarMatrix)180 TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarMatrix) {
181   Global("mat", ty.mat3x3<f32>(), ast::StorageClass::kPrivate);
182   auto* lhs = Expr(1.f);
183   auto* rhs = Expr("mat");
184 
185   auto* expr =
186       create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
187   WrapInFunction(expr);
188 
189   GeneratorImpl& gen = Build();
190 
191   std::stringstream out;
192   EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
193   EXPECT_EQ(out.str(), "(1.0f * mat)");
194 }
195 
TEST_F(HlslGeneratorImplTest_Binary,Multiply_MatrixVector)196 TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixVector) {
197   Global("mat", ty.mat3x3<f32>(), ast::StorageClass::kPrivate);
198   auto* lhs = Expr("mat");
199   auto* rhs = vec3<f32>(1.f, 1.f, 1.f);
200 
201   auto* expr =
202       create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
203   WrapInFunction(expr);
204 
205   GeneratorImpl& gen = Build();
206 
207   std::stringstream out;
208   EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
209   EXPECT_EQ(out.str(), "mul(float3(1.0f, 1.0f, 1.0f), mat)");
210 }
211 
TEST_F(HlslGeneratorImplTest_Binary,Multiply_VectorMatrix)212 TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorMatrix) {
213   Global("mat", ty.mat3x3<f32>(), ast::StorageClass::kPrivate);
214   auto* lhs = vec3<f32>(1.f, 1.f, 1.f);
215   auto* rhs = Expr("mat");
216 
217   auto* expr =
218       create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
219   WrapInFunction(expr);
220 
221   GeneratorImpl& gen = Build();
222 
223   std::stringstream out;
224   EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
225   EXPECT_EQ(out.str(), "mul(mat, float3(1.0f, 1.0f, 1.0f))");
226 }
227 
TEST_F(HlslGeneratorImplTest_Binary,Multiply_MatrixMatrix)228 TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixMatrix) {
229   Global("lhs", ty.mat3x3<f32>(), ast::StorageClass::kPrivate);
230   Global("rhs", ty.mat3x3<f32>(), ast::StorageClass::kPrivate);
231 
232   auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply,
233                                              Expr("lhs"), Expr("rhs"));
234   WrapInFunction(expr);
235 
236   GeneratorImpl& gen = Build();
237 
238   std::stringstream out;
239   EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
240   EXPECT_EQ(out.str(), "mul(rhs, lhs)");
241 }
242 
TEST_F(HlslGeneratorImplTest_Binary,Logical_And)243 TEST_F(HlslGeneratorImplTest_Binary, Logical_And) {
244   Global("a", ty.bool_(), ast::StorageClass::kPrivate);
245   Global("b", ty.bool_(), ast::StorageClass::kPrivate);
246 
247   auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd,
248                                              Expr("a"), Expr("b"));
249   WrapInFunction(expr);
250 
251   GeneratorImpl& gen = Build();
252 
253   std::stringstream out;
254   ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
255   EXPECT_EQ(out.str(), "(tint_tmp)");
256   EXPECT_EQ(gen.result(), R"(bool tint_tmp = a;
257 if (tint_tmp) {
258   tint_tmp = b;
259 }
260 )");
261 }
262 
TEST_F(HlslGeneratorImplTest_Binary,Logical_Multi)263 TEST_F(HlslGeneratorImplTest_Binary, Logical_Multi) {
264   // (a && b) || (c || d)
265   Global("a", ty.bool_(), ast::StorageClass::kPrivate);
266   Global("b", ty.bool_(), ast::StorageClass::kPrivate);
267   Global("c", ty.bool_(), ast::StorageClass::kPrivate);
268   Global("d", ty.bool_(), ast::StorageClass::kPrivate);
269 
270   auto* expr = create<ast::BinaryExpression>(
271       ast::BinaryOp::kLogicalOr,
272       create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd, Expr("a"),
273                                     Expr("b")),
274       create<ast::BinaryExpression>(ast::BinaryOp::kLogicalOr, Expr("c"),
275                                     Expr("d")));
276   WrapInFunction(expr);
277 
278   GeneratorImpl& gen = Build();
279 
280   std::stringstream out;
281   ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
282   EXPECT_EQ(out.str(), "(tint_tmp)");
283   EXPECT_EQ(gen.result(), R"(bool tint_tmp_1 = a;
284 if (tint_tmp_1) {
285   tint_tmp_1 = b;
286 }
287 bool tint_tmp = (tint_tmp_1);
288 if (!tint_tmp) {
289   bool tint_tmp_2 = c;
290   if (!tint_tmp_2) {
291     tint_tmp_2 = d;
292   }
293   tint_tmp = (tint_tmp_2);
294 }
295 )");
296 }
297 
TEST_F(HlslGeneratorImplTest_Binary,Logical_Or)298 TEST_F(HlslGeneratorImplTest_Binary, Logical_Or) {
299   Global("a", ty.bool_(), ast::StorageClass::kPrivate);
300   Global("b", ty.bool_(), ast::StorageClass::kPrivate);
301 
302   auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kLogicalOr,
303                                              Expr("a"), Expr("b"));
304   WrapInFunction(expr);
305 
306   GeneratorImpl& gen = Build();
307 
308   std::stringstream out;
309   ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
310   EXPECT_EQ(out.str(), "(tint_tmp)");
311   EXPECT_EQ(gen.result(), R"(bool tint_tmp = a;
312 if (!tint_tmp) {
313   tint_tmp = b;
314 }
315 )");
316 }
317 
TEST_F(HlslGeneratorImplTest_Binary,If_WithLogical)318 TEST_F(HlslGeneratorImplTest_Binary, If_WithLogical) {
319   // if (a && b) {
320   //   return 1;
321   // } else if (b || c) {
322   //   return 2;
323   // } else {
324   //   return 3;
325   // }
326 
327   Global("a", ty.bool_(), ast::StorageClass::kPrivate);
328   Global("b", ty.bool_(), ast::StorageClass::kPrivate);
329   Global("c", ty.bool_(), ast::StorageClass::kPrivate);
330 
331   auto* expr = If(create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd,
332                                                 Expr("a"), Expr("b")),
333                   Block(Return(1)),
334                   Else(create<ast::BinaryExpression>(ast::BinaryOp::kLogicalOr,
335                                                      Expr("b"), Expr("c")),
336                        Block(Return(2))),
337                   Else(Block(Return(3))));
338   Func("func", {}, ty.i32(), {WrapInStatement(expr)});
339 
340   GeneratorImpl& gen = Build();
341 
342   ASSERT_TRUE(gen.EmitStatement(expr)) << gen.error();
343   EXPECT_EQ(gen.result(), R"(bool tint_tmp = a;
344 if (tint_tmp) {
345   tint_tmp = b;
346 }
347 if ((tint_tmp)) {
348   return 1;
349 } else {
350   bool tint_tmp_1 = b;
351   if (!tint_tmp_1) {
352     tint_tmp_1 = c;
353   }
354   if ((tint_tmp_1)) {
355     return 2;
356   } else {
357     return 3;
358   }
359 }
360 )");
361 }
362 
TEST_F(HlslGeneratorImplTest_Binary,Return_WithLogical)363 TEST_F(HlslGeneratorImplTest_Binary, Return_WithLogical) {
364   // return (a && b) || c;
365 
366   Global("a", ty.bool_(), ast::StorageClass::kPrivate);
367   Global("b", ty.bool_(), ast::StorageClass::kPrivate);
368   Global("c", ty.bool_(), ast::StorageClass::kPrivate);
369 
370   auto* expr = Return(create<ast::BinaryExpression>(
371       ast::BinaryOp::kLogicalOr,
372       create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd, Expr("a"),
373                                     Expr("b")),
374       Expr("c")));
375   Func("func", {}, ty.bool_(), {WrapInStatement(expr)});
376 
377   GeneratorImpl& gen = Build();
378 
379   ASSERT_TRUE(gen.EmitStatement(expr)) << gen.error();
380   EXPECT_EQ(gen.result(), R"(bool tint_tmp_1 = a;
381 if (tint_tmp_1) {
382   tint_tmp_1 = b;
383 }
384 bool tint_tmp = (tint_tmp_1);
385 if (!tint_tmp) {
386   tint_tmp = c;
387 }
388 return (tint_tmp);
389 )");
390 }
391 
TEST_F(HlslGeneratorImplTest_Binary,Assign_WithLogical)392 TEST_F(HlslGeneratorImplTest_Binary, Assign_WithLogical) {
393   // a = (b || c) && d;
394 
395   Global("a", ty.bool_(), ast::StorageClass::kPrivate);
396   Global("b", ty.bool_(), ast::StorageClass::kPrivate);
397   Global("c", ty.bool_(), ast::StorageClass::kPrivate);
398   Global("d", ty.bool_(), ast::StorageClass::kPrivate);
399 
400   auto* expr = Assign(
401       Expr("a"), create<ast::BinaryExpression>(
402                      ast::BinaryOp::kLogicalAnd,
403                      create<ast::BinaryExpression>(ast::BinaryOp::kLogicalOr,
404                                                    Expr("b"), Expr("c")),
405                      Expr("d")));
406   WrapInFunction(expr);
407 
408   GeneratorImpl& gen = Build();
409 
410   ASSERT_TRUE(gen.EmitStatement(expr)) << gen.error();
411   EXPECT_EQ(gen.result(), R"(bool tint_tmp_1 = b;
412 if (!tint_tmp_1) {
413   tint_tmp_1 = c;
414 }
415 bool tint_tmp = (tint_tmp_1);
416 if (tint_tmp) {
417   tint_tmp = d;
418 }
419 a = (tint_tmp);
420 )");
421 }
422 
TEST_F(HlslGeneratorImplTest_Binary,Decl_WithLogical)423 TEST_F(HlslGeneratorImplTest_Binary, Decl_WithLogical) {
424   // var a : bool = (b && c) || d;
425 
426   Global("b", ty.bool_(), ast::StorageClass::kPrivate);
427   Global("c", ty.bool_(), ast::StorageClass::kPrivate);
428   Global("d", ty.bool_(), ast::StorageClass::kPrivate);
429 
430   auto* var = Var("a", ty.bool_(), ast::StorageClass::kNone,
431                   create<ast::BinaryExpression>(
432                       ast::BinaryOp::kLogicalOr,
433                       create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd,
434                                                     Expr("b"), Expr("c")),
435                       Expr("d")));
436 
437   auto* decl = Decl(var);
438   WrapInFunction(decl);
439 
440   GeneratorImpl& gen = Build();
441 
442   ASSERT_TRUE(gen.EmitStatement(decl)) << gen.error();
443   EXPECT_EQ(gen.result(), R"(bool tint_tmp_1 = b;
444 if (tint_tmp_1) {
445   tint_tmp_1 = c;
446 }
447 bool tint_tmp = (tint_tmp_1);
448 if (!tint_tmp) {
449   tint_tmp = d;
450 }
451 bool a = (tint_tmp);
452 )");
453 }
454 
TEST_F(HlslGeneratorImplTest_Binary,Call_WithLogical)455 TEST_F(HlslGeneratorImplTest_Binary, Call_WithLogical) {
456   // foo(a && b, c || d, (a || c) && (b || d))
457 
458   Func("foo",
459        {
460            Param(Sym(), ty.bool_()),
461            Param(Sym(), ty.bool_()),
462            Param(Sym(), ty.bool_()),
463        },
464        ty.void_(), ast::StatementList{}, ast::DecorationList{});
465   Global("a", ty.bool_(), ast::StorageClass::kPrivate);
466   Global("b", ty.bool_(), ast::StorageClass::kPrivate);
467   Global("c", ty.bool_(), ast::StorageClass::kPrivate);
468   Global("d", ty.bool_(), ast::StorageClass::kPrivate);
469 
470   ast::ExpressionList params;
471   params.push_back(create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd,
472                                                  Expr("a"), Expr("b")));
473   params.push_back(create<ast::BinaryExpression>(ast::BinaryOp::kLogicalOr,
474                                                  Expr("c"), Expr("d")));
475   params.push_back(create<ast::BinaryExpression>(
476       ast::BinaryOp::kLogicalAnd,
477       create<ast::BinaryExpression>(ast::BinaryOp::kLogicalOr, Expr("a"),
478                                     Expr("c")),
479       create<ast::BinaryExpression>(ast::BinaryOp::kLogicalOr, Expr("b"),
480                                     Expr("d"))));
481 
482   auto* expr = CallStmt(Call("foo", params));
483   WrapInFunction(expr);
484 
485   GeneratorImpl& gen = Build();
486 
487   ASSERT_TRUE(gen.EmitStatement(expr)) << gen.error();
488   EXPECT_EQ(gen.result(), R"(bool tint_tmp = a;
489 if (tint_tmp) {
490   tint_tmp = b;
491 }
492 bool tint_tmp_1 = c;
493 if (!tint_tmp_1) {
494   tint_tmp_1 = d;
495 }
496 bool tint_tmp_3 = a;
497 if (!tint_tmp_3) {
498   tint_tmp_3 = c;
499 }
500 bool tint_tmp_2 = (tint_tmp_3);
501 if (tint_tmp_2) {
502   bool tint_tmp_4 = b;
503   if (!tint_tmp_4) {
504     tint_tmp_4 = d;
505   }
506   tint_tmp_2 = (tint_tmp_4);
507 }
508 foo((tint_tmp), (tint_tmp_1), (tint_tmp_2));
509 )");
510 }
511 
TEST_F(HlslGeneratorImplTest_Binary,DivideByLiteralZero_i32)512 TEST_F(HlslGeneratorImplTest_Binary, DivideByLiteralZero_i32) {
513   Global("a", ty.i32(), ast::StorageClass::kPrivate);
514 
515   auto* expr = Div("a", 0);
516   WrapInFunction(expr);
517 
518   GeneratorImpl& gen = Build();
519 
520   std::stringstream out;
521   ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
522   EXPECT_EQ(out.str(), R"((a / 1))");
523 }
524 
TEST_F(HlslGeneratorImplTest_Binary,DivideByLiteralZero_u32)525 TEST_F(HlslGeneratorImplTest_Binary, DivideByLiteralZero_u32) {
526   Global("a", ty.u32(), ast::StorageClass::kPrivate);
527 
528   auto* expr = Div("a", 0u);
529   WrapInFunction(expr);
530 
531   GeneratorImpl& gen = Build();
532 
533   std::stringstream out;
534   ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
535   EXPECT_EQ(out.str(), R"((a / 1u))");
536 }
537 
538 }  // namespace
539 }  // namespace hlsl
540 }  // namespace writer
541 }  // namespace tint
542