1 // Copyright 2021 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/ast/bitcast_expression.h"
16 #include "src/resolver/resolver.h"
17 #include "src/resolver/resolver_test_helper.h"
18
19 #include "gmock/gmock.h"
20
21 namespace tint {
22 namespace resolver {
23 namespace {
24
25 struct Type {
26 template <typename T>
Createtint::resolver::__anon4f7462e50111::Type27 static constexpr Type Create() {
28 return Type{builder::DataType<T>::AST, builder::DataType<T>::Sem,
29 builder::DataType<T>::Expr};
30 }
31
32 builder::ast_type_func_ptr ast;
33 builder::sem_type_func_ptr sem;
34 builder::ast_expr_func_ptr expr;
35 };
36
37 static constexpr Type kNumericScalars[] = {
38 Type::Create<builder::f32>(),
39 Type::Create<builder::i32>(),
40 Type::Create<builder::u32>(),
41 };
42 static constexpr Type kVec2NumericScalars[] = {
43 Type::Create<builder::vec2<builder::f32>>(),
44 Type::Create<builder::vec2<builder::i32>>(),
45 Type::Create<builder::vec2<builder::u32>>(),
46 };
47 static constexpr Type kVec3NumericScalars[] = {
48 Type::Create<builder::vec3<builder::f32>>(),
49 Type::Create<builder::vec3<builder::i32>>(),
50 Type::Create<builder::vec3<builder::u32>>(),
51 };
52 static constexpr Type kVec4NumericScalars[] = {
53 Type::Create<builder::vec4<builder::f32>>(),
54 Type::Create<builder::vec4<builder::i32>>(),
55 Type::Create<builder::vec4<builder::u32>>(),
56 };
57 static constexpr Type kInvalid[] = {
58 // A non-exhaustive selection of uncastable types
59 Type::Create<bool>(),
60 Type::Create<builder::vec2<bool>>(),
61 Type::Create<builder::vec3<bool>>(),
62 Type::Create<builder::vec4<bool>>(),
63 Type::Create<builder::array<2, builder::i32>>(),
64 Type::Create<builder::array<3, builder::u32>>(),
65 Type::Create<builder::array<4, builder::f32>>(),
66 Type::Create<builder::array<5, bool>>(),
67 Type::Create<builder::mat2x2<builder::f32>>(),
68 Type::Create<builder::mat3x3<builder::f32>>(),
69 Type::Create<builder::mat4x4<builder::f32>>(),
70 Type::Create<builder::ptr<builder::i32>>(),
71 Type::Create<builder::ptr<builder::array<2, builder::i32>>>(),
72 Type::Create<builder::ptr<builder::mat2x2<builder::f32>>>(),
73 };
74
75 using ResolverBitcastValidationTest =
76 ResolverTestWithParam<std::tuple<Type, Type>>;
77
78 ////////////////////////////////////////////////////////////////////////////////
79 // Valid bitcasts
80 ////////////////////////////////////////////////////////////////////////////////
81 using ResolverBitcastValidationTestPass = ResolverBitcastValidationTest;
TEST_P(ResolverBitcastValidationTestPass,Test)82 TEST_P(ResolverBitcastValidationTestPass, Test) {
83 auto src = std::get<0>(GetParam());
84 auto dst = std::get<1>(GetParam());
85
86 auto* cast = Bitcast(dst.ast(*this), src.expr(*this, 0));
87 WrapInFunction(cast);
88
89 ASSERT_TRUE(r()->Resolve()) << r()->error();
90 EXPECT_EQ(TypeOf(cast), dst.sem(*this));
91 }
92 INSTANTIATE_TEST_SUITE_P(Scalars,
93 ResolverBitcastValidationTestPass,
94 testing::Combine(testing::ValuesIn(kNumericScalars),
95 testing::ValuesIn(kNumericScalars)));
96 INSTANTIATE_TEST_SUITE_P(
97 Vec2,
98 ResolverBitcastValidationTestPass,
99 testing::Combine(testing::ValuesIn(kVec2NumericScalars),
100 testing::ValuesIn(kVec2NumericScalars)));
101 INSTANTIATE_TEST_SUITE_P(
102 Vec3,
103 ResolverBitcastValidationTestPass,
104 testing::Combine(testing::ValuesIn(kVec3NumericScalars),
105 testing::ValuesIn(kVec3NumericScalars)));
106 INSTANTIATE_TEST_SUITE_P(
107 Vec4,
108 ResolverBitcastValidationTestPass,
109 testing::Combine(testing::ValuesIn(kVec4NumericScalars),
110 testing::ValuesIn(kVec4NumericScalars)));
111
112 ////////////////////////////////////////////////////////////////////////////////
113 // Invalid source type for bitcasts
114 ////////////////////////////////////////////////////////////////////////////////
115 using ResolverBitcastValidationTestInvalidSrcTy = ResolverBitcastValidationTest;
TEST_P(ResolverBitcastValidationTestInvalidSrcTy,Test)116 TEST_P(ResolverBitcastValidationTestInvalidSrcTy, Test) {
117 auto src = std::get<0>(GetParam());
118 auto dst = std::get<1>(GetParam());
119
120 auto* cast = Bitcast(dst.ast(*this), Expr(Source{{12, 34}}, "src"));
121 WrapInFunction(Const("src", nullptr, src.expr(*this, 0)), cast);
122
123 auto expected = "12:34 error: '" + src.sem(*this)->FriendlyName(Symbols()) +
124 "' cannot be bitcast";
125
126 EXPECT_FALSE(r()->Resolve());
127 EXPECT_EQ(r()->error(), expected);
128 }
129 INSTANTIATE_TEST_SUITE_P(Scalars,
130 ResolverBitcastValidationTestInvalidSrcTy,
131 testing::Combine(testing::ValuesIn(kInvalid),
132 testing::ValuesIn(kNumericScalars)));
133 INSTANTIATE_TEST_SUITE_P(
134 Vec2,
135 ResolverBitcastValidationTestInvalidSrcTy,
136 testing::Combine(testing::ValuesIn(kInvalid),
137 testing::ValuesIn(kVec2NumericScalars)));
138 INSTANTIATE_TEST_SUITE_P(
139 Vec3,
140 ResolverBitcastValidationTestInvalidSrcTy,
141 testing::Combine(testing::ValuesIn(kInvalid),
142 testing::ValuesIn(kVec3NumericScalars)));
143 INSTANTIATE_TEST_SUITE_P(
144 Vec4,
145 ResolverBitcastValidationTestInvalidSrcTy,
146 testing::Combine(testing::ValuesIn(kInvalid),
147 testing::ValuesIn(kVec4NumericScalars)));
148
149 ////////////////////////////////////////////////////////////////////////////////
150 // Invalid target type for bitcasts
151 ////////////////////////////////////////////////////////////////////////////////
152 using ResolverBitcastValidationTestInvalidDstTy = ResolverBitcastValidationTest;
TEST_P(ResolverBitcastValidationTestInvalidDstTy,Test)153 TEST_P(ResolverBitcastValidationTestInvalidDstTy, Test) {
154 auto src = std::get<0>(GetParam());
155 auto dst = std::get<1>(GetParam());
156
157 // Use an alias so we can put a Source on the bitcast type
158 Alias("T", dst.ast(*this));
159 WrapInFunction(
160 Bitcast(ty.type_name(Source{{12, 34}}, "T"), src.expr(*this, 0)));
161
162 auto expected = "12:34 error: cannot bitcast to '" +
163 dst.sem(*this)->FriendlyName(Symbols()) + "'";
164
165 EXPECT_FALSE(r()->Resolve());
166 EXPECT_EQ(r()->error(), expected);
167 }
168 INSTANTIATE_TEST_SUITE_P(Scalars,
169 ResolverBitcastValidationTestInvalidDstTy,
170 testing::Combine(testing::ValuesIn(kNumericScalars),
171 testing::ValuesIn(kInvalid)));
172 INSTANTIATE_TEST_SUITE_P(
173 Vec2,
174 ResolverBitcastValidationTestInvalidDstTy,
175 testing::Combine(testing::ValuesIn(kVec2NumericScalars),
176 testing::ValuesIn(kInvalid)));
177 INSTANTIATE_TEST_SUITE_P(
178 Vec3,
179 ResolverBitcastValidationTestInvalidDstTy,
180 testing::Combine(testing::ValuesIn(kVec3NumericScalars),
181 testing::ValuesIn(kInvalid)));
182 INSTANTIATE_TEST_SUITE_P(
183 Vec4,
184 ResolverBitcastValidationTestInvalidDstTy,
185 testing::Combine(testing::ValuesIn(kVec4NumericScalars),
186 testing::ValuesIn(kInvalid)));
187
188 ////////////////////////////////////////////////////////////////////////////////
189 // Incompatible bitcast, but both src and dst types are valid
190 ////////////////////////////////////////////////////////////////////////////////
191 using ResolverBitcastValidationTestIncompatible = ResolverBitcastValidationTest;
TEST_P(ResolverBitcastValidationTestIncompatible,Test)192 TEST_P(ResolverBitcastValidationTestIncompatible, Test) {
193 auto src = std::get<0>(GetParam());
194 auto dst = std::get<1>(GetParam());
195
196 WrapInFunction(Bitcast(Source{{12, 34}}, dst.ast(*this), src.expr(*this, 0)));
197
198 auto expected = "12:34 error: cannot bitcast from '" +
199 src.sem(*this)->FriendlyName(Symbols()) + "' to '" +
200 dst.sem(*this)->FriendlyName(Symbols()) + "'";
201
202 EXPECT_FALSE(r()->Resolve());
203 EXPECT_EQ(r()->error(), expected);
204 }
205 INSTANTIATE_TEST_SUITE_P(
206 ScalarToVec2,
207 ResolverBitcastValidationTestIncompatible,
208 testing::Combine(testing::ValuesIn(kNumericScalars),
209 testing::ValuesIn(kVec2NumericScalars)));
210 INSTANTIATE_TEST_SUITE_P(
211 Vec2ToVec3,
212 ResolverBitcastValidationTestIncompatible,
213 testing::Combine(testing::ValuesIn(kVec2NumericScalars),
214 testing::ValuesIn(kVec3NumericScalars)));
215 INSTANTIATE_TEST_SUITE_P(
216 Vec3ToVec4,
217 ResolverBitcastValidationTestIncompatible,
218 testing::Combine(testing::ValuesIn(kVec3NumericScalars),
219 testing::ValuesIn(kVec4NumericScalars)));
220 INSTANTIATE_TEST_SUITE_P(
221 Vec4ToScalar,
222 ResolverBitcastValidationTestIncompatible,
223 testing::Combine(testing::ValuesIn(kVec4NumericScalars),
224 testing::ValuesIn(kNumericScalars)));
225
226 } // namespace
227 } // namespace resolver
228 } // namespace tint
229