• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef SRC_RESOLVER_RESOLVER_H_
16 #define SRC_RESOLVER_RESOLVER_H_
17 
18 #include <memory>
19 #include <set>
20 #include <string>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <utility>
24 #include <vector>
25 
26 #include "src/intrinsic_table.h"
27 #include "src/program_builder.h"
28 #include "src/resolver/dependency_graph.h"
29 #include "src/scope_stack.h"
30 #include "src/sem/binding_point.h"
31 #include "src/sem/block_statement.h"
32 #include "src/sem/constant.h"
33 #include "src/sem/function.h"
34 #include "src/sem/struct.h"
35 #include "src/utils/map.h"
36 #include "src/utils/unique_vector.h"
37 
38 namespace tint {
39 
40 // Forward declarations
41 namespace ast {
42 class IndexAccessorExpression;
43 class BinaryExpression;
44 class BitcastExpression;
45 class CallExpression;
46 class CallStatement;
47 class CaseStatement;
48 class ForLoopStatement;
49 class Function;
50 class IdentifierExpression;
51 class LoopStatement;
52 class MemberAccessorExpression;
53 class ReturnStatement;
54 class SwitchStatement;
55 class UnaryOpExpression;
56 class Variable;
57 }  // namespace ast
58 namespace sem {
59 class Array;
60 class Atomic;
61 class BlockStatement;
62 class CaseStatement;
63 class ElseStatement;
64 class ForLoopStatement;
65 class IfStatement;
66 class Intrinsic;
67 class LoopStatement;
68 class Statement;
69 class SwitchStatement;
70 class TypeConstructor;
71 }  // namespace sem
72 
73 namespace resolver {
74 
75 /// Resolves types for all items in the given tint program
76 class Resolver {
77  public:
78   /// Constructor
79   /// @param builder the program builder
80   explicit Resolver(ProgramBuilder* builder);
81 
82   /// Destructor
83   ~Resolver();
84 
85   /// @returns error messages from the resolver
error()86   std::string error() const { return diagnostics_.str(); }
87 
88   /// @returns true if the resolver was successful
89   bool Resolve();
90 
91   /// @param type the given type
92   /// @returns true if the given type is a plain type
93   bool IsPlain(const sem::Type* type) const;
94 
95   /// @param type the given type
96   /// @returns true if the given type is storable
97   bool IsStorable(const sem::Type* type) const;
98 
99   /// @param type the given type
100   /// @returns true if the given type is host-shareable
101   bool IsHostShareable(const sem::Type* type) const;
102 
103  private:
104   /// Describes the context in which a variable is declared
105   enum class VariableKind { kParameter, kLocal, kGlobal };
106 
107   std::set<std::pair<const sem::Struct*, ast::StorageClass>>
108       valid_struct_storage_layouts_;
109 
110   /// Structure holding semantic information about a block (i.e. scope), such as
111   /// parent block and variables declared in the block.
112   /// Used to validate variable scoping rules.
113   struct BlockInfo {
114     enum class Type { kGeneric, kLoop, kLoopContinuing, kSwitchCase };
115 
116     BlockInfo(const ast::BlockStatement* block, Type type, BlockInfo* parent);
117     ~BlockInfo();
118 
119     template <typename Pred>
FindFirstParentBlockInfo120     BlockInfo* FindFirstParent(Pred&& pred) {
121       BlockInfo* curr = this;
122       while (curr && !pred(curr)) {
123         curr = curr->parent;
124       }
125       return curr;
126     }
127 
FindFirstParentBlockInfo128     BlockInfo* FindFirstParent(BlockInfo::Type ty) {
129       return FindFirstParent(
130           [ty](auto* block_info) { return block_info->type == ty; });
131     }
132 
133     ast::BlockStatement const* const block;
134     const Type type;
135     BlockInfo* const parent;
136     std::vector<const ast::Variable*> decls;
137 
138     // first_continue is set to the index of the first variable in decls
139     // declared after the first continue statement in a loop block, if any.
140     constexpr static size_t kNoContinue = size_t(~0);
141     size_t first_continue = kNoContinue;
142   };
143 
144   // Structure holding information for a TypeDecl
145   struct TypeDeclInfo {
146     ast::TypeDecl const* const ast;
147     sem::Type* const sem;
148   };
149 
150   /// Resolves the program, without creating final the semantic nodes.
151   /// @returns true on success, false on error
152   bool ResolveInternal();
153 
154   bool ValidatePipelineStages();
155 
156   /// Creates the nodes and adds them to the sem::Info mappings of the
157   /// ProgramBuilder.
158   void CreateSemanticNodes() const;
159 
160   /// Retrieves information for the requested import.
161   /// @param src the source of the import
162   /// @param path the import path
163   /// @param name the method name to get information on
164   /// @param params the parameters to the method call
165   /// @param id out parameter for the external call ID. Must not be a nullptr.
166   /// @returns the return type of `name` in `path` or nullptr on error.
167   sem::Type* GetImportData(const Source& src,
168                            const std::string& path,
169                            const std::string& name,
170                            const ast::ExpressionList& params,
171                            uint32_t* id);
172 
173   //////////////////////////////////////////////////////////////////////////////
174   // AST and Type traversal methods
175   //////////////////////////////////////////////////////////////////////////////
176 
177   // Expression resolving methods
178   // Returns the semantic node pointer on success, nullptr on failure.
179   sem::Expression* IndexAccessor(const ast::IndexAccessorExpression*);
180   sem::Expression* Binary(const ast::BinaryExpression*);
181   sem::Expression* Bitcast(const ast::BitcastExpression*);
182   sem::Call* Call(const ast::CallExpression*);
183   sem::Expression* Expression(const ast::Expression*);
184   sem::Function* Function(const ast::Function*);
185   sem::Call* FunctionCall(const ast::CallExpression*,
186                           sem::Function* target,
187                           const std::vector<const sem::Expression*> args,
188                           sem::Behaviors arg_behaviors);
189   sem::Expression* Identifier(const ast::IdentifierExpression*);
190   sem::Call* IntrinsicCall(const ast::CallExpression*,
191                            sem::IntrinsicType,
192                            const std::vector<const sem::Expression*> args,
193                            const std::vector<const sem::Type*> arg_tys);
194   sem::Expression* Literal(const ast::LiteralExpression*);
195   sem::Expression* MemberAccessor(const ast::MemberAccessorExpression*);
196   sem::Call* TypeConversion(const ast::CallExpression* expr,
197                             const sem::Type* ty,
198                             const sem::Expression* arg,
199                             const sem::Type* arg_ty);
200   sem::Call* TypeConstructor(const ast::CallExpression* expr,
201                              const sem::Type* ty,
202                              const std::vector<const sem::Expression*> args,
203                              const std::vector<const sem::Type*> arg_tys);
204   sem::Expression* UnaryOp(const ast::UnaryOpExpression*);
205 
206   // Statement resolving methods
207   // Each return true on success, false on failure.
208   sem::Statement* AssignmentStatement(const ast::AssignmentStatement*);
209   sem::BlockStatement* BlockStatement(const ast::BlockStatement*);
210   sem::Statement* BreakStatement(const ast::BreakStatement*);
211   sem::Statement* CallStatement(const ast::CallStatement*);
212   sem::CaseStatement* CaseStatement(const ast::CaseStatement*);
213   sem::Statement* ContinueStatement(const ast::ContinueStatement*);
214   sem::Statement* DiscardStatement(const ast::DiscardStatement*);
215   sem::ElseStatement* ElseStatement(const ast::ElseStatement*);
216   sem::Statement* FallthroughStatement(const ast::FallthroughStatement*);
217   sem::ForLoopStatement* ForLoopStatement(const ast::ForLoopStatement*);
218   sem::Statement* Parameter(const ast::Variable*);
219   sem::IfStatement* IfStatement(const ast::IfStatement*);
220   sem::LoopStatement* LoopStatement(const ast::LoopStatement*);
221   sem::Statement* ReturnStatement(const ast::ReturnStatement*);
222   sem::Statement* Statement(const ast::Statement*);
223   sem::SwitchStatement* SwitchStatement(const ast::SwitchStatement* s);
224   sem::Statement* VariableDeclStatement(const ast::VariableDeclStatement*);
225   bool Statements(const ast::StatementList&);
226 
227   bool GlobalVariable(const ast::Variable*);
228 
229   // AST and Type validation methods
230   // Each return true on success, false on failure.
231   bool ValidateAlias(const ast::Alias*);
232   bool ValidateArray(const sem::Array* arr, const Source& source);
233   bool ValidateArrayStrideDecoration(const ast::StrideDecoration* deco,
234                                      uint32_t el_size,
235                                      uint32_t el_align,
236                                      const Source& source);
237   bool ValidateAtomic(const ast::Atomic* a, const sem::Atomic* s);
238   bool ValidateAtomicVariable(const sem::Variable* var);
239   bool ValidateAssignment(const ast::AssignmentStatement* a);
240   bool ValidateBitcast(const ast::BitcastExpression* cast, const sem::Type* to);
241   bool ValidateBreakStatement(const sem::Statement* stmt);
242   bool ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco,
243                                  const sem::Type* storage_type,
244                                  const bool is_input);
245   bool ValidateContinueStatement(const sem::Statement* stmt);
246   bool ValidateDiscardStatement(const sem::Statement* stmt);
247   bool ValidateElseStatement(const sem::ElseStatement* stmt);
248   bool ValidateEntryPoint(const sem::Function* func);
249   bool ValidateForLoopStatement(const sem::ForLoopStatement* stmt);
250   bool ValidateFallthroughStatement(const sem::Statement* stmt);
251   bool ValidateFunction(const sem::Function* func);
252   bool ValidateFunctionCall(const sem::Call* call);
253   bool ValidateGlobalVariable(const sem::Variable* var);
254   bool ValidateIfStatement(const sem::IfStatement* stmt);
255   bool ValidateInterpolateDecoration(const ast::InterpolateDecoration* deco,
256                                      const sem::Type* storage_type);
257   bool ValidateIntrinsicCall(const sem::Call* call);
258   bool ValidateLocationDecoration(const ast::LocationDecoration* location,
259                                   const sem::Type* type,
260                                   std::unordered_set<uint32_t>& locations,
261                                   const Source& source,
262                                   const bool is_input = false);
263   bool ValidateMatrix(const sem::Matrix* ty, const Source& source);
264   bool ValidateFunctionParameter(const ast::Function* func,
265                                  const sem::Variable* var);
266   bool ValidateParameter(const ast::Function* func, const sem::Variable* var);
267   bool ValidateReturn(const ast::ReturnStatement* ret);
268   bool ValidateStatements(const ast::StatementList& stmts);
269   bool ValidateStorageTexture(const ast::StorageTexture* t);
270   bool ValidateStructure(const sem::Struct* str);
271   bool ValidateStructureConstructorOrCast(const ast::CallExpression* ctor,
272                                           const sem::Struct* struct_type);
273   bool ValidateSwitch(const ast::SwitchStatement* s);
274   bool ValidateVariable(const sem::Variable* var);
275   bool ValidateVariableConstructorOrCast(const ast::Variable* var,
276                                          ast::StorageClass storage_class,
277                                          const sem::Type* storage_type,
278                                          const sem::Type* rhs_type);
279   bool ValidateVector(const sem::Vector* ty, const Source& source);
280   bool ValidateVectorConstructorOrCast(const ast::CallExpression* ctor,
281                                        const sem::Vector* vec_type);
282   bool ValidateMatrixConstructorOrCast(const ast::CallExpression* ctor,
283                                        const sem::Matrix* matrix_type);
284   bool ValidateScalarConstructorOrCast(const ast::CallExpression* ctor,
285                                        const sem::Type* type);
286   bool ValidateArrayConstructorOrCast(const ast::CallExpression* ctor,
287                                       const sem::Array* arr_type);
288   bool ValidateTextureIntrinsicFunction(const sem::Call* call);
289   bool ValidateNoDuplicateDecorations(const ast::DecorationList& decorations);
290   // sem::Struct is assumed to have at least one member
291   bool ValidateStorageClassLayout(const sem::Struct* type,
292                                   ast::StorageClass sc);
293   bool ValidateStorageClassLayout(const sem::Variable* var);
294 
295   /// @returns true if the decoration list contains a
296   /// ast::DisableValidationDecoration with the validation mode equal to
297   /// `validation`
298   bool IsValidationDisabled(const ast::DecorationList& decorations,
299                             ast::DisabledValidation validation) const;
300 
301   /// @returns true if the decoration list does not contains a
302   /// ast::DisableValidationDecoration with the validation mode equal to
303   /// `validation`
304   bool IsValidationEnabled(const ast::DecorationList& decorations,
305                            ast::DisabledValidation validation) const;
306 
307   /// Resolves the WorkgroupSize for the given function, assigning it to
308   /// current_function_
309   bool WorkgroupSize(const ast::Function*);
310 
311   /// @returns the sem::Type for the ast::Type `ty`, building it if it
312   /// hasn't been constructed already. If an error is raised, nullptr is
313   /// returned.
314   /// @param ty the ast::Type
315   sem::Type* Type(const ast::Type* ty);
316 
317   /// @param named_type the named type to resolve
318   /// @returns the resolved semantic type
319   sem::Type* TypeDecl(const ast::TypeDecl* named_type);
320 
321   /// Builds and returns the semantic information for the array `arr`.
322   /// This method does not mark the ast::Array node, nor attach the generated
323   /// semantic information to the AST node.
324   /// @returns the semantic Array information, or nullptr if an error is
325   /// raised.
326   /// @param arr the Array to get semantic information for
327   sem::Array* Array(const ast::Array* arr);
328 
329   /// Builds and returns the semantic information for the alias `alias`.
330   /// This method does not mark the ast::Alias node, nor attach the generated
331   /// semantic information to the AST node.
332   /// @returns the aliased type, or nullptr if an error is raised.
333   sem::Type* Alias(const ast::Alias* alias);
334 
335   /// Builds and returns the semantic information for the structure `str`.
336   /// This method does not mark the ast::Struct node, nor attach the generated
337   /// semantic information to the AST node.
338   /// @returns the semantic Struct information, or nullptr if an error is
339   /// raised.
340   sem::Struct* Structure(const ast::Struct* str);
341 
342   /// @returns the semantic info for the variable `var`. If an error is
343   /// raised, nullptr is returned.
344   /// @note this method does not resolve the decorations as these are
345   /// context-dependent (global, local, parameter)
346   /// @param var the variable to create or return the `VariableInfo` for
347   /// @param kind what kind of variable we are declaring
348   /// @param index the index of the parameter, if this variable is a parameter
349   sem::Variable* Variable(const ast::Variable* var,
350                           VariableKind kind,
351                           uint32_t index = 0);
352 
353   /// Records the storage class usage for the given type, and any transient
354   /// dependencies of the type. Validates that the type can be used for the
355   /// given storage class, erroring if it cannot.
356   /// @param sc the storage class to apply to the type and transitent types
357   /// @param ty the type to apply the storage class on
358   /// @param usage the Source of the root variable declaration that uses the
359   /// given type and storage class. Used for generating sensible error
360   /// messages.
361   /// @returns true on success, false on error
362   bool ApplyStorageClassUsageToType(ast::StorageClass sc,
363                                     sem::Type* ty,
364                                     const Source& usage);
365 
366   /// @param storage_class the storage class
367   /// @returns the default access control for the given storage class
368   ast::Access DefaultAccessForStorageClass(ast::StorageClass storage_class);
369 
370   /// Allocate constant IDs for pipeline-overridable constants.
371   void AllocateOverridableConstantIds();
372 
373   /// Set the shadowing information on variable declarations.
374   /// @note this method must only be called after all semantic nodes are built.
375   void SetShadows();
376 
377   /// @returns the resolved type of the ast::Expression `expr`
378   /// @param expr the expression
379   sem::Type* TypeOf(const ast::Expression* expr);
380 
381   /// @returns the type name of the given semantic type, unwrapping
382   /// references.
383   std::string TypeNameOf(const sem::Type* ty);
384 
385   /// @returns the type name of the given semantic type, without unwrapping
386   /// references.
387   std::string RawTypeNameOf(const sem::Type* ty);
388 
389   /// @returns the semantic type of the AST literal `lit`
390   /// @param lit the literal
391   sem::Type* TypeOf(const ast::LiteralExpression* lit);
392 
393   /// StatementScope() does the following:
394   /// * Creates the AST -> SEM mapping.
395   /// * Assigns `sem` to #current_statement_
396   /// * Assigns `sem` to #current_compound_statement_ if `sem` derives from
397   ///   sem::CompoundStatement.
398   /// * Assigns `sem` to #current_block_ if `sem` derives from
399   ///   sem::BlockStatement.
400   /// * Then calls `callback`.
401   /// * Before returning #current_statement_, #current_compound_statement_, and
402   ///   #current_block_ are restored to their original values.
403   /// @returns `sem` if `callback` returns true, otherwise `nullptr`.
404   template <typename SEM, typename F>
405   SEM* StatementScope(const ast::Statement* ast, SEM* sem, F&& callback);
406 
407   /// Returns a human-readable string representation of the vector type name
408   /// with the given parameters.
409   /// @param size the vector dimension
410   /// @param element_type scalar vector sub-element type
411   /// @return pretty string representation
412   std::string VectorPretty(uint32_t size, const sem::Type* element_type);
413 
414   /// Mark records that the given AST node has been visited, and asserts that
415   /// the given node has not already been seen. Diamonds in the AST are
416   /// illegal.
417   /// @param node the AST node.
418   /// @returns true on success, false on error
419   bool Mark(const ast::Node* node);
420 
421   /// Adds the given error message to the diagnostics
422   void AddError(const std::string& msg, const Source& source) const;
423 
424   /// Adds the given warning message to the diagnostics
425   void AddWarning(const std::string& msg, const Source& source) const;
426 
427   /// Adds the given note message to the diagnostics
428   void AddNote(const std::string& msg, const Source& source) const;
429 
430   //////////////////////////////////////////////////////////////////////////////
431   /// Constant value evaluation methods
432   //////////////////////////////////////////////////////////////////////////////
433   /// Cast `Value` to `target_type`
434   /// @return the casted value
435   sem::Constant ConstantCast(const sem::Constant& value,
436                              const sem::Type* target_elem_type);
437 
438   sem::Constant EvaluateConstantValue(const ast::Expression* expr,
439                                       const sem::Type* type);
440   sem::Constant EvaluateConstantValue(const ast::LiteralExpression* literal,
441                                       const sem::Type* type);
442   sem::Constant EvaluateConstantValue(const ast::CallExpression* call,
443                                       const sem::Type* type);
444 
445   /// Sem is a helper for obtaining the semantic node for the given AST node.
446   template <typename SEM = sem::Info::InferFromAST,
447             typename AST_OR_TYPE = CastableBase>
Sem(const AST_OR_TYPE * ast)448   auto* Sem(const AST_OR_TYPE* ast) {
449     using T = sem::Info::GetResultType<SEM, AST_OR_TYPE>;
450     auto* sem = builder_->Sem().Get(ast);
451     if (!sem) {
452       TINT_ICE(Resolver, diagnostics_)
453           << "AST node '" << ast->TypeInfo().name << "' had no semantic info\n"
454           << "At: " << ast->source << "\n"
455           << "Pointer: " << ast;
456     }
457     return const_cast<T*>(As<T>(sem));
458   }
459 
460   /// @returns true if the symbol is the name of an intrinsic (builtin)
461   /// function.
462   bool IsIntrinsic(Symbol) const;
463 
464   /// @returns true if `expr` is the current CallStatement's CallExpression
465   bool IsCallStatement(const ast::Expression* expr) const;
466 
467   /// Searches the current statement and up through parents of the current
468   /// statement looking for a loop or for-loop continuing statement.
469   /// @returns the closest continuing statement to the current statement that
470   /// (transitively) owns the current statement.
471   /// @param stop_at_loop if true then the function will return nullptr if a
472   /// loop or for-loop was found before the continuing.
473   const ast::Statement* ClosestContinuing(bool stop_at_loop) const;
474 
475   /// @returns the resolved symbol (function, type or variable) for the given
476   /// ast::Identifier or ast::TypeName cast to the given semantic type.
477   template <typename SEM = sem::Node>
ResolvedSymbol(const ast::Node * node)478   SEM* ResolvedSymbol(const ast::Node* node) {
479     auto* resolved = utils::Lookup(dependencies_.resolved_symbols, node);
480     return resolved ? const_cast<SEM*>(builder_->Sem().Get<SEM>(resolved))
481                     : nullptr;
482   }
483 
484   struct TypeConversionSig {
485     const sem::Type* target;
486     const sem::Type* source;
487 
488     bool operator==(const TypeConversionSig&) const;
489 
490     /// Hasher provides a hash function for the TypeConversionSig
491     struct Hasher {
492       /// @param sig the TypeConversionSig to create a hash for
493       /// @return the hash value
494       std::size_t operator()(const TypeConversionSig& sig) const;
495     };
496   };
497 
498   struct TypeConstructorSig {
499     const sem::Type* type;
500     const std::vector<const sem::Type*> parameters;
501 
502     TypeConstructorSig(const sem::Type* ty,
503                        const std::vector<const sem::Type*> params);
504     TypeConstructorSig(const TypeConstructorSig&);
505     ~TypeConstructorSig();
506     bool operator==(const TypeConstructorSig&) const;
507 
508     /// Hasher provides a hash function for the TypeConstructorSig
509     struct Hasher {
510       /// @param sig the TypeConstructorSig to create a hash for
511       /// @return the hash value
512       std::size_t operator()(const TypeConstructorSig& sig) const;
513     };
514   };
515 
516   ProgramBuilder* const builder_;
517   diag::List& diagnostics_;
518   std::unique_ptr<IntrinsicTable> const intrinsic_table_;
519   DependencyGraph dependencies_;
520   std::vector<sem::Function*> entry_points_;
521   std::unordered_map<const sem::Type*, const Source&> atomic_composite_info_;
522   std::unordered_set<const ast::Node*> marked_;
523   std::unordered_map<uint32_t, const sem::Variable*> constant_ids_;
524   std::unordered_map<TypeConversionSig,
525                      sem::CallTarget*,
526                      TypeConversionSig::Hasher>
527       type_conversions_;
528   std::unordered_map<TypeConstructorSig,
529                      sem::CallTarget*,
530                      TypeConstructorSig::Hasher>
531       type_ctors_;
532 
533   sem::Function* current_function_ = nullptr;
534   sem::Statement* current_statement_ = nullptr;
535   sem::CompoundStatement* current_compound_statement_ = nullptr;
536   sem::BlockStatement* current_block_ = nullptr;
537 };
538 
539 }  // namespace resolver
540 }  // namespace tint
541 
542 #endif  // SRC_RESOLVER_RESOLVER_H_
543