1 //
2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include "compiler/translator/tree_util/IntermTraverse.h"
8
9 #include "compiler/translator/Compiler.h"
10 #include "compiler/translator/InfoSink.h"
11 #include "compiler/translator/SymbolTable.h"
12 #include "compiler/translator/tree_util/IntermNode_util.h"
13
14 namespace sh
15 {
16
17 // Traverse the intermediate representation tree, and call a node type specific visit function for
18 // each node. Traversal is done recursively through the node member function traverse(). Nodes with
19 // children can have their whole subtree skipped if preVisit is turned on and the type specific
20 // function returns false.
21 template <typename T>
traverse(T * node)22 void TIntermTraverser::traverse(T *node)
23 {
24 ScopedNodeInTraversalPath addToPath(this, node);
25 if (!addToPath.isWithinDepthLimit())
26 return;
27
28 bool visit = true;
29
30 // Visit the node before children if pre-visiting.
31 if (preVisit)
32 visit = node->visit(PreVisit, this);
33
34 if (visit)
35 {
36 size_t childIndex = 0;
37 size_t childCount = node->getChildCount();
38
39 while (childIndex < childCount && visit)
40 {
41 node->getChildNode(childIndex)->traverse(this);
42 if (inVisit && childIndex != childCount - 1)
43 {
44 visit = node->visit(InVisit, this);
45 }
46 ++childIndex;
47 }
48
49 if (visit && postVisit)
50 node->visit(PostVisit, this);
51 }
52 }
53
54 // Instantiate template for RewriteAtomicFunctionExpressions, in case this gets inlined thus not
55 // exported from the TU.
56 template void TIntermTraverser::traverse(TIntermNode *);
57
traverse(TIntermTraverser * it)58 void TIntermNode::traverse(TIntermTraverser *it)
59 {
60 it->traverse(this);
61 }
62
traverse(TIntermTraverser * it)63 void TIntermSymbol::traverse(TIntermTraverser *it)
64 {
65 TIntermTraverser::ScopedNodeInTraversalPath addToPath(it, this);
66 it->visitSymbol(this);
67 }
68
traverse(TIntermTraverser * it)69 void TIntermConstantUnion::traverse(TIntermTraverser *it)
70 {
71 TIntermTraverser::ScopedNodeInTraversalPath addToPath(it, this);
72 it->visitConstantUnion(this);
73 }
74
traverse(TIntermTraverser * it)75 void TIntermFunctionPrototype::traverse(TIntermTraverser *it)
76 {
77 TIntermTraverser::ScopedNodeInTraversalPath addToPath(it, this);
78 it->visitFunctionPrototype(this);
79 }
80
traverse(TIntermTraverser * it)81 void TIntermBinary::traverse(TIntermTraverser *it)
82 {
83 it->traverseBinary(this);
84 }
85
traverse(TIntermTraverser * it)86 void TIntermUnary::traverse(TIntermTraverser *it)
87 {
88 it->traverseUnary(this);
89 }
90
traverse(TIntermTraverser * it)91 void TIntermFunctionDefinition::traverse(TIntermTraverser *it)
92 {
93 it->traverseFunctionDefinition(this);
94 }
95
traverse(TIntermTraverser * it)96 void TIntermBlock::traverse(TIntermTraverser *it)
97 {
98 it->traverseBlock(this);
99 }
100
traverse(TIntermTraverser * it)101 void TIntermAggregate::traverse(TIntermTraverser *it)
102 {
103 it->traverseAggregate(this);
104 }
105
traverse(TIntermTraverser * it)106 void TIntermLoop::traverse(TIntermTraverser *it)
107 {
108 it->traverseLoop(this);
109 }
110
traverse(TIntermTraverser * it)111 void TIntermPreprocessorDirective::traverse(TIntermTraverser *it)
112 {
113 it->visitPreprocessorDirective(this);
114 }
115
visit(Visit visit,TIntermTraverser * it)116 bool TIntermSymbol::visit(Visit visit, TIntermTraverser *it)
117 {
118 it->visitSymbol(this);
119 return false;
120 }
121
visit(Visit visit,TIntermTraverser * it)122 bool TIntermConstantUnion::visit(Visit visit, TIntermTraverser *it)
123 {
124 it->visitConstantUnion(this);
125 return false;
126 }
127
visit(Visit visit,TIntermTraverser * it)128 bool TIntermFunctionPrototype::visit(Visit visit, TIntermTraverser *it)
129 {
130 it->visitFunctionPrototype(this);
131 return false;
132 }
133
visit(Visit visit,TIntermTraverser * it)134 bool TIntermFunctionDefinition::visit(Visit visit, TIntermTraverser *it)
135 {
136 return it->visitFunctionDefinition(visit, this);
137 }
138
visit(Visit visit,TIntermTraverser * it)139 bool TIntermUnary::visit(Visit visit, TIntermTraverser *it)
140 {
141 return it->visitUnary(visit, this);
142 }
143
visit(Visit visit,TIntermTraverser * it)144 bool TIntermSwizzle::visit(Visit visit, TIntermTraverser *it)
145 {
146 return it->visitSwizzle(visit, this);
147 }
148
visit(Visit visit,TIntermTraverser * it)149 bool TIntermBinary::visit(Visit visit, TIntermTraverser *it)
150 {
151 return it->visitBinary(visit, this);
152 }
153
visit(Visit visit,TIntermTraverser * it)154 bool TIntermTernary::visit(Visit visit, TIntermTraverser *it)
155 {
156 return it->visitTernary(visit, this);
157 }
158
visit(Visit visit,TIntermTraverser * it)159 bool TIntermAggregate::visit(Visit visit, TIntermTraverser *it)
160 {
161 return it->visitAggregate(visit, this);
162 }
163
visit(Visit visit,TIntermTraverser * it)164 bool TIntermDeclaration::visit(Visit visit, TIntermTraverser *it)
165 {
166 return it->visitDeclaration(visit, this);
167 }
168
visit(Visit visit,TIntermTraverser * it)169 bool TIntermGlobalQualifierDeclaration::visit(Visit visit, TIntermTraverser *it)
170 {
171 return it->visitGlobalQualifierDeclaration(visit, this);
172 }
173
visit(Visit visit,TIntermTraverser * it)174 bool TIntermBlock::visit(Visit visit, TIntermTraverser *it)
175 {
176 return it->visitBlock(visit, this);
177 }
178
visit(Visit visit,TIntermTraverser * it)179 bool TIntermIfElse::visit(Visit visit, TIntermTraverser *it)
180 {
181 return it->visitIfElse(visit, this);
182 }
183
visit(Visit visit,TIntermTraverser * it)184 bool TIntermLoop::visit(Visit visit, TIntermTraverser *it)
185 {
186 return it->visitLoop(visit, this);
187 }
188
visit(Visit visit,TIntermTraverser * it)189 bool TIntermBranch::visit(Visit visit, TIntermTraverser *it)
190 {
191 return it->visitBranch(visit, this);
192 }
193
visit(Visit visit,TIntermTraverser * it)194 bool TIntermSwitch::visit(Visit visit, TIntermTraverser *it)
195 {
196 return it->visitSwitch(visit, this);
197 }
198
visit(Visit visit,TIntermTraverser * it)199 bool TIntermCase::visit(Visit visit, TIntermTraverser *it)
200 {
201 return it->visitCase(visit, this);
202 }
203
visit(Visit visit,TIntermTraverser * it)204 bool TIntermPreprocessorDirective::visit(Visit visit, TIntermTraverser *it)
205 {
206 it->visitPreprocessorDirective(this);
207 return false;
208 }
209
TIntermTraverser(bool preVisit,bool inVisit,bool postVisit,TSymbolTable * symbolTable)210 TIntermTraverser::TIntermTraverser(bool preVisit,
211 bool inVisit,
212 bool postVisit,
213 TSymbolTable *symbolTable)
214 : preVisit(preVisit),
215 inVisit(inVisit),
216 postVisit(postVisit),
217 mMaxDepth(0),
218 mMaxAllowedDepth(std::numeric_limits<int>::max()),
219 mInGlobalScope(true),
220 mSymbolTable(symbolTable)
221 {
222 // Only enabling inVisit is not supported.
223 ASSERT(!(inVisit && !preVisit && !postVisit));
224 }
225
~TIntermTraverser()226 TIntermTraverser::~TIntermTraverser() {}
227
setMaxAllowedDepth(int depth)228 void TIntermTraverser::setMaxAllowedDepth(int depth)
229 {
230 mMaxAllowedDepth = depth;
231 }
232
getParentBlock() const233 const TIntermBlock *TIntermTraverser::getParentBlock() const
234 {
235 if (!mParentBlockStack.empty())
236 {
237 return mParentBlockStack.back().node;
238 }
239 return nullptr;
240 }
241
pushParentBlock(TIntermBlock * node)242 void TIntermTraverser::pushParentBlock(TIntermBlock *node)
243 {
244 mParentBlockStack.push_back(ParentBlock(node, 0));
245 }
246
incrementParentBlockPos()247 void TIntermTraverser::incrementParentBlockPos()
248 {
249 ++mParentBlockStack.back().pos;
250 }
251
popParentBlock()252 void TIntermTraverser::popParentBlock()
253 {
254 ASSERT(!mParentBlockStack.empty());
255 mParentBlockStack.pop_back();
256 }
257
insertStatementsInParentBlock(const TIntermSequence & insertions)258 void TIntermTraverser::insertStatementsInParentBlock(const TIntermSequence &insertions)
259 {
260 TIntermSequence emptyInsertionsAfter;
261 insertStatementsInParentBlock(insertions, emptyInsertionsAfter);
262 }
263
insertStatementsInParentBlock(const TIntermSequence & insertionsBefore,const TIntermSequence & insertionsAfter)264 void TIntermTraverser::insertStatementsInParentBlock(const TIntermSequence &insertionsBefore,
265 const TIntermSequence &insertionsAfter)
266 {
267 ASSERT(!mParentBlockStack.empty());
268 ParentBlock &parentBlock = mParentBlockStack.back();
269 if (mPath.back() == parentBlock.node)
270 {
271 ASSERT(mParentBlockStack.size() >= 2u);
272 // The current node is a block node, so the parent block is not the topmost one in the block
273 // stack, but the one below that.
274 parentBlock = mParentBlockStack.at(mParentBlockStack.size() - 2u);
275 }
276 NodeInsertMultipleEntry insert(parentBlock.node, parentBlock.pos, insertionsBefore,
277 insertionsAfter);
278 mInsertions.push_back(insert);
279 }
280
insertStatementInParentBlock(TIntermNode * statement)281 void TIntermTraverser::insertStatementInParentBlock(TIntermNode *statement)
282 {
283 TIntermSequence insertions;
284 insertions.push_back(statement);
285 insertStatementsInParentBlock(insertions);
286 }
287
insertStatementsInBlockAtPosition(TIntermBlock * parent,size_t position,const TIntermSequence & insertionsBefore,const TIntermSequence & insertionsAfter)288 void TIntermTraverser::insertStatementsInBlockAtPosition(TIntermBlock *parent,
289 size_t position,
290 const TIntermSequence &insertionsBefore,
291 const TIntermSequence &insertionsAfter)
292 {
293 ASSERT(parent);
294 ASSERT(position >= 0);
295 ASSERT(position < parent->getChildCount());
296
297 mInsertions.emplace_back(parent, position, insertionsBefore, insertionsAfter);
298 }
299
setInFunctionCallOutParameter(bool inOutParameter)300 void TLValueTrackingTraverser::setInFunctionCallOutParameter(bool inOutParameter)
301 {
302 mInFunctionCallOutParameter = inOutParameter;
303 }
304
isInFunctionCallOutParameter() const305 bool TLValueTrackingTraverser::isInFunctionCallOutParameter() const
306 {
307 return mInFunctionCallOutParameter;
308 }
309
traverseBinary(TIntermBinary * node)310 void TIntermTraverser::traverseBinary(TIntermBinary *node)
311 {
312 traverse(node);
313 }
314
traverseBinary(TIntermBinary * node)315 void TLValueTrackingTraverser::traverseBinary(TIntermBinary *node)
316 {
317 ScopedNodeInTraversalPath addToPath(this, node);
318 if (!addToPath.isWithinDepthLimit())
319 return;
320
321 bool visit = true;
322
323 // visit the node before children if pre-visiting.
324 if (preVisit)
325 visit = node->visit(PreVisit, this);
326
327 // Visit the children, in the right order.
328 if (visit)
329 {
330 if (node->isAssignment())
331 {
332 ASSERT(!isLValueRequiredHere());
333 setOperatorRequiresLValue(true);
334 }
335
336 node->getLeft()->traverse(this);
337
338 if (node->isAssignment())
339 setOperatorRequiresLValue(false);
340
341 if (inVisit)
342 visit = node->visit(InVisit, this);
343
344 if (visit)
345 {
346 // Some binary operations like indexing can be inside an expression which must be an
347 // l-value.
348 bool parentOperatorRequiresLValue = operatorRequiresLValue();
349 bool parentInFunctionCallOutParameter = isInFunctionCallOutParameter();
350
351 // Index is not required to be an l-value even when the surrounding expression is
352 // required to be an l-value.
353 TOperator op = node->getOp();
354 if (op == EOpIndexDirect || op == EOpIndexDirectInterfaceBlock ||
355 op == EOpIndexDirectStruct || op == EOpIndexIndirect)
356 {
357 setOperatorRequiresLValue(false);
358 setInFunctionCallOutParameter(false);
359 }
360
361 node->getRight()->traverse(this);
362
363 setOperatorRequiresLValue(parentOperatorRequiresLValue);
364 setInFunctionCallOutParameter(parentInFunctionCallOutParameter);
365
366 // Visit the node after the children, if requested and the traversal
367 // hasn't been cancelled yet.
368 if (postVisit)
369 visit = node->visit(PostVisit, this);
370 }
371 }
372 }
373
traverseUnary(TIntermUnary * node)374 void TIntermTraverser::traverseUnary(TIntermUnary *node)
375 {
376 traverse(node);
377 }
378
traverseUnary(TIntermUnary * node)379 void TLValueTrackingTraverser::traverseUnary(TIntermUnary *node)
380 {
381 ScopedNodeInTraversalPath addToPath(this, node);
382 if (!addToPath.isWithinDepthLimit())
383 return;
384
385 bool visit = true;
386
387 if (preVisit)
388 visit = node->visit(PreVisit, this);
389
390 if (visit)
391 {
392 ASSERT(!operatorRequiresLValue());
393 switch (node->getOp())
394 {
395 case EOpPostIncrement:
396 case EOpPostDecrement:
397 case EOpPreIncrement:
398 case EOpPreDecrement:
399 setOperatorRequiresLValue(true);
400 break;
401 default:
402 break;
403 }
404
405 node->getOperand()->traverse(this);
406
407 setOperatorRequiresLValue(false);
408
409 if (postVisit)
410 visit = node->visit(PostVisit, this);
411 }
412 }
413
414 // Traverse a function definition node. This keeps track of global scope.
traverseFunctionDefinition(TIntermFunctionDefinition * node)415 void TIntermTraverser::traverseFunctionDefinition(TIntermFunctionDefinition *node)
416 {
417 ScopedNodeInTraversalPath addToPath(this, node);
418 if (!addToPath.isWithinDepthLimit())
419 return;
420
421 bool visit = true;
422
423 if (preVisit)
424 visit = node->visit(PreVisit, this);
425
426 if (visit)
427 {
428 node->getFunctionPrototype()->traverse(this);
429 if (inVisit)
430 visit = node->visit(InVisit, this);
431 if (visit)
432 {
433 mInGlobalScope = false;
434 node->getBody()->traverse(this);
435 mInGlobalScope = true;
436 if (postVisit)
437 visit = node->visit(PostVisit, this);
438 }
439 }
440 }
441
442 // Traverse a block node. This keeps track of the position of traversed child nodes within the block
443 // so that nodes may be inserted before or after them.
traverseBlock(TIntermBlock * node)444 void TIntermTraverser::traverseBlock(TIntermBlock *node)
445 {
446 ScopedNodeInTraversalPath addToPath(this, node);
447 if (!addToPath.isWithinDepthLimit())
448 return;
449
450 pushParentBlock(node);
451
452 bool visit = true;
453
454 TIntermSequence *sequence = node->getSequence();
455
456 if (preVisit)
457 visit = node->visit(PreVisit, this);
458
459 if (visit)
460 {
461 for (auto *child : *sequence)
462 {
463 if (visit)
464 {
465 child->traverse(this);
466 if (inVisit)
467 {
468 if (child != sequence->back())
469 visit = node->visit(InVisit, this);
470 }
471
472 incrementParentBlockPos();
473 }
474 }
475
476 if (visit && postVisit)
477 visit = node->visit(PostVisit, this);
478 }
479
480 popParentBlock();
481 }
482
traverseAggregate(TIntermAggregate * node)483 void TIntermTraverser::traverseAggregate(TIntermAggregate *node)
484 {
485 traverse(node);
486 }
487
CompareInsertion(const NodeInsertMultipleEntry & a,const NodeInsertMultipleEntry & b)488 bool TIntermTraverser::CompareInsertion(const NodeInsertMultipleEntry &a,
489 const NodeInsertMultipleEntry &b)
490 {
491 if (a.parent != b.parent)
492 {
493 return a.parent < b.parent;
494 }
495 return a.position < b.position;
496 }
497
updateTree(TCompiler * compiler,TIntermNode * node)498 bool TIntermTraverser::updateTree(TCompiler *compiler, TIntermNode *node)
499 {
500 // Sort the insertions so that insertion position is increasing and same position insertions are
501 // not reordered. The insertions are processed in reverse order so that multiple insertions to
502 // the same parent node are handled correctly.
503 std::stable_sort(mInsertions.begin(), mInsertions.end(), CompareInsertion);
504 for (size_t ii = 0; ii < mInsertions.size(); ++ii)
505 {
506 // If two insertions are to the same position, insert them in the order they were specified.
507 // The std::stable_sort call above will automatically guarantee this.
508 const NodeInsertMultipleEntry &insertion = mInsertions[mInsertions.size() - ii - 1];
509 ASSERT(insertion.parent);
510 if (!insertion.insertionsAfter.empty())
511 {
512 bool inserted = insertion.parent->insertChildNodes(insertion.position + 1,
513 insertion.insertionsAfter);
514 ASSERT(inserted);
515 }
516 if (!insertion.insertionsBefore.empty())
517 {
518 bool inserted =
519 insertion.parent->insertChildNodes(insertion.position, insertion.insertionsBefore);
520 ASSERT(inserted);
521 }
522 }
523 for (size_t ii = 0; ii < mReplacements.size(); ++ii)
524 {
525 const NodeUpdateEntry &replacement = mReplacements[ii];
526 ASSERT(replacement.parent);
527 bool replaced =
528 replacement.parent->replaceChildNode(replacement.original, replacement.replacement);
529 ASSERT(replaced);
530
531 if (!replacement.originalBecomesChildOfReplacement)
532 {
533 // In AST traversing, a parent is visited before its children.
534 // After we replace a node, if its immediate child is to
535 // be replaced, we need to make sure we don't update the replaced
536 // node; instead, we update the replacement node.
537 for (size_t jj = ii + 1; jj < mReplacements.size(); ++jj)
538 {
539 NodeUpdateEntry &replacement2 = mReplacements[jj];
540 if (replacement2.parent == replacement.original)
541 replacement2.parent = replacement.replacement;
542 }
543 }
544 }
545 for (size_t ii = 0; ii < mMultiReplacements.size(); ++ii)
546 {
547 const NodeReplaceWithMultipleEntry &replacement = mMultiReplacements[ii];
548 ASSERT(replacement.parent);
549 bool replaced = replacement.parent->replaceChildNodeWithMultiple(replacement.original,
550 replacement.replacements);
551 ASSERT(replaced);
552 }
553
554 clearReplacementQueue();
555
556 return compiler->validateAST(node);
557 }
558
clearReplacementQueue()559 void TIntermTraverser::clearReplacementQueue()
560 {
561 mReplacements.clear();
562 mMultiReplacements.clear();
563 mInsertions.clear();
564 }
565
queueReplacement(TIntermNode * replacement,OriginalNode originalStatus)566 void TIntermTraverser::queueReplacement(TIntermNode *replacement, OriginalNode originalStatus)
567 {
568 queueReplacementWithParent(getParentNode(), mPath.back(), replacement, originalStatus);
569 }
570
queueReplacementWithParent(TIntermNode * parent,TIntermNode * original,TIntermNode * replacement,OriginalNode originalStatus)571 void TIntermTraverser::queueReplacementWithParent(TIntermNode *parent,
572 TIntermNode *original,
573 TIntermNode *replacement,
574 OriginalNode originalStatus)
575 {
576 bool originalBecomesChild = (originalStatus == OriginalNode::BECOMES_CHILD);
577 mReplacements.push_back(NodeUpdateEntry(parent, original, replacement, originalBecomesChild));
578 }
579
TLValueTrackingTraverser(bool preVisitIn,bool inVisitIn,bool postVisitIn,TSymbolTable * symbolTable)580 TLValueTrackingTraverser::TLValueTrackingTraverser(bool preVisitIn,
581 bool inVisitIn,
582 bool postVisitIn,
583 TSymbolTable *symbolTable)
584 : TIntermTraverser(preVisitIn, inVisitIn, postVisitIn, symbolTable),
585 mOperatorRequiresLValue(false),
586 mInFunctionCallOutParameter(false)
587 {
588 ASSERT(symbolTable);
589 }
590
traverseAggregate(TIntermAggregate * node)591 void TLValueTrackingTraverser::traverseAggregate(TIntermAggregate *node)
592 {
593 ScopedNodeInTraversalPath addToPath(this, node);
594 if (!addToPath.isWithinDepthLimit())
595 return;
596
597 bool visit = true;
598
599 TIntermSequence *sequence = node->getSequence();
600
601 if (preVisit)
602 visit = node->visit(PreVisit, this);
603
604 if (visit)
605 {
606 size_t paramIndex = 0u;
607 for (auto *child : *sequence)
608 {
609 if (visit)
610 {
611 if (node->getFunction())
612 {
613 // Both built-ins and user defined functions should have the function symbol
614 // set.
615 ASSERT(paramIndex < node->getFunction()->getParamCount());
616 TQualifier qualifier =
617 node->getFunction()->getParam(paramIndex)->getType().getQualifier();
618 setInFunctionCallOutParameter(qualifier == EvqOut || qualifier == EvqInOut);
619 ++paramIndex;
620 }
621 else
622 {
623 ASSERT(node->isConstructor());
624 }
625 child->traverse(this);
626 if (inVisit)
627 {
628 if (child != sequence->back())
629 visit = node->visit(InVisit, this);
630 }
631 }
632 }
633 setInFunctionCallOutParameter(false);
634
635 if (visit && postVisit)
636 visit = node->visit(PostVisit, this);
637 }
638 }
639
traverseLoop(TIntermLoop * node)640 void TIntermTraverser::traverseLoop(TIntermLoop *node)
641 {
642 traverse(node);
643 }
644 } // namespace sh
645