1 // 2 // Copyright 2017 The ANGLE Project Authors. All rights reserved. 3 // Use of this source code is governed by a BSD-style license that can be 4 // found in the LICENSE file. 5 // 6 // IntermTraverse.h : base classes for AST traversers that walk the AST and 7 // also have the ability to transform it by replacing nodes. 8 9 #ifndef COMPILER_TRANSLATOR_TREEUTIL_INTERMTRAVERSE_H_ 10 #define COMPILER_TRANSLATOR_TREEUTIL_INTERMTRAVERSE_H_ 11 12 #include "compiler/translator/IntermNode.h" 13 #include "compiler/translator/tree_util/Visit.h" 14 15 namespace sh 16 { 17 18 class TCompiler; 19 class TSymbolTable; 20 class TSymbolUniqueId; 21 22 // For traversing the tree. User should derive from this class overriding the visit functions, 23 // and then pass an object of the subclass to a traverse method of a node. 24 // 25 // The traverse*() functions may also be overridden to do other bookkeeping on the tree to provide 26 // contextual information to the visit functions, such as whether the node is the target of an 27 // assignment. This is complex to maintain and so should only be done in special cases. 28 // 29 // When using this, just fill in the methods for nodes you want visited. 30 // Return false from a pre-visit to skip visiting that node's subtree. 31 // 32 // See also how to write AST transformations documentation: 33 // https://github.com/google/angle/blob/master/doc/WritingShaderASTTransformations.md 34 class TIntermTraverser : angle::NonCopyable 35 { 36 public: 37 POOL_ALLOCATOR_NEW_DELETE 38 TIntermTraverser(bool preVisitIn, 39 bool inVisitIn, 40 bool postVisitIn, 41 TSymbolTable *symbolTable = nullptr); 42 virtual ~TIntermTraverser(); 43 visitSymbol(TIntermSymbol * node)44 virtual void visitSymbol(TIntermSymbol *node) {} visitConstantUnion(TIntermConstantUnion * node)45 virtual void visitConstantUnion(TIntermConstantUnion *node) {} visitSwizzle(Visit visit,TIntermSwizzle * node)46 virtual bool visitSwizzle(Visit visit, TIntermSwizzle *node) { return true; } visitBinary(Visit visit,TIntermBinary * node)47 virtual bool visitBinary(Visit visit, TIntermBinary *node) { return true; } visitUnary(Visit visit,TIntermUnary * node)48 virtual bool visitUnary(Visit visit, TIntermUnary *node) { return true; } visitTernary(Visit visit,TIntermTernary * node)49 virtual bool visitTernary(Visit visit, TIntermTernary *node) { return true; } visitIfElse(Visit visit,TIntermIfElse * node)50 virtual bool visitIfElse(Visit visit, TIntermIfElse *node) { return true; } visitSwitch(Visit visit,TIntermSwitch * node)51 virtual bool visitSwitch(Visit visit, TIntermSwitch *node) { return true; } visitCase(Visit visit,TIntermCase * node)52 virtual bool visitCase(Visit visit, TIntermCase *node) { return true; } visitFunctionPrototype(TIntermFunctionPrototype * node)53 virtual void visitFunctionPrototype(TIntermFunctionPrototype *node) {} visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)54 virtual bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) 55 { 56 return true; 57 } visitAggregate(Visit visit,TIntermAggregate * node)58 virtual bool visitAggregate(Visit visit, TIntermAggregate *node) { return true; } visitBlock(Visit visit,TIntermBlock * node)59 virtual bool visitBlock(Visit visit, TIntermBlock *node) { return true; } visitGlobalQualifierDeclaration(Visit visit,TIntermGlobalQualifierDeclaration * node)60 virtual bool visitGlobalQualifierDeclaration(Visit visit, 61 TIntermGlobalQualifierDeclaration *node) 62 { 63 return true; 64 } visitDeclaration(Visit visit,TIntermDeclaration * node)65 virtual bool visitDeclaration(Visit visit, TIntermDeclaration *node) { return true; } visitLoop(Visit visit,TIntermLoop * node)66 virtual bool visitLoop(Visit visit, TIntermLoop *node) { return true; } visitBranch(Visit visit,TIntermBranch * node)67 virtual bool visitBranch(Visit visit, TIntermBranch *node) { return true; } visitPreprocessorDirective(TIntermPreprocessorDirective * node)68 virtual void visitPreprocessorDirective(TIntermPreprocessorDirective *node) {} 69 70 // The traverse functions contain logic for iterating over the children of the node 71 // and calling the visit functions in the appropriate places. They also track some 72 // context that may be used by the visit functions. 73 74 // The generic traverse() function is used for nodes that don't need special handling. 75 // It's templated in order to avoid virtual function calls, this gains around 2% compiler 76 // performance. 77 template <typename T> 78 void traverse(T *node); 79 80 // Specialized traverse functions are implemented for node types where traversal logic may need 81 // to be overridden or where some special bookkeeping needs to be done. 82 virtual void traverseBinary(TIntermBinary *node); 83 virtual void traverseUnary(TIntermUnary *node); 84 virtual void traverseFunctionDefinition(TIntermFunctionDefinition *node); 85 virtual void traverseAggregate(TIntermAggregate *node); 86 virtual void traverseBlock(TIntermBlock *node); 87 virtual void traverseLoop(TIntermLoop *node); 88 getMaxDepth()89 int getMaxDepth() const { return mMaxDepth; } 90 91 // If traversers need to replace nodes, they can add the replacements in 92 // mReplacements/mMultiReplacements during traversal and the user of the traverser should call 93 // this function after traversal to perform them. 94 // 95 // Compiler is used to validate the tree. Node is the same given to traverse(). Returns false 96 // if the tree is invalid after update. 97 [[nodiscard]] bool updateTree(TCompiler *compiler, TIntermNode *node); 98 99 protected: 100 void setMaxAllowedDepth(int depth); 101 102 // Should only be called from traverse*() functions incrementDepth(TIntermNode * current)103 bool incrementDepth(TIntermNode *current) 104 { 105 mMaxDepth = std::max(mMaxDepth, static_cast<int>(mPath.size())); 106 mPath.push_back(current); 107 return mMaxDepth < mMaxAllowedDepth; 108 } 109 110 // Should only be called from traverse*() functions decrementDepth()111 void decrementDepth() { mPath.pop_back(); } 112 getCurrentTraversalDepth()113 int getCurrentTraversalDepth() const { return static_cast<int>(mPath.size()) - 1; } getCurrentBlockDepth()114 int getCurrentBlockDepth() const { return static_cast<int>(mParentBlockStack.size()) - 1; } 115 116 // RAII helper for incrementDepth/decrementDepth 117 class [[nodiscard]] ScopedNodeInTraversalPath 118 { 119 public: ScopedNodeInTraversalPath(TIntermTraverser * traverser,TIntermNode * current)120 ScopedNodeInTraversalPath(TIntermTraverser *traverser, TIntermNode *current) 121 : mTraverser(traverser) 122 { 123 mWithinDepthLimit = mTraverser->incrementDepth(current); 124 } ~ScopedNodeInTraversalPath()125 ~ScopedNodeInTraversalPath() { mTraverser->decrementDepth(); } 126 isWithinDepthLimit()127 bool isWithinDepthLimit() { return mWithinDepthLimit; } 128 129 private: 130 TIntermTraverser *mTraverser; 131 bool mWithinDepthLimit; 132 }; 133 // Optimized traversal functions for leaf nodes directly access ScopedNodeInTraversalPath. 134 friend void TIntermSymbol::traverse(TIntermTraverser *); 135 friend void TIntermConstantUnion::traverse(TIntermTraverser *); 136 friend void TIntermFunctionPrototype::traverse(TIntermTraverser *); 137 getParentNode()138 TIntermNode *getParentNode() const 139 { 140 return mPath.size() <= 1 ? nullptr : mPath[mPath.size() - 2u]; 141 } 142 143 // Return the nth ancestor of the node being traversed. getAncestorNode(0) == getParentNode() getAncestorNode(unsigned int n)144 TIntermNode *getAncestorNode(unsigned int n) const 145 { 146 if (mPath.size() > n + 1u) 147 { 148 return mPath[mPath.size() - n - 2u]; 149 } 150 return nullptr; 151 } 152 153 // Returns what child index is currently being visited. For example when visiting the children 154 // of an aggregate, it can be used to find out which argument of the parent (aggregate) node 155 // they correspond to. Only valid in the PreVisit call of the child. getParentChildIndex(Visit visit)156 size_t getParentChildIndex(Visit visit) const 157 { 158 ASSERT(visit == PreVisit); 159 return mCurrentChildIndex; 160 } 161 // Returns what child index has just been processed. Only valid in the InVisit and PostVisit 162 // calls of the parent node. getLastTraversedChildIndex(Visit visit)163 size_t getLastTraversedChildIndex(Visit visit) const 164 { 165 ASSERT(visit != PreVisit); 166 return mCurrentChildIndex; 167 } 168 169 const TIntermBlock *getParentBlock() const; 170 getRootNode()171 TIntermNode *getRootNode() const 172 { 173 ASSERT(!mPath.empty()); 174 return mPath.front(); 175 } 176 177 void pushParentBlock(TIntermBlock *node); 178 void incrementParentBlockPos(); 179 void popParentBlock(); 180 181 // To replace a single node with multiple nodes in the parent aggregate. May be used with blocks 182 // but also with other nodes like declarations. 183 struct NodeReplaceWithMultipleEntry 184 { NodeReplaceWithMultipleEntryNodeReplaceWithMultipleEntry185 NodeReplaceWithMultipleEntry(TIntermAggregateBase *parentIn, 186 TIntermNode *originalIn, 187 TIntermSequence &&replacementsIn) 188 : parent(parentIn), original(originalIn), replacements(std::move(replacementsIn)) 189 {} 190 191 TIntermAggregateBase *parent; 192 TIntermNode *original; 193 TIntermSequence replacements; 194 }; 195 196 // Helper to insert statements in the parent block of the node currently being traversed. 197 // The statements will be inserted before the node being traversed once updateTree is called. 198 // Should only be called during PreVisit or PostVisit if called from block nodes. 199 // Note that two insertions to the same position in the same block are not supported. 200 void insertStatementsInParentBlock(const TIntermSequence &insertions); 201 202 // Same as above, but supports simultaneous insertion of statements before and after the node 203 // currently being traversed. 204 void insertStatementsInParentBlock(const TIntermSequence &insertionsBefore, 205 const TIntermSequence &insertionsAfter); 206 207 // Helper to insert a single statement. 208 void insertStatementInParentBlock(TIntermNode *statement); 209 210 // Explicitly specify where to insert statements. The statements are inserted before and after 211 // the specified position. The statements will be inserted once updateTree is called. Note that 212 // two insertions to the same position in the same block are not supported. 213 void insertStatementsInBlockAtPosition(TIntermBlock *parent, 214 size_t position, 215 const TIntermSequence &insertionsBefore, 216 const TIntermSequence &insertionsAfter); 217 218 enum class OriginalNode 219 { 220 BECOMES_CHILD, 221 IS_DROPPED 222 }; 223 224 void clearReplacementQueue(); 225 226 // Replace the node currently being visited with replacement. 227 void queueReplacement(TIntermNode *replacement, OriginalNode originalStatus); 228 // Explicitly specify a node to replace with replacement. 229 void queueReplacementWithParent(TIntermNode *parent, 230 TIntermNode *original, 231 TIntermNode *replacement, 232 OriginalNode originalStatus); 233 // Walk the ancestors and replace the access chain that leads to this symbol. This fixes up the 234 // types of the intermediate nodes, so it should be used when the type of the symbol changes. 235 // The AST transformation must still visit the (indirect) index nodes to transform the 236 // expression inside those nodes. Note that due to the way these replacements work, the AST 237 // transformation should not attempt to replace the actual index node itself, but only a subnode 238 // of that. 239 // 240 // Node 1 Node 6 241 // EOpIndexDirect EOpIndexDirect 242 // / \ / \ 243 // Node 2 Node 3 Node 7 Node 3 244 // EOpIndexIndirect N --> replaced with --> EOpIndexIndirect N 245 // / \ / \ 246 // Node 4 Node 5 Node 8 Node 5 247 // symbol expression replacement expression 248 // ^ ^ 249 // | | 250 // This symbol is being replaced, This node is directly placed in the 251 // and the replacement is given new access chain, and its parent is 252 // to this function. is changed. This is why a 253 // replacment attempt for this node 254 // itself will not work. 255 // 256 void queueAccessChainReplacement(TIntermTyped *replacement); 257 258 const bool preVisit; 259 const bool inVisit; 260 const bool postVisit; 261 262 int mMaxDepth; 263 int mMaxAllowedDepth; 264 265 bool mInGlobalScope; 266 267 // During traversing, save all the changes that need to happen into 268 // mReplacements/mMultiReplacements, then do them by calling updateTree(). 269 // Multi replacements are processed after single replacements. 270 std::vector<NodeReplaceWithMultipleEntry> mMultiReplacements; 271 272 TSymbolTable *mSymbolTable; 273 274 private: 275 // To insert multiple nodes into the parent block. 276 struct NodeInsertMultipleEntry 277 { NodeInsertMultipleEntryNodeInsertMultipleEntry278 NodeInsertMultipleEntry(TIntermBlock *_parent, 279 TIntermSequence::size_type _position, 280 TIntermSequence _insertionsBefore, 281 TIntermSequence _insertionsAfter) 282 : parent(_parent), 283 position(_position), 284 insertionsBefore(_insertionsBefore), 285 insertionsAfter(_insertionsAfter) 286 {} 287 288 TIntermBlock *parent; 289 TIntermSequence::size_type position; 290 TIntermSequence insertionsBefore; 291 TIntermSequence insertionsAfter; 292 }; 293 294 static bool CompareInsertion(const NodeInsertMultipleEntry &a, 295 const NodeInsertMultipleEntry &b); 296 297 // To replace a single node with another on the parent node 298 struct NodeUpdateEntry 299 { NodeUpdateEntryNodeUpdateEntry300 NodeUpdateEntry(TIntermNode *_parent, 301 TIntermNode *_original, 302 TIntermNode *_replacement, 303 bool _originalBecomesChildOfReplacement) 304 : parent(_parent), 305 original(_original), 306 replacement(_replacement), 307 originalBecomesChildOfReplacement(_originalBecomesChildOfReplacement) 308 {} 309 310 TIntermNode *parent; 311 TIntermNode *original; 312 TIntermNode *replacement; 313 bool originalBecomesChildOfReplacement; 314 }; 315 316 struct ParentBlock 317 { ParentBlockParentBlock318 ParentBlock(TIntermBlock *nodeIn, TIntermSequence::size_type posIn) 319 : node(nodeIn), pos(posIn) 320 {} 321 322 TIntermBlock *node; 323 TIntermSequence::size_type pos; 324 }; 325 326 std::vector<NodeInsertMultipleEntry> mInsertions; 327 std::vector<NodeUpdateEntry> mReplacements; 328 329 // All the nodes from root to the current node during traversing. 330 TVector<TIntermNode *> mPath; 331 // The current child of parent being traversed. 332 size_t mCurrentChildIndex; 333 334 // All the code blocks from the root to the current node's parent during traversal. 335 std::vector<ParentBlock> mParentBlockStack; 336 }; 337 338 // Traverser parent class that tracks where a node is a destination of a write operation and so is 339 // required to be an l-value. 340 class TLValueTrackingTraverser : public TIntermTraverser 341 { 342 public: 343 TLValueTrackingTraverser(bool preVisit, 344 bool inVisit, 345 bool postVisit, 346 TSymbolTable *symbolTable); ~TLValueTrackingTraverser()347 ~TLValueTrackingTraverser() override {} 348 349 void traverseBinary(TIntermBinary *node) final; 350 void traverseUnary(TIntermUnary *node) final; 351 void traverseAggregate(TIntermAggregate *node) final; 352 353 protected: isLValueRequiredHere()354 bool isLValueRequiredHere() const 355 { 356 return mOperatorRequiresLValue || mInFunctionCallOutParameter; 357 } 358 359 private: 360 // Track whether an l-value is required in the node that is currently being traversed by the 361 // surrounding operator. 362 // Use isLValueRequiredHere to check all conditions which require an l-value. setOperatorRequiresLValue(bool lValueRequired)363 void setOperatorRequiresLValue(bool lValueRequired) 364 { 365 mOperatorRequiresLValue = lValueRequired; 366 } operatorRequiresLValue()367 bool operatorRequiresLValue() const { return mOperatorRequiresLValue; } 368 369 // Track whether an l-value is required inside a function call. 370 void setInFunctionCallOutParameter(bool inOutParameter); 371 bool isInFunctionCallOutParameter() const; 372 373 bool mOperatorRequiresLValue; 374 bool mInFunctionCallOutParameter; 375 }; 376 377 } // namespace sh 378 379 #endif // COMPILER_TRANSLATOR_TREEUTIL_INTERMTRAVERSE_H_ 380