1 // Copyright 2021 The Tint Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef SRC_RESOLVER_RESOLVER_TEST_HELPER_H_ 16 #define SRC_RESOLVER_RESOLVER_TEST_HELPER_H_ 17 18 #include <memory> 19 #include <string> 20 #include <vector> 21 22 #include "gtest/gtest.h" 23 #include "src/program_builder.h" 24 #include "src/resolver/resolver.h" 25 #include "src/sem/expression.h" 26 #include "src/sem/statement.h" 27 #include "src/sem/variable.h" 28 29 namespace tint { 30 namespace resolver { 31 32 /// Helper class for testing 33 class TestHelper : public ProgramBuilder { 34 public: 35 /// Constructor 36 TestHelper(); 37 38 /// Destructor 39 ~TestHelper() override; 40 41 /// @return a pointer to the Resolver r()42 Resolver* r() const { return resolver_.get(); } 43 44 /// Returns the statement that holds the given expression. 45 /// @param expr the ast::Expression 46 /// @return the ast::Statement of the ast::Expression, or nullptr if the 47 /// expression is not owned by a statement. StmtOf(const ast::Expression * expr)48 const ast::Statement* StmtOf(const ast::Expression* expr) { 49 auto* sem_stmt = Sem().Get(expr)->Stmt(); 50 return sem_stmt ? sem_stmt->Declaration() : nullptr; 51 } 52 53 /// Returns the BlockStatement that holds the given statement. 54 /// @param stmt the ast::Statement 55 /// @return the ast::BlockStatement that holds the ast::Statement, or nullptr 56 /// if the statement is not owned by a BlockStatement. BlockOf(const ast::Statement * stmt)57 const ast::BlockStatement* BlockOf(const ast::Statement* stmt) { 58 auto* sem_stmt = Sem().Get(stmt); 59 return sem_stmt ? sem_stmt->Block()->Declaration() : nullptr; 60 } 61 62 /// Returns the BlockStatement that holds the given expression. 63 /// @param expr the ast::Expression 64 /// @return the ast::Statement of the ast::Expression, or nullptr if the 65 /// expression is not indirectly owned by a BlockStatement. BlockOf(const ast::Expression * expr)66 const ast::BlockStatement* BlockOf(const ast::Expression* expr) { 67 auto* sem_stmt = Sem().Get(expr)->Stmt(); 68 return sem_stmt ? sem_stmt->Block()->Declaration() : nullptr; 69 } 70 71 /// Returns the semantic variable for the given identifier expression. 72 /// @param expr the identifier expression 73 /// @return the resolved sem::Variable of the identifier, or nullptr if 74 /// the expression did not resolve to a variable. VarOf(const ast::Expression * expr)75 const sem::Variable* VarOf(const ast::Expression* expr) { 76 auto* sem_ident = Sem().Get(expr); 77 auto* var_user = sem_ident ? sem_ident->As<sem::VariableUser>() : nullptr; 78 return var_user ? var_user->Variable() : nullptr; 79 } 80 81 /// Checks that all the users of the given variable are as expected 82 /// @param var the variable to check 83 /// @param expected_users the expected users of the variable 84 /// @return true if all users are as expected CheckVarUsers(const ast::Variable * var,std::vector<const ast::Expression * > && expected_users)85 bool CheckVarUsers(const ast::Variable* var, 86 std::vector<const ast::Expression*>&& expected_users) { 87 auto& var_users = Sem().Get(var)->Users(); 88 if (var_users.size() != expected_users.size()) { 89 return false; 90 } 91 for (size_t i = 0; i < var_users.size(); i++) { 92 if (var_users[i]->Declaration() != expected_users[i]) { 93 return false; 94 } 95 } 96 return true; 97 } 98 99 /// @param type a type 100 /// @returns the name for `type` that closely resembles how it would be 101 /// declared in WGSL. FriendlyName(const ast::Type * type)102 std::string FriendlyName(const ast::Type* type) { 103 return type->FriendlyName(Symbols()); 104 } 105 106 /// @param type a type 107 /// @returns the name for `type` that closely resembles how it would be 108 /// declared in WGSL. FriendlyName(const sem::Type * type)109 std::string FriendlyName(const sem::Type* type) { 110 return type->FriendlyName(Symbols()); 111 } 112 113 private: 114 std::unique_ptr<Resolver> resolver_; 115 }; 116 117 class ResolverTest : public TestHelper, public testing::Test {}; 118 119 template <typename T> 120 class ResolverTestWithParam : public TestHelper, 121 public testing::TestWithParam<T> {}; 122 123 namespace builder { 124 125 using i32 = ProgramBuilder::i32; 126 using u32 = ProgramBuilder::u32; 127 using f32 = ProgramBuilder::f32; 128 129 template <int N, typename T> 130 struct vec {}; 131 132 template <typename T> 133 using vec2 = vec<2, T>; 134 135 template <typename T> 136 using vec3 = vec<3, T>; 137 138 template <typename T> 139 using vec4 = vec<4, T>; 140 141 template <int N, int M, typename T> 142 struct mat {}; 143 144 template <typename T> 145 using mat2x2 = mat<2, 2, T>; 146 147 template <typename T> 148 using mat2x3 = mat<2, 3, T>; 149 150 template <typename T> 151 using mat3x2 = mat<3, 2, T>; 152 153 template <typename T> 154 using mat3x3 = mat<3, 3, T>; 155 156 template <typename T> 157 using mat4x4 = mat<4, 4, T>; 158 159 template <int N, typename T> 160 struct array {}; 161 162 template <typename TO, int ID = 0> 163 struct alias {}; 164 165 template <typename TO> 166 using alias1 = alias<TO, 1>; 167 168 template <typename TO> 169 using alias2 = alias<TO, 2>; 170 171 template <typename TO> 172 using alias3 = alias<TO, 3>; 173 174 template <typename TO> 175 struct ptr {}; 176 177 using ast_type_func_ptr = const ast::Type* (*)(ProgramBuilder& b); 178 using ast_expr_func_ptr = const ast::Expression* (*)(ProgramBuilder& b, 179 int elem_value); 180 using sem_type_func_ptr = const sem::Type* (*)(ProgramBuilder& b); 181 182 template <typename T> 183 struct DataType {}; 184 185 /// Helper for building bool types and expressions 186 template <> 187 struct DataType<bool> { 188 /// false as bool is not a composite type 189 static constexpr bool is_composite = false; 190 191 /// @param b the ProgramBuilder 192 /// @return a new AST bool type 193 static inline const ast::Type* AST(ProgramBuilder& b) { return b.ty.bool_(); } 194 /// @param b the ProgramBuilder 195 /// @return the semantic bool type 196 static inline const sem::Type* Sem(ProgramBuilder& b) { 197 return b.create<sem::Bool>(); 198 } 199 /// @param b the ProgramBuilder 200 /// @param elem_value the b 201 /// @return a new AST expression of the bool type 202 static inline const ast::Expression* Expr(ProgramBuilder& b, int elem_value) { 203 return b.Expr(elem_value == 0); 204 } 205 }; 206 207 /// Helper for building i32 types and expressions 208 template <> 209 struct DataType<i32> { 210 /// false as i32 is not a composite type 211 static constexpr bool is_composite = false; 212 213 /// @param b the ProgramBuilder 214 /// @return a new AST i32 type 215 static inline const ast::Type* AST(ProgramBuilder& b) { return b.ty.i32(); } 216 /// @param b the ProgramBuilder 217 /// @return the semantic i32 type 218 static inline const sem::Type* Sem(ProgramBuilder& b) { 219 return b.create<sem::I32>(); 220 } 221 /// @param b the ProgramBuilder 222 /// @param elem_value the value i32 will be initialized with 223 /// @return a new AST i32 literal value expression 224 static inline const ast::Expression* Expr(ProgramBuilder& b, int elem_value) { 225 return b.Expr(static_cast<i32>(elem_value)); 226 } 227 }; 228 229 /// Helper for building u32 types and expressions 230 template <> 231 struct DataType<u32> { 232 /// false as u32 is not a composite type 233 static constexpr bool is_composite = false; 234 235 /// @param b the ProgramBuilder 236 /// @return a new AST u32 type 237 static inline const ast::Type* AST(ProgramBuilder& b) { return b.ty.u32(); } 238 /// @param b the ProgramBuilder 239 /// @return the semantic u32 type 240 static inline const sem::Type* Sem(ProgramBuilder& b) { 241 return b.create<sem::U32>(); 242 } 243 /// @param b the ProgramBuilder 244 /// @param elem_value the value u32 will be initialized with 245 /// @return a new AST u32 literal value expression 246 static inline const ast::Expression* Expr(ProgramBuilder& b, int elem_value) { 247 return b.Expr(static_cast<u32>(elem_value)); 248 } 249 }; 250 251 /// Helper for building f32 types and expressions 252 template <> 253 struct DataType<f32> { 254 /// false as f32 is not a composite type 255 static constexpr bool is_composite = false; 256 257 /// @param b the ProgramBuilder 258 /// @return a new AST f32 type 259 static inline const ast::Type* AST(ProgramBuilder& b) { return b.ty.f32(); } 260 /// @param b the ProgramBuilder 261 /// @return the semantic f32 type 262 static inline const sem::Type* Sem(ProgramBuilder& b) { 263 return b.create<sem::F32>(); 264 } 265 /// @param b the ProgramBuilder 266 /// @param elem_value the value f32 will be initialized with 267 /// @return a new AST f32 literal value expression 268 static inline const ast::Expression* Expr(ProgramBuilder& b, int elem_value) { 269 return b.Expr(static_cast<f32>(elem_value)); 270 } 271 }; 272 273 /// Helper for building vector types and expressions 274 template <int N, typename T> 275 struct DataType<vec<N, T>> { 276 /// true as vectors are a composite type 277 static constexpr bool is_composite = true; 278 279 /// @param b the ProgramBuilder 280 /// @return a new AST vector type 281 static inline const ast::Type* AST(ProgramBuilder& b) { 282 return b.ty.vec(DataType<T>::AST(b), N); 283 } 284 /// @param b the ProgramBuilder 285 /// @return the semantic vector type 286 static inline const sem::Type* Sem(ProgramBuilder& b) { 287 return b.create<sem::Vector>(DataType<T>::Sem(b), N); 288 } 289 /// @param b the ProgramBuilder 290 /// @param elem_value the value each element in the vector will be initialized 291 /// with 292 /// @return a new AST vector value expression 293 static inline const ast::Expression* Expr(ProgramBuilder& b, int elem_value) { 294 return b.Construct(AST(b), ExprArgs(b, elem_value)); 295 } 296 297 /// @param b the ProgramBuilder 298 /// @param elem_value the value each element will be initialized with 299 /// @return the list of expressions that are used to construct the vector 300 static inline ast::ExpressionList ExprArgs(ProgramBuilder& b, 301 int elem_value) { 302 ast::ExpressionList args; 303 for (int i = 0; i < N; i++) { 304 args.emplace_back(DataType<T>::Expr(b, elem_value)); 305 } 306 return args; 307 } 308 }; 309 310 /// Helper for building matrix types and expressions 311 template <int N, int M, typename T> 312 struct DataType<mat<N, M, T>> { 313 /// true as matrices are a composite type 314 static constexpr bool is_composite = true; 315 316 /// @param b the ProgramBuilder 317 /// @return a new AST matrix type 318 static inline const ast::Type* AST(ProgramBuilder& b) { 319 return b.ty.mat(DataType<T>::AST(b), N, M); 320 } 321 /// @param b the ProgramBuilder 322 /// @return the semantic matrix type 323 static inline const sem::Type* Sem(ProgramBuilder& b) { 324 auto* column_type = b.create<sem::Vector>(DataType<T>::Sem(b), M); 325 return b.create<sem::Matrix>(column_type, N); 326 } 327 /// @param b the ProgramBuilder 328 /// @param elem_value the value each element in the matrix will be initialized 329 /// with 330 /// @return a new AST matrix value expression 331 static inline const ast::Expression* Expr(ProgramBuilder& b, int elem_value) { 332 return b.Construct(AST(b), ExprArgs(b, elem_value)); 333 } 334 335 /// @param b the ProgramBuilder 336 /// @param elem_value the value each element will be initialized with 337 /// @return the list of expressions that are used to construct the matrix 338 static inline ast::ExpressionList ExprArgs(ProgramBuilder& b, 339 int elem_value) { 340 ast::ExpressionList args; 341 for (int i = 0; i < N; i++) { 342 args.emplace_back(DataType<vec<M, T>>::Expr(b, elem_value)); 343 } 344 return args; 345 } 346 }; 347 348 /// Helper for building alias types and expressions 349 template <typename T, int ID> 350 struct DataType<alias<T, ID>> { 351 /// true if the aliased type is a composite type 352 static constexpr bool is_composite = DataType<T>::is_composite; 353 354 /// @param b the ProgramBuilder 355 /// @return a new AST alias type 356 static inline const ast::Type* AST(ProgramBuilder& b) { 357 auto name = b.Symbols().Register("alias_" + std::to_string(ID)); 358 if (!b.AST().LookupType(name)) { 359 auto* type = DataType<T>::AST(b); 360 b.AST().AddTypeDecl(b.ty.alias(name, type)); 361 } 362 return b.create<ast::TypeName>(name); 363 } 364 /// @param b the ProgramBuilder 365 /// @return the semantic aliased type 366 static inline const sem::Type* Sem(ProgramBuilder& b) { 367 return DataType<T>::Sem(b); 368 } 369 370 /// @param b the ProgramBuilder 371 /// @param elem_value the value nested elements will be initialized with 372 /// @return a new AST expression of the alias type 373 template <bool IS_COMPOSITE = is_composite> 374 static inline traits::EnableIf<!IS_COMPOSITE, const ast::Expression*> Expr( 375 ProgramBuilder& b, 376 int elem_value) { 377 // Cast 378 return b.Construct(AST(b), DataType<T>::Expr(b, elem_value)); 379 } 380 381 /// @param b the ProgramBuilder 382 /// @param elem_value the value nested elements will be initialized with 383 /// @return a new AST expression of the alias type 384 template <bool IS_COMPOSITE = is_composite> 385 static inline traits::EnableIf<IS_COMPOSITE, const ast::Expression*> Expr( 386 ProgramBuilder& b, 387 int elem_value) { 388 // Construct 389 return b.Construct(AST(b), DataType<T>::ExprArgs(b, elem_value)); 390 } 391 }; 392 393 /// Helper for building pointer types and expressions 394 template <typename T> 395 struct DataType<ptr<T>> { 396 /// true if the pointer type is a composite type 397 static constexpr bool is_composite = false; 398 399 /// @param b the ProgramBuilder 400 /// @return a new AST alias type 401 static inline const ast::Type* AST(ProgramBuilder& b) { 402 return b.create<ast::Pointer>(DataType<T>::AST(b), 403 ast::StorageClass::kPrivate, 404 ast::Access::kReadWrite); 405 } 406 /// @param b the ProgramBuilder 407 /// @return the semantic aliased type 408 static inline const sem::Type* Sem(ProgramBuilder& b) { 409 return b.create<sem::Pointer>(DataType<T>::Sem(b), 410 ast::StorageClass::kPrivate, 411 ast::Access::kReadWrite); 412 } 413 414 /// @param b the ProgramBuilder 415 /// @return a new AST expression of the alias type 416 static inline const ast::Expression* Expr(ProgramBuilder& b, int /*unused*/) { 417 auto sym = b.Symbols().New("global_for_ptr"); 418 b.Global(sym, DataType<T>::AST(b), ast::StorageClass::kPrivate); 419 return b.AddressOf(sym); 420 } 421 }; 422 423 /// Helper for building array types and expressions 424 template <int N, typename T> 425 struct DataType<array<N, T>> { 426 /// true as arrays are a composite type 427 static constexpr bool is_composite = true; 428 429 /// @param b the ProgramBuilder 430 /// @return a new AST array type 431 static inline const ast::Type* AST(ProgramBuilder& b) { 432 return b.ty.array(DataType<T>::AST(b), N); 433 } 434 /// @param b the ProgramBuilder 435 /// @return the semantic array type 436 static inline const sem::Type* Sem(ProgramBuilder& b) { 437 auto* el = DataType<T>::Sem(b); 438 return b.create<sem::Array>( 439 /* element */ el, 440 /* count */ N, 441 /* align */ el->Align(), 442 /* size */ el->Size(), 443 /* stride */ el->Align(), 444 /* implicit_stride */ el->Align()); 445 } 446 /// @param b the ProgramBuilder 447 /// @param elem_value the value each element in the array will be initialized 448 /// with 449 /// @return a new AST array value expression 450 static inline const ast::Expression* Expr(ProgramBuilder& b, int elem_value) { 451 return b.Construct(AST(b), ExprArgs(b, elem_value)); 452 } 453 454 /// @param b the ProgramBuilder 455 /// @param elem_value the value each element will be initialized with 456 /// @return the list of expressions that are used to construct the array 457 static inline ast::ExpressionList ExprArgs(ProgramBuilder& b, 458 int elem_value) { 459 ast::ExpressionList args; 460 for (int i = 0; i < N; i++) { 461 args.emplace_back(DataType<T>::Expr(b, elem_value)); 462 } 463 return args; 464 } 465 }; 466 467 /// Struct of all creation pointer types 468 struct CreatePtrs { 469 /// ast node type create function 470 ast_type_func_ptr ast; 471 /// ast expression type create function 472 ast_expr_func_ptr expr; 473 /// sem type create function 474 sem_type_func_ptr sem; 475 }; 476 477 /// Returns a CreatePtrs struct instance with all creation pointer types for 478 /// type `T` 479 template <typename T> 480 constexpr CreatePtrs CreatePtrsFor() { 481 return {DataType<T>::AST, DataType<T>::Expr, DataType<T>::Sem}; 482 } 483 484 } // namespace builder 485 486 } // namespace resolver 487 } // namespace tint 488 489 #endif // SRC_RESOLVER_RESOLVER_TEST_HELPER_H_ 490