• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef SRC_RESOLVER_RESOLVER_TEST_HELPER_H_
16 #define SRC_RESOLVER_RESOLVER_TEST_HELPER_H_
17 
18 #include <memory>
19 #include <string>
20 #include <vector>
21 
22 #include "gtest/gtest.h"
23 #include "src/program_builder.h"
24 #include "src/resolver/resolver.h"
25 #include "src/sem/expression.h"
26 #include "src/sem/statement.h"
27 #include "src/sem/variable.h"
28 
29 namespace tint {
30 namespace resolver {
31 
32 /// Helper class for testing
33 class TestHelper : public ProgramBuilder {
34  public:
35   /// Constructor
36   TestHelper();
37 
38   /// Destructor
39   ~TestHelper() override;
40 
41   /// @return a pointer to the Resolver
r()42   Resolver* r() const { return resolver_.get(); }
43 
44   /// Returns the statement that holds the given expression.
45   /// @param expr the ast::Expression
46   /// @return the ast::Statement of the ast::Expression, or nullptr if the
47   /// expression is not owned by a statement.
StmtOf(const ast::Expression * expr)48   const ast::Statement* StmtOf(const ast::Expression* expr) {
49     auto* sem_stmt = Sem().Get(expr)->Stmt();
50     return sem_stmt ? sem_stmt->Declaration() : nullptr;
51   }
52 
53   /// Returns the BlockStatement that holds the given statement.
54   /// @param stmt the ast::Statement
55   /// @return the ast::BlockStatement that holds the ast::Statement, or nullptr
56   /// if the statement is not owned by a BlockStatement.
BlockOf(const ast::Statement * stmt)57   const ast::BlockStatement* BlockOf(const ast::Statement* stmt) {
58     auto* sem_stmt = Sem().Get(stmt);
59     return sem_stmt ? sem_stmt->Block()->Declaration() : nullptr;
60   }
61 
62   /// Returns the BlockStatement that holds the given expression.
63   /// @param expr the ast::Expression
64   /// @return the ast::Statement of the ast::Expression, or nullptr if the
65   /// expression is not indirectly owned by a BlockStatement.
BlockOf(const ast::Expression * expr)66   const ast::BlockStatement* BlockOf(const ast::Expression* expr) {
67     auto* sem_stmt = Sem().Get(expr)->Stmt();
68     return sem_stmt ? sem_stmt->Block()->Declaration() : nullptr;
69   }
70 
71   /// Returns the semantic variable for the given identifier expression.
72   /// @param expr the identifier expression
73   /// @return the resolved sem::Variable of the identifier, or nullptr if
74   /// the expression did not resolve to a variable.
VarOf(const ast::Expression * expr)75   const sem::Variable* VarOf(const ast::Expression* expr) {
76     auto* sem_ident = Sem().Get(expr);
77     auto* var_user = sem_ident ? sem_ident->As<sem::VariableUser>() : nullptr;
78     return var_user ? var_user->Variable() : nullptr;
79   }
80 
81   /// Checks that all the users of the given variable are as expected
82   /// @param var the variable to check
83   /// @param expected_users the expected users of the variable
84   /// @return true if all users are as expected
CheckVarUsers(const ast::Variable * var,std::vector<const ast::Expression * > && expected_users)85   bool CheckVarUsers(const ast::Variable* var,
86                      std::vector<const ast::Expression*>&& expected_users) {
87     auto& var_users = Sem().Get(var)->Users();
88     if (var_users.size() != expected_users.size()) {
89       return false;
90     }
91     for (size_t i = 0; i < var_users.size(); i++) {
92       if (var_users[i]->Declaration() != expected_users[i]) {
93         return false;
94       }
95     }
96     return true;
97   }
98 
99   /// @param type a type
100   /// @returns the name for `type` that closely resembles how it would be
101   /// declared in WGSL.
FriendlyName(const ast::Type * type)102   std::string FriendlyName(const ast::Type* type) {
103     return type->FriendlyName(Symbols());
104   }
105 
106   /// @param type a type
107   /// @returns the name for `type` that closely resembles how it would be
108   /// declared in WGSL.
FriendlyName(const sem::Type * type)109   std::string FriendlyName(const sem::Type* type) {
110     return type->FriendlyName(Symbols());
111   }
112 
113  private:
114   std::unique_ptr<Resolver> resolver_;
115 };
116 
117 class ResolverTest : public TestHelper, public testing::Test {};
118 
119 template <typename T>
120 class ResolverTestWithParam : public TestHelper,
121                               public testing::TestWithParam<T> {};
122 
123 namespace builder {
124 
125 using i32 = ProgramBuilder::i32;
126 using u32 = ProgramBuilder::u32;
127 using f32 = ProgramBuilder::f32;
128 
129 template <int N, typename T>
130 struct vec {};
131 
132 template <typename T>
133 using vec2 = vec<2, T>;
134 
135 template <typename T>
136 using vec3 = vec<3, T>;
137 
138 template <typename T>
139 using vec4 = vec<4, T>;
140 
141 template <int N, int M, typename T>
142 struct mat {};
143 
144 template <typename T>
145 using mat2x2 = mat<2, 2, T>;
146 
147 template <typename T>
148 using mat2x3 = mat<2, 3, T>;
149 
150 template <typename T>
151 using mat3x2 = mat<3, 2, T>;
152 
153 template <typename T>
154 using mat3x3 = mat<3, 3, T>;
155 
156 template <typename T>
157 using mat4x4 = mat<4, 4, T>;
158 
159 template <int N, typename T>
160 struct array {};
161 
162 template <typename TO, int ID = 0>
163 struct alias {};
164 
165 template <typename TO>
166 using alias1 = alias<TO, 1>;
167 
168 template <typename TO>
169 using alias2 = alias<TO, 2>;
170 
171 template <typename TO>
172 using alias3 = alias<TO, 3>;
173 
174 template <typename TO>
175 struct ptr {};
176 
177 using ast_type_func_ptr = const ast::Type* (*)(ProgramBuilder& b);
178 using ast_expr_func_ptr = const ast::Expression* (*)(ProgramBuilder& b,
179                                                      int elem_value);
180 using sem_type_func_ptr = const sem::Type* (*)(ProgramBuilder& b);
181 
182 template <typename T>
183 struct DataType {};
184 
185 /// Helper for building bool types and expressions
186 template <>
187 struct DataType<bool> {
188   /// false as bool is not a composite type
189   static constexpr bool is_composite = false;
190 
191   /// @param b the ProgramBuilder
192   /// @return a new AST bool type
193   static inline const ast::Type* AST(ProgramBuilder& b) { return b.ty.bool_(); }
194   /// @param b the ProgramBuilder
195   /// @return the semantic bool type
196   static inline const sem::Type* Sem(ProgramBuilder& b) {
197     return b.create<sem::Bool>();
198   }
199   /// @param b the ProgramBuilder
200   /// @param elem_value the b
201   /// @return a new AST expression of the bool type
202   static inline const ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
203     return b.Expr(elem_value == 0);
204   }
205 };
206 
207 /// Helper for building i32 types and expressions
208 template <>
209 struct DataType<i32> {
210   /// false as i32 is not a composite type
211   static constexpr bool is_composite = false;
212 
213   /// @param b the ProgramBuilder
214   /// @return a new AST i32 type
215   static inline const ast::Type* AST(ProgramBuilder& b) { return b.ty.i32(); }
216   /// @param b the ProgramBuilder
217   /// @return the semantic i32 type
218   static inline const sem::Type* Sem(ProgramBuilder& b) {
219     return b.create<sem::I32>();
220   }
221   /// @param b the ProgramBuilder
222   /// @param elem_value the value i32 will be initialized with
223   /// @return a new AST i32 literal value expression
224   static inline const ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
225     return b.Expr(static_cast<i32>(elem_value));
226   }
227 };
228 
229 /// Helper for building u32 types and expressions
230 template <>
231 struct DataType<u32> {
232   /// false as u32 is not a composite type
233   static constexpr bool is_composite = false;
234 
235   /// @param b the ProgramBuilder
236   /// @return a new AST u32 type
237   static inline const ast::Type* AST(ProgramBuilder& b) { return b.ty.u32(); }
238   /// @param b the ProgramBuilder
239   /// @return the semantic u32 type
240   static inline const sem::Type* Sem(ProgramBuilder& b) {
241     return b.create<sem::U32>();
242   }
243   /// @param b the ProgramBuilder
244   /// @param elem_value the value u32 will be initialized with
245   /// @return a new AST u32 literal value expression
246   static inline const ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
247     return b.Expr(static_cast<u32>(elem_value));
248   }
249 };
250 
251 /// Helper for building f32 types and expressions
252 template <>
253 struct DataType<f32> {
254   /// false as f32 is not a composite type
255   static constexpr bool is_composite = false;
256 
257   /// @param b the ProgramBuilder
258   /// @return a new AST f32 type
259   static inline const ast::Type* AST(ProgramBuilder& b) { return b.ty.f32(); }
260   /// @param b the ProgramBuilder
261   /// @return the semantic f32 type
262   static inline const sem::Type* Sem(ProgramBuilder& b) {
263     return b.create<sem::F32>();
264   }
265   /// @param b the ProgramBuilder
266   /// @param elem_value the value f32 will be initialized with
267   /// @return a new AST f32 literal value expression
268   static inline const ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
269     return b.Expr(static_cast<f32>(elem_value));
270   }
271 };
272 
273 /// Helper for building vector types and expressions
274 template <int N, typename T>
275 struct DataType<vec<N, T>> {
276   /// true as vectors are a composite type
277   static constexpr bool is_composite = true;
278 
279   /// @param b the ProgramBuilder
280   /// @return a new AST vector type
281   static inline const ast::Type* AST(ProgramBuilder& b) {
282     return b.ty.vec(DataType<T>::AST(b), N);
283   }
284   /// @param b the ProgramBuilder
285   /// @return the semantic vector type
286   static inline const sem::Type* Sem(ProgramBuilder& b) {
287     return b.create<sem::Vector>(DataType<T>::Sem(b), N);
288   }
289   /// @param b the ProgramBuilder
290   /// @param elem_value the value each element in the vector will be initialized
291   /// with
292   /// @return a new AST vector value expression
293   static inline const ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
294     return b.Construct(AST(b), ExprArgs(b, elem_value));
295   }
296 
297   /// @param b the ProgramBuilder
298   /// @param elem_value the value each element will be initialized with
299   /// @return the list of expressions that are used to construct the vector
300   static inline ast::ExpressionList ExprArgs(ProgramBuilder& b,
301                                              int elem_value) {
302     ast::ExpressionList args;
303     for (int i = 0; i < N; i++) {
304       args.emplace_back(DataType<T>::Expr(b, elem_value));
305     }
306     return args;
307   }
308 };
309 
310 /// Helper for building matrix types and expressions
311 template <int N, int M, typename T>
312 struct DataType<mat<N, M, T>> {
313   /// true as matrices are a composite type
314   static constexpr bool is_composite = true;
315 
316   /// @param b the ProgramBuilder
317   /// @return a new AST matrix type
318   static inline const ast::Type* AST(ProgramBuilder& b) {
319     return b.ty.mat(DataType<T>::AST(b), N, M);
320   }
321   /// @param b the ProgramBuilder
322   /// @return the semantic matrix type
323   static inline const sem::Type* Sem(ProgramBuilder& b) {
324     auto* column_type = b.create<sem::Vector>(DataType<T>::Sem(b), M);
325     return b.create<sem::Matrix>(column_type, N);
326   }
327   /// @param b the ProgramBuilder
328   /// @param elem_value the value each element in the matrix will be initialized
329   /// with
330   /// @return a new AST matrix value expression
331   static inline const ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
332     return b.Construct(AST(b), ExprArgs(b, elem_value));
333   }
334 
335   /// @param b the ProgramBuilder
336   /// @param elem_value the value each element will be initialized with
337   /// @return the list of expressions that are used to construct the matrix
338   static inline ast::ExpressionList ExprArgs(ProgramBuilder& b,
339                                              int elem_value) {
340     ast::ExpressionList args;
341     for (int i = 0; i < N; i++) {
342       args.emplace_back(DataType<vec<M, T>>::Expr(b, elem_value));
343     }
344     return args;
345   }
346 };
347 
348 /// Helper for building alias types and expressions
349 template <typename T, int ID>
350 struct DataType<alias<T, ID>> {
351   /// true if the aliased type is a composite type
352   static constexpr bool is_composite = DataType<T>::is_composite;
353 
354   /// @param b the ProgramBuilder
355   /// @return a new AST alias type
356   static inline const ast::Type* AST(ProgramBuilder& b) {
357     auto name = b.Symbols().Register("alias_" + std::to_string(ID));
358     if (!b.AST().LookupType(name)) {
359       auto* type = DataType<T>::AST(b);
360       b.AST().AddTypeDecl(b.ty.alias(name, type));
361     }
362     return b.create<ast::TypeName>(name);
363   }
364   /// @param b the ProgramBuilder
365   /// @return the semantic aliased type
366   static inline const sem::Type* Sem(ProgramBuilder& b) {
367     return DataType<T>::Sem(b);
368   }
369 
370   /// @param b the ProgramBuilder
371   /// @param elem_value the value nested elements will be initialized with
372   /// @return a new AST expression of the alias type
373   template <bool IS_COMPOSITE = is_composite>
374   static inline traits::EnableIf<!IS_COMPOSITE, const ast::Expression*> Expr(
375       ProgramBuilder& b,
376       int elem_value) {
377     // Cast
378     return b.Construct(AST(b), DataType<T>::Expr(b, elem_value));
379   }
380 
381   /// @param b the ProgramBuilder
382   /// @param elem_value the value nested elements will be initialized with
383   /// @return a new AST expression of the alias type
384   template <bool IS_COMPOSITE = is_composite>
385   static inline traits::EnableIf<IS_COMPOSITE, const ast::Expression*> Expr(
386       ProgramBuilder& b,
387       int elem_value) {
388     // Construct
389     return b.Construct(AST(b), DataType<T>::ExprArgs(b, elem_value));
390   }
391 };
392 
393 /// Helper for building pointer types and expressions
394 template <typename T>
395 struct DataType<ptr<T>> {
396   /// true if the pointer type is a composite type
397   static constexpr bool is_composite = false;
398 
399   /// @param b the ProgramBuilder
400   /// @return a new AST alias type
401   static inline const ast::Type* AST(ProgramBuilder& b) {
402     return b.create<ast::Pointer>(DataType<T>::AST(b),
403                                   ast::StorageClass::kPrivate,
404                                   ast::Access::kReadWrite);
405   }
406   /// @param b the ProgramBuilder
407   /// @return the semantic aliased type
408   static inline const sem::Type* Sem(ProgramBuilder& b) {
409     return b.create<sem::Pointer>(DataType<T>::Sem(b),
410                                   ast::StorageClass::kPrivate,
411                                   ast::Access::kReadWrite);
412   }
413 
414   /// @param b the ProgramBuilder
415   /// @return a new AST expression of the alias type
416   static inline const ast::Expression* Expr(ProgramBuilder& b, int /*unused*/) {
417     auto sym = b.Symbols().New("global_for_ptr");
418     b.Global(sym, DataType<T>::AST(b), ast::StorageClass::kPrivate);
419     return b.AddressOf(sym);
420   }
421 };
422 
423 /// Helper for building array types and expressions
424 template <int N, typename T>
425 struct DataType<array<N, T>> {
426   /// true as arrays are a composite type
427   static constexpr bool is_composite = true;
428 
429   /// @param b the ProgramBuilder
430   /// @return a new AST array type
431   static inline const ast::Type* AST(ProgramBuilder& b) {
432     return b.ty.array(DataType<T>::AST(b), N);
433   }
434   /// @param b the ProgramBuilder
435   /// @return the semantic array type
436   static inline const sem::Type* Sem(ProgramBuilder& b) {
437     auto* el = DataType<T>::Sem(b);
438     return b.create<sem::Array>(
439         /* element */ el,
440         /* count */ N,
441         /* align */ el->Align(),
442         /* size */ el->Size(),
443         /* stride */ el->Align(),
444         /* implicit_stride */ el->Align());
445   }
446   /// @param b the ProgramBuilder
447   /// @param elem_value the value each element in the array will be initialized
448   /// with
449   /// @return a new AST array value expression
450   static inline const ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
451     return b.Construct(AST(b), ExprArgs(b, elem_value));
452   }
453 
454   /// @param b the ProgramBuilder
455   /// @param elem_value the value each element will be initialized with
456   /// @return the list of expressions that are used to construct the array
457   static inline ast::ExpressionList ExprArgs(ProgramBuilder& b,
458                                              int elem_value) {
459     ast::ExpressionList args;
460     for (int i = 0; i < N; i++) {
461       args.emplace_back(DataType<T>::Expr(b, elem_value));
462     }
463     return args;
464   }
465 };
466 
467 /// Struct of all creation pointer types
468 struct CreatePtrs {
469   /// ast node type create function
470   ast_type_func_ptr ast;
471   /// ast expression type create function
472   ast_expr_func_ptr expr;
473   /// sem type create function
474   sem_type_func_ptr sem;
475 };
476 
477 /// Returns a CreatePtrs struct instance with all creation pointer types for
478 /// type `T`
479 template <typename T>
480 constexpr CreatePtrs CreatePtrsFor() {
481   return {DataType<T>::AST, DataType<T>::Expr, DataType<T>::Sem};
482 }
483 
484 }  // namespace builder
485 
486 }  // namespace resolver
487 }  // namespace tint
488 
489 #endif  // SRC_RESOLVER_RESOLVER_TEST_HELPER_H_
490