• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2017 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // IntermTraverse.h : base classes for AST traversers that walk the AST and
7 //   also have the ability to transform it by replacing nodes.
8 
9 #ifndef COMPILER_TRANSLATOR_TREEUTIL_INTERMTRAVERSE_H_
10 #define COMPILER_TRANSLATOR_TREEUTIL_INTERMTRAVERSE_H_
11 
12 #include "compiler/translator/IntermNode.h"
13 #include "compiler/translator/tree_util/Visit.h"
14 
15 namespace sh
16 {
17 
18 class TCompiler;
19 class TSymbolTable;
20 class TSymbolUniqueId;
21 
22 // For traversing the tree.  User should derive from this class overriding the visit functions,
23 // and then pass an object of the subclass to a traverse method of a node.
24 //
25 // The traverse*() functions may also be overridden to do other bookkeeping on the tree to provide
26 // contextual information to the visit functions, such as whether the node is the target of an
27 // assignment. This is complex to maintain and so should only be done in special cases.
28 //
29 // When using this, just fill in the methods for nodes you want visited.
30 // Return false from a pre-visit to skip visiting that node's subtree.
31 //
32 // See also how to write AST transformations documentation:
33 //   https://github.com/google/angle/blob/master/doc/WritingShaderASTTransformations.md
34 class TIntermTraverser : angle::NonCopyable
35 {
36   public:
37     POOL_ALLOCATOR_NEW_DELETE
38     TIntermTraverser(bool preVisitIn,
39                      bool inVisitIn,
40                      bool postVisitIn,
41                      TSymbolTable *symbolTable = nullptr);
42     virtual ~TIntermTraverser();
43 
visitSymbol(TIntermSymbol * node)44     virtual void visitSymbol(TIntermSymbol *node) {}
visitConstantUnion(TIntermConstantUnion * node)45     virtual void visitConstantUnion(TIntermConstantUnion *node) {}
visitSwizzle(Visit visit,TIntermSwizzle * node)46     virtual bool visitSwizzle(Visit visit, TIntermSwizzle *node) { return true; }
visitBinary(Visit visit,TIntermBinary * node)47     virtual bool visitBinary(Visit visit, TIntermBinary *node) { return true; }
visitUnary(Visit visit,TIntermUnary * node)48     virtual bool visitUnary(Visit visit, TIntermUnary *node) { return true; }
visitTernary(Visit visit,TIntermTernary * node)49     virtual bool visitTernary(Visit visit, TIntermTernary *node) { return true; }
visitIfElse(Visit visit,TIntermIfElse * node)50     virtual bool visitIfElse(Visit visit, TIntermIfElse *node) { return true; }
visitSwitch(Visit visit,TIntermSwitch * node)51     virtual bool visitSwitch(Visit visit, TIntermSwitch *node) { return true; }
visitCase(Visit visit,TIntermCase * node)52     virtual bool visitCase(Visit visit, TIntermCase *node) { return true; }
visitFunctionPrototype(TIntermFunctionPrototype * node)53     virtual void visitFunctionPrototype(TIntermFunctionPrototype *node) {}
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)54     virtual bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
55     {
56         return true;
57     }
visitAggregate(Visit visit,TIntermAggregate * node)58     virtual bool visitAggregate(Visit visit, TIntermAggregate *node) { return true; }
visitBlock(Visit visit,TIntermBlock * node)59     virtual bool visitBlock(Visit visit, TIntermBlock *node) { return true; }
visitGlobalQualifierDeclaration(Visit visit,TIntermGlobalQualifierDeclaration * node)60     virtual bool visitGlobalQualifierDeclaration(Visit visit,
61                                                  TIntermGlobalQualifierDeclaration *node)
62     {
63         return true;
64     }
visitDeclaration(Visit visit,TIntermDeclaration * node)65     virtual bool visitDeclaration(Visit visit, TIntermDeclaration *node) { return true; }
visitLoop(Visit visit,TIntermLoop * node)66     virtual bool visitLoop(Visit visit, TIntermLoop *node) { return true; }
visitBranch(Visit visit,TIntermBranch * node)67     virtual bool visitBranch(Visit visit, TIntermBranch *node) { return true; }
visitPreprocessorDirective(TIntermPreprocessorDirective * node)68     virtual void visitPreprocessorDirective(TIntermPreprocessorDirective *node) {}
69 
70     // The traverse functions contain logic for iterating over the children of the node
71     // and calling the visit functions in the appropriate places. They also track some
72     // context that may be used by the visit functions.
73 
74     // The generic traverse() function is used for nodes that don't need special handling.
75     // It's templated in order to avoid virtual function calls, this gains around 2% compiler
76     // performance.
77     template <typename T>
78     void traverse(T *node);
79 
80     // Specialized traverse functions are implemented for node types where traversal logic may need
81     // to be overridden or where some special bookkeeping needs to be done.
82     virtual void traverseBinary(TIntermBinary *node);
83     virtual void traverseUnary(TIntermUnary *node);
84     virtual void traverseFunctionDefinition(TIntermFunctionDefinition *node);
85     virtual void traverseAggregate(TIntermAggregate *node);
86     virtual void traverseBlock(TIntermBlock *node);
87     virtual void traverseLoop(TIntermLoop *node);
88 
getMaxDepth()89     int getMaxDepth() const { return mMaxDepth; }
90 
91     // If traversers need to replace nodes, they can add the replacements in
92     // mReplacements/mMultiReplacements during traversal and the user of the traverser should call
93     // this function after traversal to perform them.
94     //
95     // Compiler is used to validate the tree.  Node is the same given to traverse().  Returns false
96     // if the tree is invalid after update.
97     [[nodiscard]] bool updateTree(TCompiler *compiler, TIntermNode *node);
98 
99   protected:
100     void setMaxAllowedDepth(int depth);
101 
102     // Should only be called from traverse*() functions
incrementDepth(TIntermNode * current)103     bool incrementDepth(TIntermNode *current)
104     {
105         mMaxDepth = std::max(mMaxDepth, static_cast<int>(mPath.size()));
106         mPath.push_back(current);
107         return mMaxDepth < mMaxAllowedDepth;
108     }
109 
110     // Should only be called from traverse*() functions
decrementDepth()111     void decrementDepth() { mPath.pop_back(); }
112 
getCurrentTraversalDepth()113     int getCurrentTraversalDepth() const { return static_cast<int>(mPath.size()) - 1; }
getCurrentBlockDepth()114     int getCurrentBlockDepth() const { return static_cast<int>(mParentBlockStack.size()) - 1; }
115 
116     // RAII helper for incrementDepth/decrementDepth
117     class [[nodiscard]] ScopedNodeInTraversalPath
118     {
119       public:
ScopedNodeInTraversalPath(TIntermTraverser * traverser,TIntermNode * current)120         ScopedNodeInTraversalPath(TIntermTraverser *traverser, TIntermNode *current)
121             : mTraverser(traverser)
122         {
123             mWithinDepthLimit = mTraverser->incrementDepth(current);
124         }
~ScopedNodeInTraversalPath()125         ~ScopedNodeInTraversalPath() { mTraverser->decrementDepth(); }
126 
isWithinDepthLimit()127         bool isWithinDepthLimit() { return mWithinDepthLimit; }
128 
129       private:
130         TIntermTraverser *mTraverser;
131         bool mWithinDepthLimit;
132     };
133     // Optimized traversal functions for leaf nodes directly access ScopedNodeInTraversalPath.
134     friend void TIntermSymbol::traverse(TIntermTraverser *);
135     friend void TIntermConstantUnion::traverse(TIntermTraverser *);
136     friend void TIntermFunctionPrototype::traverse(TIntermTraverser *);
137 
getParentNode()138     TIntermNode *getParentNode() const
139     {
140         return mPath.size() <= 1 ? nullptr : mPath[mPath.size() - 2u];
141     }
142 
143     // Return the nth ancestor of the node being traversed. getAncestorNode(0) == getParentNode()
getAncestorNode(unsigned int n)144     TIntermNode *getAncestorNode(unsigned int n) const
145     {
146         if (mPath.size() > n + 1u)
147         {
148             return mPath[mPath.size() - n - 2u];
149         }
150         return nullptr;
151     }
152 
153     // Returns what child index is currently being visited.  For example when visiting the children
154     // of an aggregate, it can be used to find out which argument of the parent (aggregate) node
155     // they correspond to.  Only valid in the PreVisit call of the child.
getParentChildIndex(Visit visit)156     size_t getParentChildIndex(Visit visit) const
157     {
158         ASSERT(visit == PreVisit);
159         return mCurrentChildIndex;
160     }
161     // Returns what child index has just been processed.  Only valid in the InVisit and PostVisit
162     // calls of the parent node.
getLastTraversedChildIndex(Visit visit)163     size_t getLastTraversedChildIndex(Visit visit) const
164     {
165         ASSERT(visit != PreVisit);
166         return mCurrentChildIndex;
167     }
168 
169     const TIntermBlock *getParentBlock() const;
170 
getRootNode()171     TIntermNode *getRootNode() const
172     {
173         ASSERT(!mPath.empty());
174         return mPath.front();
175     }
176 
177     void pushParentBlock(TIntermBlock *node);
178     void incrementParentBlockPos();
179     void popParentBlock();
180 
181     // To replace a single node with multiple nodes in the parent aggregate. May be used with blocks
182     // but also with other nodes like declarations.
183     struct NodeReplaceWithMultipleEntry
184     {
NodeReplaceWithMultipleEntryNodeReplaceWithMultipleEntry185         NodeReplaceWithMultipleEntry(TIntermAggregateBase *parentIn,
186                                      TIntermNode *originalIn,
187                                      TIntermSequence &&replacementsIn)
188             : parent(parentIn), original(originalIn), replacements(std::move(replacementsIn))
189         {}
190 
191         TIntermAggregateBase *parent;
192         TIntermNode *original;
193         TIntermSequence replacements;
194     };
195 
196     // Helper to insert statements in the parent block of the node currently being traversed.
197     // The statements will be inserted before the node being traversed once updateTree is called.
198     // Should only be called during PreVisit or PostVisit if called from block nodes.
199     // Note that two insertions to the same position in the same block are not supported.
200     void insertStatementsInParentBlock(const TIntermSequence &insertions);
201 
202     // Same as above, but supports simultaneous insertion of statements before and after the node
203     // currently being traversed.
204     void insertStatementsInParentBlock(const TIntermSequence &insertionsBefore,
205                                        const TIntermSequence &insertionsAfter);
206 
207     // Helper to insert a single statement.
208     void insertStatementInParentBlock(TIntermNode *statement);
209 
210     // Explicitly specify where to insert statements. The statements are inserted before and after
211     // the specified position. The statements will be inserted once updateTree is called. Note that
212     // two insertions to the same position in the same block are not supported.
213     void insertStatementsInBlockAtPosition(TIntermBlock *parent,
214                                            size_t position,
215                                            const TIntermSequence &insertionsBefore,
216                                            const TIntermSequence &insertionsAfter);
217 
218     enum class OriginalNode
219     {
220         BECOMES_CHILD,
221         IS_DROPPED
222     };
223 
224     void clearReplacementQueue();
225 
226     // Replace the node currently being visited with replacement.
227     void queueReplacement(TIntermNode *replacement, OriginalNode originalStatus);
228     // Explicitly specify a node to replace with replacement.
229     void queueReplacementWithParent(TIntermNode *parent,
230                                     TIntermNode *original,
231                                     TIntermNode *replacement,
232                                     OriginalNode originalStatus);
233     // Walk the ancestors and replace the access chain that leads to this symbol.  This fixes up the
234     // types of the intermediate nodes, so it should be used when the type of the symbol changes.
235     // The AST transformation must still visit the (indirect) index nodes to transform the
236     // expression inside those nodes.  Note that due to the way these replacements work, the AST
237     // transformation should not attempt to replace the actual index node itself, but only a subnode
238     // of that.
239     //
240     //                    Node 1                                                Node 6
241     //                 EOpIndexDirect                                        EOpIndexDirect
242     //                /          \                                              /       \
243     //           Node 2        Node 3                                   Node 7        Node 3
244     //       EOpIndexIndirect     N     --> replaced with -->       EOpIndexIndirect     N
245     //         /        \                                            /        \
246     //      Node 4    Node 5                                      Node 8      Node 5
247     //      symbol   expression                                replacement   expression
248     //        ^                                                                 ^
249     //        |                                                                 |
250     //    This symbol is being replaced,                        This node is directly placed in the
251     //    and the replacement is given                          new access chain, and its parent is
252     //    to this function.                                     is changed.  This is why a
253     //                                                          replacment attempt for this node
254     //                                                          itself will not work.
255     //
256     void queueAccessChainReplacement(TIntermTyped *replacement);
257 
258     const bool preVisit;
259     const bool inVisit;
260     const bool postVisit;
261 
262     int mMaxDepth;
263     int mMaxAllowedDepth;
264 
265     bool mInGlobalScope;
266 
267     // During traversing, save all the changes that need to happen into
268     // mReplacements/mMultiReplacements, then do them by calling updateTree().
269     // Multi replacements are processed after single replacements.
270     std::vector<NodeReplaceWithMultipleEntry> mMultiReplacements;
271 
272     TSymbolTable *mSymbolTable;
273 
274   private:
275     // To insert multiple nodes into the parent block.
276     struct NodeInsertMultipleEntry
277     {
NodeInsertMultipleEntryNodeInsertMultipleEntry278         NodeInsertMultipleEntry(TIntermBlock *_parent,
279                                 TIntermSequence::size_type _position,
280                                 TIntermSequence _insertionsBefore,
281                                 TIntermSequence _insertionsAfter)
282             : parent(_parent),
283               position(_position),
284               insertionsBefore(_insertionsBefore),
285               insertionsAfter(_insertionsAfter)
286         {}
287 
288         TIntermBlock *parent;
289         TIntermSequence::size_type position;
290         TIntermSequence insertionsBefore;
291         TIntermSequence insertionsAfter;
292     };
293 
294     static bool CompareInsertion(const NodeInsertMultipleEntry &a,
295                                  const NodeInsertMultipleEntry &b);
296 
297     // To replace a single node with another on the parent node
298     struct NodeUpdateEntry
299     {
NodeUpdateEntryNodeUpdateEntry300         NodeUpdateEntry(TIntermNode *_parent,
301                         TIntermNode *_original,
302                         TIntermNode *_replacement,
303                         bool _originalBecomesChildOfReplacement)
304             : parent(_parent),
305               original(_original),
306               replacement(_replacement),
307               originalBecomesChildOfReplacement(_originalBecomesChildOfReplacement)
308         {}
309 
310         TIntermNode *parent;
311         TIntermNode *original;
312         TIntermNode *replacement;
313         bool originalBecomesChildOfReplacement;
314     };
315 
316     struct ParentBlock
317     {
ParentBlockParentBlock318         ParentBlock(TIntermBlock *nodeIn, TIntermSequence::size_type posIn)
319             : node(nodeIn), pos(posIn)
320         {}
321 
322         TIntermBlock *node;
323         TIntermSequence::size_type pos;
324     };
325 
326     std::vector<NodeInsertMultipleEntry> mInsertions;
327     std::vector<NodeUpdateEntry> mReplacements;
328 
329     // All the nodes from root to the current node during traversing.
330     TVector<TIntermNode *> mPath;
331     // The current child of parent being traversed.
332     size_t mCurrentChildIndex;
333 
334     // All the code blocks from the root to the current node's parent during traversal.
335     std::vector<ParentBlock> mParentBlockStack;
336 };
337 
338 // Traverser parent class that tracks where a node is a destination of a write operation and so is
339 // required to be an l-value.
340 class TLValueTrackingTraverser : public TIntermTraverser
341 {
342   public:
343     TLValueTrackingTraverser(bool preVisit,
344                              bool inVisit,
345                              bool postVisit,
346                              TSymbolTable *symbolTable);
~TLValueTrackingTraverser()347     ~TLValueTrackingTraverser() override {}
348 
349     void traverseBinary(TIntermBinary *node) final;
350     void traverseUnary(TIntermUnary *node) final;
351     void traverseAggregate(TIntermAggregate *node) final;
352 
353   protected:
isLValueRequiredHere()354     bool isLValueRequiredHere() const
355     {
356         return mOperatorRequiresLValue || mInFunctionCallOutParameter;
357     }
358 
359   private:
360     // Track whether an l-value is required in the node that is currently being traversed by the
361     // surrounding operator.
362     // Use isLValueRequiredHere to check all conditions which require an l-value.
setOperatorRequiresLValue(bool lValueRequired)363     void setOperatorRequiresLValue(bool lValueRequired)
364     {
365         mOperatorRequiresLValue = lValueRequired;
366     }
operatorRequiresLValue()367     bool operatorRequiresLValue() const { return mOperatorRequiresLValue; }
368 
369     // Track whether an l-value is required inside a function call.
370     void setInFunctionCallOutParameter(bool inOutParameter);
371     bool isInFunctionCallOutParameter() const;
372 
373     bool mOperatorRequiresLValue;
374     bool mInFunctionCallOutParameter;
375 };
376 
377 }  // namespace sh
378 
379 #endif  // COMPILER_TRANSLATOR_TREEUTIL_INTERMTRAVERSE_H_
380