1 // 2 // Copyright 2017 The ANGLE Project Authors. All rights reserved. 3 // Use of this source code is governed by a BSD-style license that can be 4 // found in the LICENSE file. 5 // 6 // IntermTraverse.h : base classes for AST traversers that walk the AST and 7 // also have the ability to transform it by replacing nodes. 8 9 #ifndef COMPILER_TRANSLATOR_TREEUTIL_INTERMTRAVERSE_H_ 10 #define COMPILER_TRANSLATOR_TREEUTIL_INTERMTRAVERSE_H_ 11 12 #include "compiler/translator/IntermNode.h" 13 #include "compiler/translator/tree_util/Visit.h" 14 15 namespace sh 16 { 17 18 class TSymbolTable; 19 class TSymbolUniqueId; 20 21 // For traversing the tree. User should derive from this class overriding the visit functions, 22 // and then pass an object of the subclass to a traverse method of a node. 23 // 24 // The traverse*() functions may also be overridden to do other bookkeeping on the tree to provide 25 // contextual information to the visit functions, such as whether the node is the target of an 26 // assignment. This is complex to maintain and so should only be done in special cases. 27 // 28 // When using this, just fill in the methods for nodes you want visited. 29 // Return false from a pre-visit to skip visiting that node's subtree. 30 // 31 // See also how to write AST transformations documentation: 32 // https://github.com/google/angle/blob/master/doc/WritingShaderASTTransformations.md 33 class TIntermTraverser : angle::NonCopyable 34 { 35 public: 36 POOL_ALLOCATOR_NEW_DELETE 37 TIntermTraverser(bool preVisit, 38 bool inVisit, 39 bool postVisit, 40 TSymbolTable *symbolTable = nullptr); 41 virtual ~TIntermTraverser(); 42 visitSymbol(TIntermSymbol * node)43 virtual void visitSymbol(TIntermSymbol *node) {} visitConstantUnion(TIntermConstantUnion * node)44 virtual void visitConstantUnion(TIntermConstantUnion *node) {} visitSwizzle(Visit visit,TIntermSwizzle * node)45 virtual bool visitSwizzle(Visit visit, TIntermSwizzle *node) { return true; } visitBinary(Visit visit,TIntermBinary * node)46 virtual bool visitBinary(Visit visit, TIntermBinary *node) { return true; } visitUnary(Visit visit,TIntermUnary * node)47 virtual bool visitUnary(Visit visit, TIntermUnary *node) { return true; } visitTernary(Visit visit,TIntermTernary * node)48 virtual bool visitTernary(Visit visit, TIntermTernary *node) { return true; } visitIfElse(Visit visit,TIntermIfElse * node)49 virtual bool visitIfElse(Visit visit, TIntermIfElse *node) { return true; } visitSwitch(Visit visit,TIntermSwitch * node)50 virtual bool visitSwitch(Visit visit, TIntermSwitch *node) { return true; } visitCase(Visit visit,TIntermCase * node)51 virtual bool visitCase(Visit visit, TIntermCase *node) { return true; } visitFunctionPrototype(TIntermFunctionPrototype * node)52 virtual void visitFunctionPrototype(TIntermFunctionPrototype *node) {} visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)53 virtual bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) 54 { 55 return true; 56 } visitAggregate(Visit visit,TIntermAggregate * node)57 virtual bool visitAggregate(Visit visit, TIntermAggregate *node) { return true; } visitBlock(Visit visit,TIntermBlock * node)58 virtual bool visitBlock(Visit visit, TIntermBlock *node) { return true; } visitInvariantDeclaration(Visit visit,TIntermInvariantDeclaration * node)59 virtual bool visitInvariantDeclaration(Visit visit, TIntermInvariantDeclaration *node) 60 { 61 return true; 62 } visitDeclaration(Visit visit,TIntermDeclaration * node)63 virtual bool visitDeclaration(Visit visit, TIntermDeclaration *node) { return true; } visitLoop(Visit visit,TIntermLoop * node)64 virtual bool visitLoop(Visit visit, TIntermLoop *node) { return true; } visitBranch(Visit visit,TIntermBranch * node)65 virtual bool visitBranch(Visit visit, TIntermBranch *node) { return true; } visitPreprocessorDirective(TIntermPreprocessorDirective * node)66 virtual void visitPreprocessorDirective(TIntermPreprocessorDirective *node) {} 67 68 // The traverse functions contain logic for iterating over the children of the node 69 // and calling the visit functions in the appropriate places. They also track some 70 // context that may be used by the visit functions. 71 72 // The generic traverse() function is used for nodes that don't need special handling. 73 // It's templated in order to avoid virtual function calls, this gains around 2% compiler 74 // performance. 75 template <typename T> 76 void traverse(T *node); 77 78 // Specialized traverse functions are implemented for node types where traversal logic may need 79 // to be overridden or where some special bookkeeping needs to be done. 80 virtual void traverseBinary(TIntermBinary *node); 81 virtual void traverseUnary(TIntermUnary *node); 82 virtual void traverseFunctionDefinition(TIntermFunctionDefinition *node); 83 virtual void traverseAggregate(TIntermAggregate *node); 84 virtual void traverseBlock(TIntermBlock *node); 85 virtual void traverseLoop(TIntermLoop *node); 86 getMaxDepth()87 int getMaxDepth() const { return mMaxDepth; } 88 89 // If traversers need to replace nodes, they can add the replacements in 90 // mReplacements/mMultiReplacements during traversal and the user of the traverser should call 91 // this function after traversal to perform them. 92 void updateTree(); 93 94 protected: 95 void setMaxAllowedDepth(int depth); 96 97 // Should only be called from traverse*() functions incrementDepth(TIntermNode * current)98 bool incrementDepth(TIntermNode *current) 99 { 100 mMaxDepth = std::max(mMaxDepth, static_cast<int>(mPath.size())); 101 mPath.push_back(current); 102 return mMaxDepth < mMaxAllowedDepth; 103 } 104 105 // Should only be called from traverse*() functions decrementDepth()106 void decrementDepth() { mPath.pop_back(); } 107 getCurrentTraversalDepth()108 int getCurrentTraversalDepth() const { return static_cast<int>(mPath.size()) - 1; } 109 110 // RAII helper for incrementDepth/decrementDepth 111 class ScopedNodeInTraversalPath 112 { 113 public: ScopedNodeInTraversalPath(TIntermTraverser * traverser,TIntermNode * current)114 ScopedNodeInTraversalPath(TIntermTraverser *traverser, TIntermNode *current) 115 : mTraverser(traverser) 116 { 117 mWithinDepthLimit = mTraverser->incrementDepth(current); 118 } ~ScopedNodeInTraversalPath()119 ~ScopedNodeInTraversalPath() { mTraverser->decrementDepth(); } 120 isWithinDepthLimit()121 bool isWithinDepthLimit() { return mWithinDepthLimit; } 122 123 private: 124 TIntermTraverser *mTraverser; 125 bool mWithinDepthLimit; 126 }; 127 // Optimized traversal functions for leaf nodes directly access ScopedNodeInTraversalPath. 128 friend void TIntermSymbol::traverse(TIntermTraverser *); 129 friend void TIntermConstantUnion::traverse(TIntermTraverser *); 130 friend void TIntermFunctionPrototype::traverse(TIntermTraverser *); 131 getParentNode()132 TIntermNode *getParentNode() { return mPath.size() <= 1 ? nullptr : mPath[mPath.size() - 2u]; } 133 134 // Return the nth ancestor of the node being traversed. getAncestorNode(0) == getParentNode() getAncestorNode(unsigned int n)135 TIntermNode *getAncestorNode(unsigned int n) 136 { 137 if (mPath.size() > n + 1u) 138 { 139 return mPath[mPath.size() - n - 2u]; 140 } 141 return nullptr; 142 } 143 144 const TIntermBlock *getParentBlock() const; 145 146 void pushParentBlock(TIntermBlock *node); 147 void incrementParentBlockPos(); 148 void popParentBlock(); 149 150 // To replace a single node with multiple nodes in the parent aggregate. May be used with blocks 151 // but also with other nodes like declarations. 152 struct NodeReplaceWithMultipleEntry 153 { NodeReplaceWithMultipleEntryNodeReplaceWithMultipleEntry154 NodeReplaceWithMultipleEntry(TIntermAggregateBase *parentIn, 155 TIntermNode *originalIn, 156 TIntermSequence replacementsIn) 157 : parent(parentIn), original(originalIn), replacements(std::move(replacementsIn)) 158 {} 159 160 TIntermAggregateBase *parent; 161 TIntermNode *original; 162 TIntermSequence replacements; 163 }; 164 165 // Helper to insert statements in the parent block of the node currently being traversed. 166 // The statements will be inserted before the node being traversed once updateTree is called. 167 // Should only be called during PreVisit or PostVisit if called from block nodes. 168 // Note that two insertions to the same position in the same block are not supported. 169 void insertStatementsInParentBlock(const TIntermSequence &insertions); 170 171 // Same as above, but supports simultaneous insertion of statements before and after the node 172 // currently being traversed. 173 void insertStatementsInParentBlock(const TIntermSequence &insertionsBefore, 174 const TIntermSequence &insertionsAfter); 175 176 // Helper to insert a single statement. 177 void insertStatementInParentBlock(TIntermNode *statement); 178 179 // Explicitly specify where to insert statements. The statements are inserted before and after 180 // the specified position. The statements will be inserted once updateTree is called. Note that 181 // two insertions to the same position in the same block are not supported. 182 void insertStatementsInBlockAtPosition(TIntermBlock *parent, 183 size_t position, 184 const TIntermSequence &insertionsBefore, 185 const TIntermSequence &insertionsAfter); 186 187 enum class OriginalNode 188 { 189 BECOMES_CHILD, 190 IS_DROPPED 191 }; 192 193 void clearReplacementQueue(); 194 195 // Replace the node currently being visited with replacement. 196 void queueReplacement(TIntermNode *replacement, OriginalNode originalStatus); 197 // Explicitly specify a node to replace with replacement. 198 void queueReplacementWithParent(TIntermNode *parent, 199 TIntermNode *original, 200 TIntermNode *replacement, 201 OriginalNode originalStatus); 202 203 const bool preVisit; 204 const bool inVisit; 205 const bool postVisit; 206 207 int mMaxDepth; 208 int mMaxAllowedDepth; 209 210 bool mInGlobalScope; 211 212 // During traversing, save all the changes that need to happen into 213 // mReplacements/mMultiReplacements, then do them by calling updateTree(). 214 // Multi replacements are processed after single replacements. 215 std::vector<NodeReplaceWithMultipleEntry> mMultiReplacements; 216 217 TSymbolTable *mSymbolTable; 218 219 private: 220 // To insert multiple nodes into the parent block. 221 struct NodeInsertMultipleEntry 222 { NodeInsertMultipleEntryNodeInsertMultipleEntry223 NodeInsertMultipleEntry(TIntermBlock *_parent, 224 TIntermSequence::size_type _position, 225 TIntermSequence _insertionsBefore, 226 TIntermSequence _insertionsAfter) 227 : parent(_parent), 228 position(_position), 229 insertionsBefore(_insertionsBefore), 230 insertionsAfter(_insertionsAfter) 231 {} 232 233 TIntermBlock *parent; 234 TIntermSequence::size_type position; 235 TIntermSequence insertionsBefore; 236 TIntermSequence insertionsAfter; 237 }; 238 239 static bool CompareInsertion(const NodeInsertMultipleEntry &a, 240 const NodeInsertMultipleEntry &b); 241 242 // To replace a single node with another on the parent node 243 struct NodeUpdateEntry 244 { NodeUpdateEntryNodeUpdateEntry245 NodeUpdateEntry(TIntermNode *_parent, 246 TIntermNode *_original, 247 TIntermNode *_replacement, 248 bool _originalBecomesChildOfReplacement) 249 : parent(_parent), 250 original(_original), 251 replacement(_replacement), 252 originalBecomesChildOfReplacement(_originalBecomesChildOfReplacement) 253 {} 254 255 TIntermNode *parent; 256 TIntermNode *original; 257 TIntermNode *replacement; 258 bool originalBecomesChildOfReplacement; 259 }; 260 261 struct ParentBlock 262 { ParentBlockParentBlock263 ParentBlock(TIntermBlock *nodeIn, TIntermSequence::size_type posIn) 264 : node(nodeIn), pos(posIn) 265 {} 266 267 TIntermBlock *node; 268 TIntermSequence::size_type pos; 269 }; 270 271 std::vector<NodeInsertMultipleEntry> mInsertions; 272 std::vector<NodeUpdateEntry> mReplacements; 273 274 // All the nodes from root to the current node during traversing. 275 TVector<TIntermNode *> mPath; 276 277 // All the code blocks from the root to the current node's parent during traversal. 278 std::vector<ParentBlock> mParentBlockStack; 279 }; 280 281 // Traverser parent class that tracks where a node is a destination of a write operation and so is 282 // required to be an l-value. 283 class TLValueTrackingTraverser : public TIntermTraverser 284 { 285 public: 286 TLValueTrackingTraverser(bool preVisit, 287 bool inVisit, 288 bool postVisit, 289 TSymbolTable *symbolTable); ~TLValueTrackingTraverser()290 virtual ~TLValueTrackingTraverser() {} 291 292 void traverseBinary(TIntermBinary *node) final; 293 void traverseUnary(TIntermUnary *node) final; 294 void traverseAggregate(TIntermAggregate *node) final; 295 296 protected: isLValueRequiredHere()297 bool isLValueRequiredHere() const 298 { 299 return mOperatorRequiresLValue || mInFunctionCallOutParameter; 300 } 301 302 private: 303 // Track whether an l-value is required in the node that is currently being traversed by the 304 // surrounding operator. 305 // Use isLValueRequiredHere to check all conditions which require an l-value. setOperatorRequiresLValue(bool lValueRequired)306 void setOperatorRequiresLValue(bool lValueRequired) 307 { 308 mOperatorRequiresLValue = lValueRequired; 309 } operatorRequiresLValue()310 bool operatorRequiresLValue() const { return mOperatorRequiresLValue; } 311 312 // Track whether an l-value is required inside a function call. 313 void setInFunctionCallOutParameter(bool inOutParameter); 314 bool isInFunctionCallOutParameter() const; 315 316 bool mOperatorRequiresLValue; 317 bool mInFunctionCallOutParameter; 318 }; 319 320 } // namespace sh 321 322 #endif // COMPILER_TRANSLATOR_TREEUTIL_INTERMTRAVERSE_H_ 323