• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/ast/bitcast_expression.h"
16 #include "src/resolver/resolver.h"
17 #include "src/resolver/resolver_test_helper.h"
18 
19 #include "gmock/gmock.h"
20 
21 namespace tint {
22 namespace resolver {
23 namespace {
24 
25 struct Type {
26   template <typename T>
Createtint::resolver::__anon4f7462e50111::Type27   static constexpr Type Create() {
28     return Type{builder::DataType<T>::AST, builder::DataType<T>::Sem,
29                 builder::DataType<T>::Expr};
30   }
31 
32   builder::ast_type_func_ptr ast;
33   builder::sem_type_func_ptr sem;
34   builder::ast_expr_func_ptr expr;
35 };
36 
37 static constexpr Type kNumericScalars[] = {
38     Type::Create<builder::f32>(),
39     Type::Create<builder::i32>(),
40     Type::Create<builder::u32>(),
41 };
42 static constexpr Type kVec2NumericScalars[] = {
43     Type::Create<builder::vec2<builder::f32>>(),
44     Type::Create<builder::vec2<builder::i32>>(),
45     Type::Create<builder::vec2<builder::u32>>(),
46 };
47 static constexpr Type kVec3NumericScalars[] = {
48     Type::Create<builder::vec3<builder::f32>>(),
49     Type::Create<builder::vec3<builder::i32>>(),
50     Type::Create<builder::vec3<builder::u32>>(),
51 };
52 static constexpr Type kVec4NumericScalars[] = {
53     Type::Create<builder::vec4<builder::f32>>(),
54     Type::Create<builder::vec4<builder::i32>>(),
55     Type::Create<builder::vec4<builder::u32>>(),
56 };
57 static constexpr Type kInvalid[] = {
58     // A non-exhaustive selection of uncastable types
59     Type::Create<bool>(),
60     Type::Create<builder::vec2<bool>>(),
61     Type::Create<builder::vec3<bool>>(),
62     Type::Create<builder::vec4<bool>>(),
63     Type::Create<builder::array<2, builder::i32>>(),
64     Type::Create<builder::array<3, builder::u32>>(),
65     Type::Create<builder::array<4, builder::f32>>(),
66     Type::Create<builder::array<5, bool>>(),
67     Type::Create<builder::mat2x2<builder::f32>>(),
68     Type::Create<builder::mat3x3<builder::f32>>(),
69     Type::Create<builder::mat4x4<builder::f32>>(),
70     Type::Create<builder::ptr<builder::i32>>(),
71     Type::Create<builder::ptr<builder::array<2, builder::i32>>>(),
72     Type::Create<builder::ptr<builder::mat2x2<builder::f32>>>(),
73 };
74 
75 using ResolverBitcastValidationTest =
76     ResolverTestWithParam<std::tuple<Type, Type>>;
77 
78 ////////////////////////////////////////////////////////////////////////////////
79 // Valid bitcasts
80 ////////////////////////////////////////////////////////////////////////////////
81 using ResolverBitcastValidationTestPass = ResolverBitcastValidationTest;
TEST_P(ResolverBitcastValidationTestPass,Test)82 TEST_P(ResolverBitcastValidationTestPass, Test) {
83   auto src = std::get<0>(GetParam());
84   auto dst = std::get<1>(GetParam());
85 
86   auto* cast = Bitcast(dst.ast(*this), src.expr(*this, 0));
87   WrapInFunction(cast);
88 
89   ASSERT_TRUE(r()->Resolve()) << r()->error();
90   EXPECT_EQ(TypeOf(cast), dst.sem(*this));
91 }
92 INSTANTIATE_TEST_SUITE_P(Scalars,
93                          ResolverBitcastValidationTestPass,
94                          testing::Combine(testing::ValuesIn(kNumericScalars),
95                                           testing::ValuesIn(kNumericScalars)));
96 INSTANTIATE_TEST_SUITE_P(
97     Vec2,
98     ResolverBitcastValidationTestPass,
99     testing::Combine(testing::ValuesIn(kVec2NumericScalars),
100                      testing::ValuesIn(kVec2NumericScalars)));
101 INSTANTIATE_TEST_SUITE_P(
102     Vec3,
103     ResolverBitcastValidationTestPass,
104     testing::Combine(testing::ValuesIn(kVec3NumericScalars),
105                      testing::ValuesIn(kVec3NumericScalars)));
106 INSTANTIATE_TEST_SUITE_P(
107     Vec4,
108     ResolverBitcastValidationTestPass,
109     testing::Combine(testing::ValuesIn(kVec4NumericScalars),
110                      testing::ValuesIn(kVec4NumericScalars)));
111 
112 ////////////////////////////////////////////////////////////////////////////////
113 // Invalid source type for bitcasts
114 ////////////////////////////////////////////////////////////////////////////////
115 using ResolverBitcastValidationTestInvalidSrcTy = ResolverBitcastValidationTest;
TEST_P(ResolverBitcastValidationTestInvalidSrcTy,Test)116 TEST_P(ResolverBitcastValidationTestInvalidSrcTy, Test) {
117   auto src = std::get<0>(GetParam());
118   auto dst = std::get<1>(GetParam());
119 
120   auto* cast = Bitcast(dst.ast(*this), Expr(Source{{12, 34}}, "src"));
121   WrapInFunction(Const("src", nullptr, src.expr(*this, 0)), cast);
122 
123   auto expected = "12:34 error: '" + src.sem(*this)->FriendlyName(Symbols()) +
124                   "' cannot be bitcast";
125 
126   EXPECT_FALSE(r()->Resolve());
127   EXPECT_EQ(r()->error(), expected);
128 }
129 INSTANTIATE_TEST_SUITE_P(Scalars,
130                          ResolverBitcastValidationTestInvalidSrcTy,
131                          testing::Combine(testing::ValuesIn(kInvalid),
132                                           testing::ValuesIn(kNumericScalars)));
133 INSTANTIATE_TEST_SUITE_P(
134     Vec2,
135     ResolverBitcastValidationTestInvalidSrcTy,
136     testing::Combine(testing::ValuesIn(kInvalid),
137                      testing::ValuesIn(kVec2NumericScalars)));
138 INSTANTIATE_TEST_SUITE_P(
139     Vec3,
140     ResolverBitcastValidationTestInvalidSrcTy,
141     testing::Combine(testing::ValuesIn(kInvalid),
142                      testing::ValuesIn(kVec3NumericScalars)));
143 INSTANTIATE_TEST_SUITE_P(
144     Vec4,
145     ResolverBitcastValidationTestInvalidSrcTy,
146     testing::Combine(testing::ValuesIn(kInvalid),
147                      testing::ValuesIn(kVec4NumericScalars)));
148 
149 ////////////////////////////////////////////////////////////////////////////////
150 // Invalid target type for bitcasts
151 ////////////////////////////////////////////////////////////////////////////////
152 using ResolverBitcastValidationTestInvalidDstTy = ResolverBitcastValidationTest;
TEST_P(ResolverBitcastValidationTestInvalidDstTy,Test)153 TEST_P(ResolverBitcastValidationTestInvalidDstTy, Test) {
154   auto src = std::get<0>(GetParam());
155   auto dst = std::get<1>(GetParam());
156 
157   // Use an alias so we can put a Source on the bitcast type
158   Alias("T", dst.ast(*this));
159   WrapInFunction(
160       Bitcast(ty.type_name(Source{{12, 34}}, "T"), src.expr(*this, 0)));
161 
162   auto expected = "12:34 error: cannot bitcast to '" +
163                   dst.sem(*this)->FriendlyName(Symbols()) + "'";
164 
165   EXPECT_FALSE(r()->Resolve());
166   EXPECT_EQ(r()->error(), expected);
167 }
168 INSTANTIATE_TEST_SUITE_P(Scalars,
169                          ResolverBitcastValidationTestInvalidDstTy,
170                          testing::Combine(testing::ValuesIn(kNumericScalars),
171                                           testing::ValuesIn(kInvalid)));
172 INSTANTIATE_TEST_SUITE_P(
173     Vec2,
174     ResolverBitcastValidationTestInvalidDstTy,
175     testing::Combine(testing::ValuesIn(kVec2NumericScalars),
176                      testing::ValuesIn(kInvalid)));
177 INSTANTIATE_TEST_SUITE_P(
178     Vec3,
179     ResolverBitcastValidationTestInvalidDstTy,
180     testing::Combine(testing::ValuesIn(kVec3NumericScalars),
181                      testing::ValuesIn(kInvalid)));
182 INSTANTIATE_TEST_SUITE_P(
183     Vec4,
184     ResolverBitcastValidationTestInvalidDstTy,
185     testing::Combine(testing::ValuesIn(kVec4NumericScalars),
186                      testing::ValuesIn(kInvalid)));
187 
188 ////////////////////////////////////////////////////////////////////////////////
189 // Incompatible bitcast, but both src and dst types are valid
190 ////////////////////////////////////////////////////////////////////////////////
191 using ResolverBitcastValidationTestIncompatible = ResolverBitcastValidationTest;
TEST_P(ResolverBitcastValidationTestIncompatible,Test)192 TEST_P(ResolverBitcastValidationTestIncompatible, Test) {
193   auto src = std::get<0>(GetParam());
194   auto dst = std::get<1>(GetParam());
195 
196   WrapInFunction(Bitcast(Source{{12, 34}}, dst.ast(*this), src.expr(*this, 0)));
197 
198   auto expected = "12:34 error: cannot bitcast from '" +
199                   src.sem(*this)->FriendlyName(Symbols()) + "' to '" +
200                   dst.sem(*this)->FriendlyName(Symbols()) + "'";
201 
202   EXPECT_FALSE(r()->Resolve());
203   EXPECT_EQ(r()->error(), expected);
204 }
205 INSTANTIATE_TEST_SUITE_P(
206     ScalarToVec2,
207     ResolverBitcastValidationTestIncompatible,
208     testing::Combine(testing::ValuesIn(kNumericScalars),
209                      testing::ValuesIn(kVec2NumericScalars)));
210 INSTANTIATE_TEST_SUITE_P(
211     Vec2ToVec3,
212     ResolverBitcastValidationTestIncompatible,
213     testing::Combine(testing::ValuesIn(kVec2NumericScalars),
214                      testing::ValuesIn(kVec3NumericScalars)));
215 INSTANTIATE_TEST_SUITE_P(
216     Vec3ToVec4,
217     ResolverBitcastValidationTestIncompatible,
218     testing::Combine(testing::ValuesIn(kVec3NumericScalars),
219                      testing::ValuesIn(kVec4NumericScalars)));
220 INSTANTIATE_TEST_SUITE_P(
221     Vec4ToScalar,
222     ResolverBitcastValidationTestIncompatible,
223     testing::Combine(testing::ValuesIn(kVec4NumericScalars),
224                      testing::ValuesIn(kNumericScalars)));
225 
226 }  // namespace
227 }  // namespace resolver
228 }  // namespace tint
229