• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2017 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // IntermTraverse.h : base classes for AST traversers that walk the AST and
7 //   also have the ability to transform it by replacing nodes.
8 
9 #ifndef COMPILER_TRANSLATOR_TREEUTIL_INTERMTRAVERSE_H_
10 #define COMPILER_TRANSLATOR_TREEUTIL_INTERMTRAVERSE_H_
11 
12 #include "compiler/translator/IntermNode.h"
13 #include "compiler/translator/tree_util/Visit.h"
14 
15 namespace sh
16 {
17 
18 class TSymbolTable;
19 class TSymbolUniqueId;
20 
21 // For traversing the tree.  User should derive from this class overriding the visit functions,
22 // and then pass an object of the subclass to a traverse method of a node.
23 //
24 // The traverse*() functions may also be overridden to do other bookkeeping on the tree to provide
25 // contextual information to the visit functions, such as whether the node is the target of an
26 // assignment. This is complex to maintain and so should only be done in special cases.
27 //
28 // When using this, just fill in the methods for nodes you want visited.
29 // Return false from a pre-visit to skip visiting that node's subtree.
30 //
31 // See also how to write AST transformations documentation:
32 //   https://github.com/google/angle/blob/master/doc/WritingShaderASTTransformations.md
33 class TIntermTraverser : angle::NonCopyable
34 {
35   public:
36     POOL_ALLOCATOR_NEW_DELETE
37     TIntermTraverser(bool preVisit,
38                      bool inVisit,
39                      bool postVisit,
40                      TSymbolTable *symbolTable = nullptr);
41     virtual ~TIntermTraverser();
42 
visitSymbol(TIntermSymbol * node)43     virtual void visitSymbol(TIntermSymbol *node) {}
visitConstantUnion(TIntermConstantUnion * node)44     virtual void visitConstantUnion(TIntermConstantUnion *node) {}
visitSwizzle(Visit visit,TIntermSwizzle * node)45     virtual bool visitSwizzle(Visit visit, TIntermSwizzle *node) { return true; }
visitBinary(Visit visit,TIntermBinary * node)46     virtual bool visitBinary(Visit visit, TIntermBinary *node) { return true; }
visitUnary(Visit visit,TIntermUnary * node)47     virtual bool visitUnary(Visit visit, TIntermUnary *node) { return true; }
visitTernary(Visit visit,TIntermTernary * node)48     virtual bool visitTernary(Visit visit, TIntermTernary *node) { return true; }
visitIfElse(Visit visit,TIntermIfElse * node)49     virtual bool visitIfElse(Visit visit, TIntermIfElse *node) { return true; }
visitSwitch(Visit visit,TIntermSwitch * node)50     virtual bool visitSwitch(Visit visit, TIntermSwitch *node) { return true; }
visitCase(Visit visit,TIntermCase * node)51     virtual bool visitCase(Visit visit, TIntermCase *node) { return true; }
visitFunctionPrototype(TIntermFunctionPrototype * node)52     virtual void visitFunctionPrototype(TIntermFunctionPrototype *node) {}
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)53     virtual bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
54     {
55         return true;
56     }
visitAggregate(Visit visit,TIntermAggregate * node)57     virtual bool visitAggregate(Visit visit, TIntermAggregate *node) { return true; }
visitBlock(Visit visit,TIntermBlock * node)58     virtual bool visitBlock(Visit visit, TIntermBlock *node) { return true; }
visitInvariantDeclaration(Visit visit,TIntermInvariantDeclaration * node)59     virtual bool visitInvariantDeclaration(Visit visit, TIntermInvariantDeclaration *node)
60     {
61         return true;
62     }
visitDeclaration(Visit visit,TIntermDeclaration * node)63     virtual bool visitDeclaration(Visit visit, TIntermDeclaration *node) { return true; }
visitLoop(Visit visit,TIntermLoop * node)64     virtual bool visitLoop(Visit visit, TIntermLoop *node) { return true; }
visitBranch(Visit visit,TIntermBranch * node)65     virtual bool visitBranch(Visit visit, TIntermBranch *node) { return true; }
visitPreprocessorDirective(TIntermPreprocessorDirective * node)66     virtual void visitPreprocessorDirective(TIntermPreprocessorDirective *node) {}
67 
68     // The traverse functions contain logic for iterating over the children of the node
69     // and calling the visit functions in the appropriate places. They also track some
70     // context that may be used by the visit functions.
71 
72     // The generic traverse() function is used for nodes that don't need special handling.
73     // It's templated in order to avoid virtual function calls, this gains around 2% compiler
74     // performance.
75     template <typename T>
76     void traverse(T *node);
77 
78     // Specialized traverse functions are implemented for node types where traversal logic may need
79     // to be overridden or where some special bookkeeping needs to be done.
80     virtual void traverseBinary(TIntermBinary *node);
81     virtual void traverseUnary(TIntermUnary *node);
82     virtual void traverseFunctionDefinition(TIntermFunctionDefinition *node);
83     virtual void traverseAggregate(TIntermAggregate *node);
84     virtual void traverseBlock(TIntermBlock *node);
85     virtual void traverseLoop(TIntermLoop *node);
86 
getMaxDepth()87     int getMaxDepth() const { return mMaxDepth; }
88 
89     // If traversers need to replace nodes, they can add the replacements in
90     // mReplacements/mMultiReplacements during traversal and the user of the traverser should call
91     // this function after traversal to perform them.
92     void updateTree();
93 
94   protected:
95     void setMaxAllowedDepth(int depth);
96 
97     // Should only be called from traverse*() functions
incrementDepth(TIntermNode * current)98     bool incrementDepth(TIntermNode *current)
99     {
100         mMaxDepth = std::max(mMaxDepth, static_cast<int>(mPath.size()));
101         mPath.push_back(current);
102         return mMaxDepth < mMaxAllowedDepth;
103     }
104 
105     // Should only be called from traverse*() functions
decrementDepth()106     void decrementDepth() { mPath.pop_back(); }
107 
getCurrentTraversalDepth()108     int getCurrentTraversalDepth() const { return static_cast<int>(mPath.size()) - 1; }
109 
110     // RAII helper for incrementDepth/decrementDepth
111     class ScopedNodeInTraversalPath
112     {
113       public:
ScopedNodeInTraversalPath(TIntermTraverser * traverser,TIntermNode * current)114         ScopedNodeInTraversalPath(TIntermTraverser *traverser, TIntermNode *current)
115             : mTraverser(traverser)
116         {
117             mWithinDepthLimit = mTraverser->incrementDepth(current);
118         }
~ScopedNodeInTraversalPath()119         ~ScopedNodeInTraversalPath() { mTraverser->decrementDepth(); }
120 
isWithinDepthLimit()121         bool isWithinDepthLimit() { return mWithinDepthLimit; }
122 
123       private:
124         TIntermTraverser *mTraverser;
125         bool mWithinDepthLimit;
126     };
127     // Optimized traversal functions for leaf nodes directly access ScopedNodeInTraversalPath.
128     friend void TIntermSymbol::traverse(TIntermTraverser *);
129     friend void TIntermConstantUnion::traverse(TIntermTraverser *);
130     friend void TIntermFunctionPrototype::traverse(TIntermTraverser *);
131 
getParentNode()132     TIntermNode *getParentNode() { return mPath.size() <= 1 ? nullptr : mPath[mPath.size() - 2u]; }
133 
134     // Return the nth ancestor of the node being traversed. getAncestorNode(0) == getParentNode()
getAncestorNode(unsigned int n)135     TIntermNode *getAncestorNode(unsigned int n)
136     {
137         if (mPath.size() > n + 1u)
138         {
139             return mPath[mPath.size() - n - 2u];
140         }
141         return nullptr;
142     }
143 
144     const TIntermBlock *getParentBlock() const;
145 
146     void pushParentBlock(TIntermBlock *node);
147     void incrementParentBlockPos();
148     void popParentBlock();
149 
150     // To replace a single node with multiple nodes in the parent aggregate. May be used with blocks
151     // but also with other nodes like declarations.
152     struct NodeReplaceWithMultipleEntry
153     {
NodeReplaceWithMultipleEntryNodeReplaceWithMultipleEntry154         NodeReplaceWithMultipleEntry(TIntermAggregateBase *parentIn,
155                                      TIntermNode *originalIn,
156                                      TIntermSequence replacementsIn)
157             : parent(parentIn), original(originalIn), replacements(std::move(replacementsIn))
158         {}
159 
160         TIntermAggregateBase *parent;
161         TIntermNode *original;
162         TIntermSequence replacements;
163     };
164 
165     // Helper to insert statements in the parent block of the node currently being traversed.
166     // The statements will be inserted before the node being traversed once updateTree is called.
167     // Should only be called during PreVisit or PostVisit if called from block nodes.
168     // Note that two insertions to the same position in the same block are not supported.
169     void insertStatementsInParentBlock(const TIntermSequence &insertions);
170 
171     // Same as above, but supports simultaneous insertion of statements before and after the node
172     // currently being traversed.
173     void insertStatementsInParentBlock(const TIntermSequence &insertionsBefore,
174                                        const TIntermSequence &insertionsAfter);
175 
176     // Helper to insert a single statement.
177     void insertStatementInParentBlock(TIntermNode *statement);
178 
179     // Explicitly specify where to insert statements. The statements are inserted before and after
180     // the specified position. The statements will be inserted once updateTree is called. Note that
181     // two insertions to the same position in the same block are not supported.
182     void insertStatementsInBlockAtPosition(TIntermBlock *parent,
183                                            size_t position,
184                                            const TIntermSequence &insertionsBefore,
185                                            const TIntermSequence &insertionsAfter);
186 
187     enum class OriginalNode
188     {
189         BECOMES_CHILD,
190         IS_DROPPED
191     };
192 
193     void clearReplacementQueue();
194 
195     // Replace the node currently being visited with replacement.
196     void queueReplacement(TIntermNode *replacement, OriginalNode originalStatus);
197     // Explicitly specify a node to replace with replacement.
198     void queueReplacementWithParent(TIntermNode *parent,
199                                     TIntermNode *original,
200                                     TIntermNode *replacement,
201                                     OriginalNode originalStatus);
202 
203     const bool preVisit;
204     const bool inVisit;
205     const bool postVisit;
206 
207     int mMaxDepth;
208     int mMaxAllowedDepth;
209 
210     bool mInGlobalScope;
211 
212     // During traversing, save all the changes that need to happen into
213     // mReplacements/mMultiReplacements, then do them by calling updateTree().
214     // Multi replacements are processed after single replacements.
215     std::vector<NodeReplaceWithMultipleEntry> mMultiReplacements;
216 
217     TSymbolTable *mSymbolTable;
218 
219   private:
220     // To insert multiple nodes into the parent block.
221     struct NodeInsertMultipleEntry
222     {
NodeInsertMultipleEntryNodeInsertMultipleEntry223         NodeInsertMultipleEntry(TIntermBlock *_parent,
224                                 TIntermSequence::size_type _position,
225                                 TIntermSequence _insertionsBefore,
226                                 TIntermSequence _insertionsAfter)
227             : parent(_parent),
228               position(_position),
229               insertionsBefore(_insertionsBefore),
230               insertionsAfter(_insertionsAfter)
231         {}
232 
233         TIntermBlock *parent;
234         TIntermSequence::size_type position;
235         TIntermSequence insertionsBefore;
236         TIntermSequence insertionsAfter;
237     };
238 
239     static bool CompareInsertion(const NodeInsertMultipleEntry &a,
240                                  const NodeInsertMultipleEntry &b);
241 
242     // To replace a single node with another on the parent node
243     struct NodeUpdateEntry
244     {
NodeUpdateEntryNodeUpdateEntry245         NodeUpdateEntry(TIntermNode *_parent,
246                         TIntermNode *_original,
247                         TIntermNode *_replacement,
248                         bool _originalBecomesChildOfReplacement)
249             : parent(_parent),
250               original(_original),
251               replacement(_replacement),
252               originalBecomesChildOfReplacement(_originalBecomesChildOfReplacement)
253         {}
254 
255         TIntermNode *parent;
256         TIntermNode *original;
257         TIntermNode *replacement;
258         bool originalBecomesChildOfReplacement;
259     };
260 
261     struct ParentBlock
262     {
ParentBlockParentBlock263         ParentBlock(TIntermBlock *nodeIn, TIntermSequence::size_type posIn)
264             : node(nodeIn), pos(posIn)
265         {}
266 
267         TIntermBlock *node;
268         TIntermSequence::size_type pos;
269     };
270 
271     std::vector<NodeInsertMultipleEntry> mInsertions;
272     std::vector<NodeUpdateEntry> mReplacements;
273 
274     // All the nodes from root to the current node during traversing.
275     TVector<TIntermNode *> mPath;
276 
277     // All the code blocks from the root to the current node's parent during traversal.
278     std::vector<ParentBlock> mParentBlockStack;
279 };
280 
281 // Traverser parent class that tracks where a node is a destination of a write operation and so is
282 // required to be an l-value.
283 class TLValueTrackingTraverser : public TIntermTraverser
284 {
285   public:
286     TLValueTrackingTraverser(bool preVisit,
287                              bool inVisit,
288                              bool postVisit,
289                              TSymbolTable *symbolTable);
~TLValueTrackingTraverser()290     virtual ~TLValueTrackingTraverser() {}
291 
292     void traverseBinary(TIntermBinary *node) final;
293     void traverseUnary(TIntermUnary *node) final;
294     void traverseAggregate(TIntermAggregate *node) final;
295 
296   protected:
isLValueRequiredHere()297     bool isLValueRequiredHere() const
298     {
299         return mOperatorRequiresLValue || mInFunctionCallOutParameter;
300     }
301 
302   private:
303     // Track whether an l-value is required in the node that is currently being traversed by the
304     // surrounding operator.
305     // Use isLValueRequiredHere to check all conditions which require an l-value.
setOperatorRequiresLValue(bool lValueRequired)306     void setOperatorRequiresLValue(bool lValueRequired)
307     {
308         mOperatorRequiresLValue = lValueRequired;
309     }
operatorRequiresLValue()310     bool operatorRequiresLValue() const { return mOperatorRequiresLValue; }
311 
312     // Track whether an l-value is required inside a function call.
313     void setInFunctionCallOutParameter(bool inOutParameter);
314     bool isInFunctionCallOutParameter() const;
315 
316     bool mOperatorRequiresLValue;
317     bool mInFunctionCallOutParameter;
318 };
319 
320 }  // namespace sh
321 
322 #endif  // COMPILER_TRANSLATOR_TREEUTIL_INTERMTRAVERSE_H_
323