1 // Copyright 2020 The Tint Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef SRC_RESOLVER_RESOLVER_H_ 16 #define SRC_RESOLVER_RESOLVER_H_ 17 18 #include <memory> 19 #include <set> 20 #include <string> 21 #include <unordered_map> 22 #include <unordered_set> 23 #include <utility> 24 #include <vector> 25 26 #include "src/intrinsic_table.h" 27 #include "src/program_builder.h" 28 #include "src/resolver/dependency_graph.h" 29 #include "src/scope_stack.h" 30 #include "src/sem/binding_point.h" 31 #include "src/sem/block_statement.h" 32 #include "src/sem/constant.h" 33 #include "src/sem/function.h" 34 #include "src/sem/struct.h" 35 #include "src/utils/map.h" 36 #include "src/utils/unique_vector.h" 37 38 namespace tint { 39 40 // Forward declarations 41 namespace ast { 42 class IndexAccessorExpression; 43 class BinaryExpression; 44 class BitcastExpression; 45 class CallExpression; 46 class CallStatement; 47 class CaseStatement; 48 class ForLoopStatement; 49 class Function; 50 class IdentifierExpression; 51 class LoopStatement; 52 class MemberAccessorExpression; 53 class ReturnStatement; 54 class SwitchStatement; 55 class UnaryOpExpression; 56 class Variable; 57 } // namespace ast 58 namespace sem { 59 class Array; 60 class Atomic; 61 class BlockStatement; 62 class CaseStatement; 63 class ElseStatement; 64 class ForLoopStatement; 65 class IfStatement; 66 class Intrinsic; 67 class LoopStatement; 68 class Statement; 69 class SwitchStatement; 70 class TypeConstructor; 71 } // namespace sem 72 73 namespace resolver { 74 75 /// Resolves types for all items in the given tint program 76 class Resolver { 77 public: 78 /// Constructor 79 /// @param builder the program builder 80 explicit Resolver(ProgramBuilder* builder); 81 82 /// Destructor 83 ~Resolver(); 84 85 /// @returns error messages from the resolver error()86 std::string error() const { return diagnostics_.str(); } 87 88 /// @returns true if the resolver was successful 89 bool Resolve(); 90 91 /// @param type the given type 92 /// @returns true if the given type is a plain type 93 bool IsPlain(const sem::Type* type) const; 94 95 /// @param type the given type 96 /// @returns true if the given type is storable 97 bool IsStorable(const sem::Type* type) const; 98 99 /// @param type the given type 100 /// @returns true if the given type is host-shareable 101 bool IsHostShareable(const sem::Type* type) const; 102 103 private: 104 /// Describes the context in which a variable is declared 105 enum class VariableKind { kParameter, kLocal, kGlobal }; 106 107 std::set<std::pair<const sem::Struct*, ast::StorageClass>> 108 valid_struct_storage_layouts_; 109 110 /// Structure holding semantic information about a block (i.e. scope), such as 111 /// parent block and variables declared in the block. 112 /// Used to validate variable scoping rules. 113 struct BlockInfo { 114 enum class Type { kGeneric, kLoop, kLoopContinuing, kSwitchCase }; 115 116 BlockInfo(const ast::BlockStatement* block, Type type, BlockInfo* parent); 117 ~BlockInfo(); 118 119 template <typename Pred> FindFirstParentBlockInfo120 BlockInfo* FindFirstParent(Pred&& pred) { 121 BlockInfo* curr = this; 122 while (curr && !pred(curr)) { 123 curr = curr->parent; 124 } 125 return curr; 126 } 127 FindFirstParentBlockInfo128 BlockInfo* FindFirstParent(BlockInfo::Type ty) { 129 return FindFirstParent( 130 [ty](auto* block_info) { return block_info->type == ty; }); 131 } 132 133 ast::BlockStatement const* const block; 134 const Type type; 135 BlockInfo* const parent; 136 std::vector<const ast::Variable*> decls; 137 138 // first_continue is set to the index of the first variable in decls 139 // declared after the first continue statement in a loop block, if any. 140 constexpr static size_t kNoContinue = size_t(~0); 141 size_t first_continue = kNoContinue; 142 }; 143 144 // Structure holding information for a TypeDecl 145 struct TypeDeclInfo { 146 ast::TypeDecl const* const ast; 147 sem::Type* const sem; 148 }; 149 150 /// Resolves the program, without creating final the semantic nodes. 151 /// @returns true on success, false on error 152 bool ResolveInternal(); 153 154 bool ValidatePipelineStages(); 155 156 /// Creates the nodes and adds them to the sem::Info mappings of the 157 /// ProgramBuilder. 158 void CreateSemanticNodes() const; 159 160 /// Retrieves information for the requested import. 161 /// @param src the source of the import 162 /// @param path the import path 163 /// @param name the method name to get information on 164 /// @param params the parameters to the method call 165 /// @param id out parameter for the external call ID. Must not be a nullptr. 166 /// @returns the return type of `name` in `path` or nullptr on error. 167 sem::Type* GetImportData(const Source& src, 168 const std::string& path, 169 const std::string& name, 170 const ast::ExpressionList& params, 171 uint32_t* id); 172 173 ////////////////////////////////////////////////////////////////////////////// 174 // AST and Type traversal methods 175 ////////////////////////////////////////////////////////////////////////////// 176 177 // Expression resolving methods 178 // Returns the semantic node pointer on success, nullptr on failure. 179 sem::Expression* IndexAccessor(const ast::IndexAccessorExpression*); 180 sem::Expression* Binary(const ast::BinaryExpression*); 181 sem::Expression* Bitcast(const ast::BitcastExpression*); 182 sem::Call* Call(const ast::CallExpression*); 183 sem::Expression* Expression(const ast::Expression*); 184 sem::Function* Function(const ast::Function*); 185 sem::Call* FunctionCall(const ast::CallExpression*, 186 sem::Function* target, 187 const std::vector<const sem::Expression*> args, 188 sem::Behaviors arg_behaviors); 189 sem::Expression* Identifier(const ast::IdentifierExpression*); 190 sem::Call* IntrinsicCall(const ast::CallExpression*, 191 sem::IntrinsicType, 192 const std::vector<const sem::Expression*> args, 193 const std::vector<const sem::Type*> arg_tys); 194 sem::Expression* Literal(const ast::LiteralExpression*); 195 sem::Expression* MemberAccessor(const ast::MemberAccessorExpression*); 196 sem::Call* TypeConversion(const ast::CallExpression* expr, 197 const sem::Type* ty, 198 const sem::Expression* arg, 199 const sem::Type* arg_ty); 200 sem::Call* TypeConstructor(const ast::CallExpression* expr, 201 const sem::Type* ty, 202 const std::vector<const sem::Expression*> args, 203 const std::vector<const sem::Type*> arg_tys); 204 sem::Expression* UnaryOp(const ast::UnaryOpExpression*); 205 206 // Statement resolving methods 207 // Each return true on success, false on failure. 208 sem::Statement* AssignmentStatement(const ast::AssignmentStatement*); 209 sem::BlockStatement* BlockStatement(const ast::BlockStatement*); 210 sem::Statement* BreakStatement(const ast::BreakStatement*); 211 sem::Statement* CallStatement(const ast::CallStatement*); 212 sem::CaseStatement* CaseStatement(const ast::CaseStatement*); 213 sem::Statement* ContinueStatement(const ast::ContinueStatement*); 214 sem::Statement* DiscardStatement(const ast::DiscardStatement*); 215 sem::ElseStatement* ElseStatement(const ast::ElseStatement*); 216 sem::Statement* FallthroughStatement(const ast::FallthroughStatement*); 217 sem::ForLoopStatement* ForLoopStatement(const ast::ForLoopStatement*); 218 sem::Statement* Parameter(const ast::Variable*); 219 sem::IfStatement* IfStatement(const ast::IfStatement*); 220 sem::LoopStatement* LoopStatement(const ast::LoopStatement*); 221 sem::Statement* ReturnStatement(const ast::ReturnStatement*); 222 sem::Statement* Statement(const ast::Statement*); 223 sem::SwitchStatement* SwitchStatement(const ast::SwitchStatement* s); 224 sem::Statement* VariableDeclStatement(const ast::VariableDeclStatement*); 225 bool Statements(const ast::StatementList&); 226 227 bool GlobalVariable(const ast::Variable*); 228 229 // AST and Type validation methods 230 // Each return true on success, false on failure. 231 bool ValidateAlias(const ast::Alias*); 232 bool ValidateArray(const sem::Array* arr, const Source& source); 233 bool ValidateArrayStrideDecoration(const ast::StrideDecoration* deco, 234 uint32_t el_size, 235 uint32_t el_align, 236 const Source& source); 237 bool ValidateAtomic(const ast::Atomic* a, const sem::Atomic* s); 238 bool ValidateAtomicVariable(const sem::Variable* var); 239 bool ValidateAssignment(const ast::AssignmentStatement* a); 240 bool ValidateBitcast(const ast::BitcastExpression* cast, const sem::Type* to); 241 bool ValidateBreakStatement(const sem::Statement* stmt); 242 bool ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco, 243 const sem::Type* storage_type, 244 const bool is_input); 245 bool ValidateContinueStatement(const sem::Statement* stmt); 246 bool ValidateDiscardStatement(const sem::Statement* stmt); 247 bool ValidateElseStatement(const sem::ElseStatement* stmt); 248 bool ValidateEntryPoint(const sem::Function* func); 249 bool ValidateForLoopStatement(const sem::ForLoopStatement* stmt); 250 bool ValidateFallthroughStatement(const sem::Statement* stmt); 251 bool ValidateFunction(const sem::Function* func); 252 bool ValidateFunctionCall(const sem::Call* call); 253 bool ValidateGlobalVariable(const sem::Variable* var); 254 bool ValidateIfStatement(const sem::IfStatement* stmt); 255 bool ValidateInterpolateDecoration(const ast::InterpolateDecoration* deco, 256 const sem::Type* storage_type); 257 bool ValidateIntrinsicCall(const sem::Call* call); 258 bool ValidateLocationDecoration(const ast::LocationDecoration* location, 259 const sem::Type* type, 260 std::unordered_set<uint32_t>& locations, 261 const Source& source, 262 const bool is_input = false); 263 bool ValidateMatrix(const sem::Matrix* ty, const Source& source); 264 bool ValidateFunctionParameter(const ast::Function* func, 265 const sem::Variable* var); 266 bool ValidateParameter(const ast::Function* func, const sem::Variable* var); 267 bool ValidateReturn(const ast::ReturnStatement* ret); 268 bool ValidateStatements(const ast::StatementList& stmts); 269 bool ValidateStorageTexture(const ast::StorageTexture* t); 270 bool ValidateStructure(const sem::Struct* str); 271 bool ValidateStructureConstructorOrCast(const ast::CallExpression* ctor, 272 const sem::Struct* struct_type); 273 bool ValidateSwitch(const ast::SwitchStatement* s); 274 bool ValidateVariable(const sem::Variable* var); 275 bool ValidateVariableConstructorOrCast(const ast::Variable* var, 276 ast::StorageClass storage_class, 277 const sem::Type* storage_type, 278 const sem::Type* rhs_type); 279 bool ValidateVector(const sem::Vector* ty, const Source& source); 280 bool ValidateVectorConstructorOrCast(const ast::CallExpression* ctor, 281 const sem::Vector* vec_type); 282 bool ValidateMatrixConstructorOrCast(const ast::CallExpression* ctor, 283 const sem::Matrix* matrix_type); 284 bool ValidateScalarConstructorOrCast(const ast::CallExpression* ctor, 285 const sem::Type* type); 286 bool ValidateArrayConstructorOrCast(const ast::CallExpression* ctor, 287 const sem::Array* arr_type); 288 bool ValidateTextureIntrinsicFunction(const sem::Call* call); 289 bool ValidateNoDuplicateDecorations(const ast::DecorationList& decorations); 290 // sem::Struct is assumed to have at least one member 291 bool ValidateStorageClassLayout(const sem::Struct* type, 292 ast::StorageClass sc); 293 bool ValidateStorageClassLayout(const sem::Variable* var); 294 295 /// @returns true if the decoration list contains a 296 /// ast::DisableValidationDecoration with the validation mode equal to 297 /// `validation` 298 bool IsValidationDisabled(const ast::DecorationList& decorations, 299 ast::DisabledValidation validation) const; 300 301 /// @returns true if the decoration list does not contains a 302 /// ast::DisableValidationDecoration with the validation mode equal to 303 /// `validation` 304 bool IsValidationEnabled(const ast::DecorationList& decorations, 305 ast::DisabledValidation validation) const; 306 307 /// Resolves the WorkgroupSize for the given function, assigning it to 308 /// current_function_ 309 bool WorkgroupSize(const ast::Function*); 310 311 /// @returns the sem::Type for the ast::Type `ty`, building it if it 312 /// hasn't been constructed already. If an error is raised, nullptr is 313 /// returned. 314 /// @param ty the ast::Type 315 sem::Type* Type(const ast::Type* ty); 316 317 /// @param named_type the named type to resolve 318 /// @returns the resolved semantic type 319 sem::Type* TypeDecl(const ast::TypeDecl* named_type); 320 321 /// Builds and returns the semantic information for the array `arr`. 322 /// This method does not mark the ast::Array node, nor attach the generated 323 /// semantic information to the AST node. 324 /// @returns the semantic Array information, or nullptr if an error is 325 /// raised. 326 /// @param arr the Array to get semantic information for 327 sem::Array* Array(const ast::Array* arr); 328 329 /// Builds and returns the semantic information for the alias `alias`. 330 /// This method does not mark the ast::Alias node, nor attach the generated 331 /// semantic information to the AST node. 332 /// @returns the aliased type, or nullptr if an error is raised. 333 sem::Type* Alias(const ast::Alias* alias); 334 335 /// Builds and returns the semantic information for the structure `str`. 336 /// This method does not mark the ast::Struct node, nor attach the generated 337 /// semantic information to the AST node. 338 /// @returns the semantic Struct information, or nullptr if an error is 339 /// raised. 340 sem::Struct* Structure(const ast::Struct* str); 341 342 /// @returns the semantic info for the variable `var`. If an error is 343 /// raised, nullptr is returned. 344 /// @note this method does not resolve the decorations as these are 345 /// context-dependent (global, local, parameter) 346 /// @param var the variable to create or return the `VariableInfo` for 347 /// @param kind what kind of variable we are declaring 348 /// @param index the index of the parameter, if this variable is a parameter 349 sem::Variable* Variable(const ast::Variable* var, 350 VariableKind kind, 351 uint32_t index = 0); 352 353 /// Records the storage class usage for the given type, and any transient 354 /// dependencies of the type. Validates that the type can be used for the 355 /// given storage class, erroring if it cannot. 356 /// @param sc the storage class to apply to the type and transitent types 357 /// @param ty the type to apply the storage class on 358 /// @param usage the Source of the root variable declaration that uses the 359 /// given type and storage class. Used for generating sensible error 360 /// messages. 361 /// @returns true on success, false on error 362 bool ApplyStorageClassUsageToType(ast::StorageClass sc, 363 sem::Type* ty, 364 const Source& usage); 365 366 /// @param storage_class the storage class 367 /// @returns the default access control for the given storage class 368 ast::Access DefaultAccessForStorageClass(ast::StorageClass storage_class); 369 370 /// Allocate constant IDs for pipeline-overridable constants. 371 void AllocateOverridableConstantIds(); 372 373 /// Set the shadowing information on variable declarations. 374 /// @note this method must only be called after all semantic nodes are built. 375 void SetShadows(); 376 377 /// @returns the resolved type of the ast::Expression `expr` 378 /// @param expr the expression 379 sem::Type* TypeOf(const ast::Expression* expr); 380 381 /// @returns the type name of the given semantic type, unwrapping 382 /// references. 383 std::string TypeNameOf(const sem::Type* ty); 384 385 /// @returns the type name of the given semantic type, without unwrapping 386 /// references. 387 std::string RawTypeNameOf(const sem::Type* ty); 388 389 /// @returns the semantic type of the AST literal `lit` 390 /// @param lit the literal 391 sem::Type* TypeOf(const ast::LiteralExpression* lit); 392 393 /// StatementScope() does the following: 394 /// * Creates the AST -> SEM mapping. 395 /// * Assigns `sem` to #current_statement_ 396 /// * Assigns `sem` to #current_compound_statement_ if `sem` derives from 397 /// sem::CompoundStatement. 398 /// * Assigns `sem` to #current_block_ if `sem` derives from 399 /// sem::BlockStatement. 400 /// * Then calls `callback`. 401 /// * Before returning #current_statement_, #current_compound_statement_, and 402 /// #current_block_ are restored to their original values. 403 /// @returns `sem` if `callback` returns true, otherwise `nullptr`. 404 template <typename SEM, typename F> 405 SEM* StatementScope(const ast::Statement* ast, SEM* sem, F&& callback); 406 407 /// Returns a human-readable string representation of the vector type name 408 /// with the given parameters. 409 /// @param size the vector dimension 410 /// @param element_type scalar vector sub-element type 411 /// @return pretty string representation 412 std::string VectorPretty(uint32_t size, const sem::Type* element_type); 413 414 /// Mark records that the given AST node has been visited, and asserts that 415 /// the given node has not already been seen. Diamonds in the AST are 416 /// illegal. 417 /// @param node the AST node. 418 /// @returns true on success, false on error 419 bool Mark(const ast::Node* node); 420 421 /// Adds the given error message to the diagnostics 422 void AddError(const std::string& msg, const Source& source) const; 423 424 /// Adds the given warning message to the diagnostics 425 void AddWarning(const std::string& msg, const Source& source) const; 426 427 /// Adds the given note message to the diagnostics 428 void AddNote(const std::string& msg, const Source& source) const; 429 430 ////////////////////////////////////////////////////////////////////////////// 431 /// Constant value evaluation methods 432 ////////////////////////////////////////////////////////////////////////////// 433 /// Cast `Value` to `target_type` 434 /// @return the casted value 435 sem::Constant ConstantCast(const sem::Constant& value, 436 const sem::Type* target_elem_type); 437 438 sem::Constant EvaluateConstantValue(const ast::Expression* expr, 439 const sem::Type* type); 440 sem::Constant EvaluateConstantValue(const ast::LiteralExpression* literal, 441 const sem::Type* type); 442 sem::Constant EvaluateConstantValue(const ast::CallExpression* call, 443 const sem::Type* type); 444 445 /// Sem is a helper for obtaining the semantic node for the given AST node. 446 template <typename SEM = sem::Info::InferFromAST, 447 typename AST_OR_TYPE = CastableBase> Sem(const AST_OR_TYPE * ast)448 auto* Sem(const AST_OR_TYPE* ast) { 449 using T = sem::Info::GetResultType<SEM, AST_OR_TYPE>; 450 auto* sem = builder_->Sem().Get(ast); 451 if (!sem) { 452 TINT_ICE(Resolver, diagnostics_) 453 << "AST node '" << ast->TypeInfo().name << "' had no semantic info\n" 454 << "At: " << ast->source << "\n" 455 << "Pointer: " << ast; 456 } 457 return const_cast<T*>(As<T>(sem)); 458 } 459 460 /// @returns true if the symbol is the name of an intrinsic (builtin) 461 /// function. 462 bool IsIntrinsic(Symbol) const; 463 464 /// @returns true if `expr` is the current CallStatement's CallExpression 465 bool IsCallStatement(const ast::Expression* expr) const; 466 467 /// Searches the current statement and up through parents of the current 468 /// statement looking for a loop or for-loop continuing statement. 469 /// @returns the closest continuing statement to the current statement that 470 /// (transitively) owns the current statement. 471 /// @param stop_at_loop if true then the function will return nullptr if a 472 /// loop or for-loop was found before the continuing. 473 const ast::Statement* ClosestContinuing(bool stop_at_loop) const; 474 475 /// @returns the resolved symbol (function, type or variable) for the given 476 /// ast::Identifier or ast::TypeName cast to the given semantic type. 477 template <typename SEM = sem::Node> ResolvedSymbol(const ast::Node * node)478 SEM* ResolvedSymbol(const ast::Node* node) { 479 auto* resolved = utils::Lookup(dependencies_.resolved_symbols, node); 480 return resolved ? const_cast<SEM*>(builder_->Sem().Get<SEM>(resolved)) 481 : nullptr; 482 } 483 484 struct TypeConversionSig { 485 const sem::Type* target; 486 const sem::Type* source; 487 488 bool operator==(const TypeConversionSig&) const; 489 490 /// Hasher provides a hash function for the TypeConversionSig 491 struct Hasher { 492 /// @param sig the TypeConversionSig to create a hash for 493 /// @return the hash value 494 std::size_t operator()(const TypeConversionSig& sig) const; 495 }; 496 }; 497 498 struct TypeConstructorSig { 499 const sem::Type* type; 500 const std::vector<const sem::Type*> parameters; 501 502 TypeConstructorSig(const sem::Type* ty, 503 const std::vector<const sem::Type*> params); 504 TypeConstructorSig(const TypeConstructorSig&); 505 ~TypeConstructorSig(); 506 bool operator==(const TypeConstructorSig&) const; 507 508 /// Hasher provides a hash function for the TypeConstructorSig 509 struct Hasher { 510 /// @param sig the TypeConstructorSig to create a hash for 511 /// @return the hash value 512 std::size_t operator()(const TypeConstructorSig& sig) const; 513 }; 514 }; 515 516 ProgramBuilder* const builder_; 517 diag::List& diagnostics_; 518 std::unique_ptr<IntrinsicTable> const intrinsic_table_; 519 DependencyGraph dependencies_; 520 std::vector<sem::Function*> entry_points_; 521 std::unordered_map<const sem::Type*, const Source&> atomic_composite_info_; 522 std::unordered_set<const ast::Node*> marked_; 523 std::unordered_map<uint32_t, const sem::Variable*> constant_ids_; 524 std::unordered_map<TypeConversionSig, 525 sem::CallTarget*, 526 TypeConversionSig::Hasher> 527 type_conversions_; 528 std::unordered_map<TypeConstructorSig, 529 sem::CallTarget*, 530 TypeConstructorSig::Hasher> 531 type_ctors_; 532 533 sem::Function* current_function_ = nullptr; 534 sem::Statement* current_statement_ = nullptr; 535 sem::CompoundStatement* current_compound_statement_ = nullptr; 536 sem::BlockStatement* current_block_ = nullptr; 537 }; 538 539 } // namespace resolver 540 } // namespace tint 541 542 #endif // SRC_RESOLVER_RESOLVER_H_ 543