• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include <algorithm>
8 
9 #include "compiler/translator/Compiler.h"
10 #include "compiler/translator/SymbolTable.h"
11 #include "compiler/translator/tree_util/AsNode.h"
12 #include "compiler/translator/tree_util/IntermRebuild.h"
13 
14 #define GUARD2(cond, failVal) \
15     do                        \
16     {                         \
17         if (!(cond))          \
18         {                     \
19             return failVal;   \
20         }                     \
21     } while (false)
22 
23 #define GUARD(cond) GUARD2(cond, nullptr)
24 
25 namespace sh
26 {
27 
28 template <typename T, typename U>
AllBits(T haystack,U needle)29 ANGLE_INLINE bool AllBits(T haystack, U needle)
30 {
31     return (haystack & needle) == needle;
32 }
33 
34 template <typename T, typename U>
AnyBits(T haystack,U needle)35 ANGLE_INLINE bool AnyBits(T haystack, U needle)
36 {
37     return (haystack & needle) != 0;
38 }
39 
40 ////////////////////////////////////////////////////////////////////////////////
41 
BaseResult(BaseResult & other)42 TIntermRebuild::BaseResult::BaseResult(BaseResult &other)
43     : mAction(other.mAction),
44       mVisit(other.mVisit),
45       mSingle(other.mSingle),
46       mMulti(std::move(other.mMulti))
47 {}
48 
BaseResult(TIntermNode & node,VisitBits visit)49 TIntermRebuild::BaseResult::BaseResult(TIntermNode &node, VisitBits visit)
50     : mAction(Action::ReplaceSingle), mVisit(visit), mSingle(&node)
51 {}
52 
BaseResult(TIntermNode * node,VisitBits visit)53 TIntermRebuild::BaseResult::BaseResult(TIntermNode *node, VisitBits visit)
54     : mAction(node ? Action::ReplaceSingle : Action::Drop),
55       mVisit(node ? visit : VisitBits::Neither),
56       mSingle(node)
57 {}
58 
BaseResult(std::nullptr_t)59 TIntermRebuild::BaseResult::BaseResult(std::nullptr_t)
60     : mAction(Action::Drop), mVisit(VisitBits::Neither), mSingle(nullptr)
61 {}
62 
BaseResult(Fail)63 TIntermRebuild::BaseResult::BaseResult(Fail)
64     : mAction(Action::Fail), mVisit(VisitBits::Neither), mSingle(nullptr)
65 {}
66 
BaseResult(std::vector<TIntermNode * > && nodes)67 TIntermRebuild::BaseResult::BaseResult(std::vector<TIntermNode *> &&nodes)
68     : mAction(Action::ReplaceMulti),
69       mVisit(VisitBits::Neither),
70       mSingle(nullptr),
71       mMulti(std::move(nodes))
72 {}
73 
moveAssignImpl(BaseResult & other)74 void TIntermRebuild::BaseResult::moveAssignImpl(BaseResult &other)
75 {
76     mAction = other.mAction;
77     mVisit  = other.mVisit;
78     mSingle = other.mSingle;
79     mMulti  = std::move(other.mMulti);
80 }
81 
Multi(std::vector<TIntermNode * > && nodes)82 TIntermRebuild::BaseResult TIntermRebuild::BaseResult::Multi(std::vector<TIntermNode *> &&nodes)
83 {
84     auto it = std::remove(nodes.begin(), nodes.end(), nullptr);
85     nodes.erase(it, nodes.end());
86     return std::move(nodes);
87 }
88 
isFail() const89 bool TIntermRebuild::BaseResult::isFail() const
90 {
91     return mAction == Action::Fail;
92 }
93 
isDrop() const94 bool TIntermRebuild::BaseResult::isDrop() const
95 {
96     return mAction == Action::Drop;
97 }
98 
single() const99 TIntermNode *TIntermRebuild::BaseResult::single() const
100 {
101     return mSingle;
102 }
103 
multi() const104 const std::vector<TIntermNode *> *TIntermRebuild::BaseResult::multi() const
105 {
106     if (mAction == Action::ReplaceMulti)
107     {
108         return &mMulti;
109     }
110     return nullptr;
111 }
112 
113 ////////////////////////////////////////////////////////////////////////////////
114 
115 using PreResult = TIntermRebuild::PreResult;
116 
PreResult(TIntermNode & node,VisitBits visit)117 PreResult::PreResult(TIntermNode &node, VisitBits visit) : BaseResult(node, visit) {}
PreResult(TIntermNode * node,VisitBits visit)118 PreResult::PreResult(TIntermNode *node, VisitBits visit) : BaseResult(node, visit) {}
PreResult(std::nullptr_t)119 PreResult::PreResult(std::nullptr_t) : BaseResult(nullptr) {}
PreResult(Fail)120 PreResult::PreResult(Fail) : BaseResult(Fail()) {}
121 
PreResult(BaseResult && other)122 PreResult::PreResult(BaseResult &&other) : BaseResult(other) {}
PreResult(PreResult && other)123 PreResult::PreResult(PreResult &&other) : BaseResult(other) {}
124 
operator =(PreResult && other)125 void PreResult::operator=(PreResult &&other)
126 {
127     moveAssignImpl(other);
128 }
129 
130 ////////////////////////////////////////////////////////////////////////////////
131 
132 using PostResult = TIntermRebuild::PostResult;
133 
PostResult(TIntermNode & node)134 PostResult::PostResult(TIntermNode &node) : BaseResult(node, VisitBits::Neither) {}
PostResult(TIntermNode * node)135 PostResult::PostResult(TIntermNode *node) : BaseResult(node, VisitBits::Neither) {}
PostResult(std::nullptr_t)136 PostResult::PostResult(std::nullptr_t) : BaseResult(nullptr) {}
PostResult(Fail)137 PostResult::PostResult(Fail) : BaseResult(Fail()) {}
138 
PostResult(PostResult && other)139 PostResult::PostResult(PostResult &&other) : BaseResult(other) {}
PostResult(BaseResult && other)140 PostResult::PostResult(BaseResult &&other) : BaseResult(other) {}
141 
operator =(PostResult && other)142 void PostResult::operator=(PostResult &&other)
143 {
144     moveAssignImpl(other);
145 }
146 
147 ////////////////////////////////////////////////////////////////////////////////
148 
TIntermRebuild(TCompiler & compiler,bool preVisit,bool postVisit)149 TIntermRebuild::TIntermRebuild(TCompiler &compiler, bool preVisit, bool postVisit)
150     : mCompiler(compiler),
151       mSymbolTable(compiler.getSymbolTable()),
152       mPreVisit(preVisit),
153       mPostVisit(postVisit)
154 {
155     ASSERT(preVisit || postVisit);
156 }
157 
~TIntermRebuild()158 TIntermRebuild::~TIntermRebuild()
159 {
160     ASSERT(!mNodeStack.value);
161     ASSERT(!mNodeStack.tail);
162 }
163 
getParentFunction() const164 const TFunction *TIntermRebuild::getParentFunction() const
165 {
166     return mParentFunc;
167 }
168 
getParentNode(size_t offset) const169 TIntermNode *TIntermRebuild::getParentNode(size_t offset) const
170 {
171     ASSERT(mNodeStack.tail);
172     auto parent = *mNodeStack.tail;
173     while (offset > 0)
174     {
175         --offset;
176         ASSERT(parent.tail);
177         parent = *parent.tail;
178     }
179     return parent.value;
180 }
181 
rebuildRoot(TIntermBlock & root)182 bool TIntermRebuild::rebuildRoot(TIntermBlock &root)
183 {
184     if (!rebuildInPlace(root))
185     {
186         return false;
187     }
188     return mCompiler.validateAST(&root);
189 }
190 
rebuildInPlace(TIntermAggregate & node)191 bool TIntermRebuild::rebuildInPlace(TIntermAggregate &node)
192 {
193     return rebuildInPlaceImpl(node);
194 }
195 
rebuildInPlace(TIntermBlock & node)196 bool TIntermRebuild::rebuildInPlace(TIntermBlock &node)
197 {
198     return rebuildInPlaceImpl(node);
199 }
200 
rebuildInPlace(TIntermDeclaration & node)201 bool TIntermRebuild::rebuildInPlace(TIntermDeclaration &node)
202 {
203     return rebuildInPlaceImpl(node);
204 }
205 
206 template <typename Node>
rebuildInPlaceImpl(Node & node)207 bool TIntermRebuild::rebuildInPlaceImpl(Node &node)
208 {
209     auto *newNode = traverseAnyAs<Node>(node);
210     if (!newNode)
211     {
212         return false;
213     }
214 
215     if (newNode != &node)
216     {
217         *node.getSequence() = std::move(*newNode->getSequence());
218     }
219 
220     return true;
221 }
222 
rebuild(TIntermNode & node)223 PostResult TIntermRebuild::rebuild(TIntermNode &node)
224 {
225     return traverseAny(node);
226 }
227 
228 ////////////////////////////////////////////////////////////////////////////////
229 
230 template <typename Node>
traverseAnyAs(TIntermNode & node)231 Node *TIntermRebuild::traverseAnyAs(TIntermNode &node)
232 {
233     PostResult result(traverseAny(node));
234     if (result.mAction == Action::Fail || !result.mSingle)
235     {
236         return nullptr;
237     }
238     return asNode<Node>(result.mSingle);
239 }
240 
241 template <typename Node>
traverseAnyAs(TIntermNode & node,Node * & out)242 bool TIntermRebuild::traverseAnyAs(TIntermNode &node, Node *&out)
243 {
244     PostResult result(traverseAny(node));
245     if (result.mAction == Action::Fail || result.mAction == Action::ReplaceMulti)
246     {
247         return false;
248     }
249     if (!result.mSingle)
250     {
251         return true;
252     }
253     out = asNode<Node>(result.mSingle);
254     return out != nullptr;
255 }
256 
traverseAggregateBaseChildren(TIntermAggregateBase & node)257 bool TIntermRebuild::traverseAggregateBaseChildren(TIntermAggregateBase &node)
258 {
259     auto *const children = node.getSequence();
260     ASSERT(children);
261     TIntermSequence newChildren;
262 
263     for (TIntermNode *child : *children)
264     {
265         ASSERT(child);
266         PostResult result(traverseAny(*child));
267 
268         switch (result.mAction)
269         {
270             case Action::ReplaceSingle:
271                 newChildren.push_back(result.mSingle);
272                 break;
273 
274             case Action::ReplaceMulti:
275                 for (TIntermNode *newNode : result.mMulti)
276                 {
277                     if (newNode)
278                     {
279                         newChildren.push_back(newNode);
280                     }
281                 }
282                 break;
283 
284             case Action::Drop:
285                 break;
286 
287             case Action::Fail:
288                 return false;
289 
290             default:
291                 ASSERT(false);
292                 return false;
293         }
294     }
295 
296     *children = std::move(newChildren);
297 
298     return true;
299 }
300 
301 ////////////////////////////////////////////////////////////////////////////////
302 
303 struct TIntermRebuild::NodeStackGuard
304 {
305     ConsList<TIntermNode *> oldNodeStack;
306     ConsList<TIntermNode *> &nodeStack;
NodeStackGuardsh::TIntermRebuild::NodeStackGuard307     NodeStackGuard(ConsList<TIntermNode *> &nodeStack)
308         : oldNodeStack(nodeStack), nodeStack(nodeStack)
309     {}
~NodeStackGuardsh::TIntermRebuild::NodeStackGuard310     ~NodeStackGuard() { nodeStack = oldNodeStack; }
311 };
312 
traverseAny(TIntermNode & originalNode)313 PostResult TIntermRebuild::traverseAny(TIntermNode &originalNode)
314 {
315     PreResult preResult = traversePre(originalNode);
316     if (!preResult.mSingle)
317     {
318         ASSERT(preResult.mVisit == VisitBits::Neither);
319         return std::move(preResult);
320     }
321 
322     TIntermNode *currNode       = preResult.mSingle;
323     const VisitBits visit       = preResult.mVisit;
324     const NodeType currNodeType = getNodeType(*currNode);
325 
326     currNode = traverseChildren(currNodeType, originalNode, *currNode, visit);
327     if (!currNode)
328     {
329         return Fail();
330     }
331 
332     return traversePost(currNodeType, originalNode, *currNode, visit);
333 }
334 
traversePre(TIntermNode & originalNode)335 PreResult TIntermRebuild::traversePre(TIntermNode &originalNode)
336 {
337     if (!mPreVisit)
338     {
339         return {originalNode, VisitBits::Both};
340     }
341 
342     NodeStackGuard guard(mNodeStack);
343     mNodeStack = {&originalNode, &guard.oldNodeStack};
344 
345     const NodeType originalNodeType = getNodeType(originalNode);
346 
347     switch (originalNodeType)
348     {
349         case NodeType::Unknown:
350             ASSERT(false);
351             return Fail();
352         case NodeType::Symbol:
353             return visitSymbolPre(*originalNode.getAsSymbolNode());
354         case NodeType::ConstantUnion:
355             return visitConstantUnionPre(*originalNode.getAsConstantUnion());
356         case NodeType::FunctionPrototype:
357             return visitFunctionPrototypePre(*originalNode.getAsFunctionPrototypeNode());
358         case NodeType::PreprocessorDirective:
359             return visitPreprocessorDirectivePre(*originalNode.getAsPreprocessorDirective());
360         case NodeType::Unary:
361             return visitUnaryPre(*originalNode.getAsUnaryNode());
362         case NodeType::Binary:
363             return visitBinaryPre(*originalNode.getAsBinaryNode());
364         case NodeType::Ternary:
365             return visitTernaryPre(*originalNode.getAsTernaryNode());
366         case NodeType::Swizzle:
367             return visitSwizzlePre(*originalNode.getAsSwizzleNode());
368         case NodeType::IfElse:
369             return visitIfElsePre(*originalNode.getAsIfElseNode());
370         case NodeType::Switch:
371             return visitSwitchPre(*originalNode.getAsSwitchNode());
372         case NodeType::Case:
373             return visitCasePre(*originalNode.getAsCaseNode());
374         case NodeType::FunctionDefinition:
375             return visitFunctionDefinitionPre(*originalNode.getAsFunctionDefinition());
376         case NodeType::Aggregate:
377             return visitAggregatePre(*originalNode.getAsAggregate());
378         case NodeType::Block:
379             return visitBlockPre(*originalNode.getAsBlock());
380         case NodeType::GlobalQualifierDeclaration:
381             return visitGlobalQualifierDeclarationPre(
382                 *originalNode.getAsGlobalQualifierDeclarationNode());
383         case NodeType::Declaration:
384             return visitDeclarationPre(*originalNode.getAsDeclarationNode());
385         case NodeType::Loop:
386             return visitLoopPre(*originalNode.getAsLoopNode());
387         case NodeType::Branch:
388             return visitBranchPre(*originalNode.getAsBranchNode());
389         default:
390             ASSERT(false);
391             return Fail();
392     }
393 }
394 
traverseChildren(NodeType currNodeType,const TIntermNode & originalNode,TIntermNode & currNode,VisitBits visit)395 TIntermNode *TIntermRebuild::traverseChildren(NodeType currNodeType,
396                                               const TIntermNode &originalNode,
397                                               TIntermNode &currNode,
398                                               VisitBits visit)
399 {
400     if (!AnyBits(visit, VisitBits::Children))
401     {
402         return &currNode;
403     }
404 
405     if (AnyBits(visit, VisitBits::ChildrenRequiresSame) && &originalNode != &currNode)
406     {
407         return &currNode;
408     }
409 
410     NodeStackGuard guard(mNodeStack);
411     mNodeStack = {&currNode, &guard.oldNodeStack};
412 
413     switch (currNodeType)
414     {
415         case NodeType::Unknown:
416             ASSERT(false);
417             return nullptr;
418         case NodeType::Symbol:
419             return &currNode;
420         case NodeType::ConstantUnion:
421             return &currNode;
422         case NodeType::FunctionPrototype:
423             return &currNode;
424         case NodeType::PreprocessorDirective:
425             return &currNode;
426         case NodeType::Unary:
427             return traverseUnaryChildren(*currNode.getAsUnaryNode());
428         case NodeType::Binary:
429             return traverseBinaryChildren(*currNode.getAsBinaryNode());
430         case NodeType::Ternary:
431             return traverseTernaryChildren(*currNode.getAsTernaryNode());
432         case NodeType::Swizzle:
433             return traverseSwizzleChildren(*currNode.getAsSwizzleNode());
434         case NodeType::IfElse:
435             return traverseIfElseChildren(*currNode.getAsIfElseNode());
436         case NodeType::Switch:
437             return traverseSwitchChildren(*currNode.getAsSwitchNode());
438         case NodeType::Case:
439             return traverseCaseChildren(*currNode.getAsCaseNode());
440         case NodeType::FunctionDefinition:
441             return traverseFunctionDefinitionChildren(*currNode.getAsFunctionDefinition());
442         case NodeType::Aggregate:
443             return traverseAggregateChildren(*currNode.getAsAggregate());
444         case NodeType::Block:
445             return traverseBlockChildren(*currNode.getAsBlock());
446         case NodeType::GlobalQualifierDeclaration:
447             return traverseGlobalQualifierDeclarationChildren(
448                 *currNode.getAsGlobalQualifierDeclarationNode());
449         case NodeType::Declaration:
450             return traverseDeclarationChildren(*currNode.getAsDeclarationNode());
451         case NodeType::Loop:
452             return traverseLoopChildren(*currNode.getAsLoopNode());
453         case NodeType::Branch:
454             return traverseBranchChildren(*currNode.getAsBranchNode());
455         default:
456             ASSERT(false);
457             return nullptr;
458     }
459 }
460 
traversePost(NodeType currNodeType,const TIntermNode & originalNode,TIntermNode & currNode,VisitBits visit)461 PostResult TIntermRebuild::traversePost(NodeType currNodeType,
462                                         const TIntermNode &originalNode,
463                                         TIntermNode &currNode,
464                                         VisitBits visit)
465 {
466     if (!mPostVisit)
467     {
468         return currNode;
469     }
470 
471     if (!AnyBits(visit, VisitBits::Post))
472     {
473         return currNode;
474     }
475 
476     if (AnyBits(visit, VisitBits::PostRequiresSame) && &originalNode != &currNode)
477     {
478         return currNode;
479     }
480 
481     NodeStackGuard guard(mNodeStack);
482     mNodeStack = {&currNode, &guard.oldNodeStack};
483 
484     switch (currNodeType)
485     {
486         case NodeType::Unknown:
487             ASSERT(false);
488             return Fail();
489         case NodeType::Symbol:
490             return visitSymbolPost(*currNode.getAsSymbolNode());
491         case NodeType::ConstantUnion:
492             return visitConstantUnionPost(*currNode.getAsConstantUnion());
493         case NodeType::FunctionPrototype:
494             return visitFunctionPrototypePost(*currNode.getAsFunctionPrototypeNode());
495         case NodeType::PreprocessorDirective:
496             return visitPreprocessorDirectivePost(*currNode.getAsPreprocessorDirective());
497         case NodeType::Unary:
498             return visitUnaryPost(*currNode.getAsUnaryNode());
499         case NodeType::Binary:
500             return visitBinaryPost(*currNode.getAsBinaryNode());
501         case NodeType::Ternary:
502             return visitTernaryPost(*currNode.getAsTernaryNode());
503         case NodeType::Swizzle:
504             return visitSwizzlePost(*currNode.getAsSwizzleNode());
505         case NodeType::IfElse:
506             return visitIfElsePost(*currNode.getAsIfElseNode());
507         case NodeType::Switch:
508             return visitSwitchPost(*currNode.getAsSwitchNode());
509         case NodeType::Case:
510             return visitCasePost(*currNode.getAsCaseNode());
511         case NodeType::FunctionDefinition:
512             return visitFunctionDefinitionPost(*currNode.getAsFunctionDefinition());
513         case NodeType::Aggregate:
514             return visitAggregatePost(*currNode.getAsAggregate());
515         case NodeType::Block:
516             return visitBlockPost(*currNode.getAsBlock());
517         case NodeType::GlobalQualifierDeclaration:
518             return visitGlobalQualifierDeclarationPost(
519                 *currNode.getAsGlobalQualifierDeclarationNode());
520         case NodeType::Declaration:
521             return visitDeclarationPost(*currNode.getAsDeclarationNode());
522         case NodeType::Loop:
523             return visitLoopPost(*currNode.getAsLoopNode());
524         case NodeType::Branch:
525             return visitBranchPost(*currNode.getAsBranchNode());
526         default:
527             ASSERT(false);
528             return Fail();
529     }
530 }
531 
532 ////////////////////////////////////////////////////////////////////////////////
533 
traverseAggregateChildren(TIntermAggregate & node)534 TIntermNode *TIntermRebuild::traverseAggregateChildren(TIntermAggregate &node)
535 {
536     if (traverseAggregateBaseChildren(node))
537     {
538         return &node;
539     }
540     return nullptr;
541 }
542 
traverseBlockChildren(TIntermBlock & node)543 TIntermNode *TIntermRebuild::traverseBlockChildren(TIntermBlock &node)
544 {
545     if (traverseAggregateBaseChildren(node))
546     {
547         return &node;
548     }
549     return nullptr;
550 }
551 
traverseDeclarationChildren(TIntermDeclaration & node)552 TIntermNode *TIntermRebuild::traverseDeclarationChildren(TIntermDeclaration &node)
553 {
554     if (traverseAggregateBaseChildren(node))
555     {
556         return &node;
557     }
558     return nullptr;
559 }
560 
traverseSwizzleChildren(TIntermSwizzle & node)561 TIntermNode *TIntermRebuild::traverseSwizzleChildren(TIntermSwizzle &node)
562 {
563     auto *const operand = node.getOperand();
564     ASSERT(operand);
565 
566     auto *newOperand = traverseAnyAs<TIntermTyped>(*operand);
567     GUARD(newOperand);
568 
569     if (newOperand != operand)
570     {
571         return new TIntermSwizzle(newOperand, node.getSwizzleOffsets());
572     }
573 
574     return &node;
575 }
576 
traverseBinaryChildren(TIntermBinary & node)577 TIntermNode *TIntermRebuild::traverseBinaryChildren(TIntermBinary &node)
578 {
579     auto *const left = node.getLeft();
580     ASSERT(left);
581     auto *const right = node.getRight();
582     ASSERT(right);
583 
584     auto *const newLeft = traverseAnyAs<TIntermTyped>(*left);
585     GUARD(newLeft);
586     auto *const newRight = traverseAnyAs<TIntermTyped>(*right);
587     GUARD(newRight);
588 
589     if (newLeft != left || newRight != right)
590     {
591         TOperator op = node.getOp();
592         switch (op)
593         {
594             case TOperator::EOpIndexDirectStruct:
595             {
596                 if (newLeft->getType().getInterfaceBlock())
597                 {
598                     op = TOperator::EOpIndexDirectInterfaceBlock;
599                 }
600             }
601             break;
602 
603             case TOperator::EOpIndexDirectInterfaceBlock:
604             {
605                 if (newLeft->getType().getStruct())
606                 {
607                     op = TOperator::EOpIndexDirectStruct;
608                 }
609             }
610             break;
611 
612             case TOperator::EOpComma:
613                 return TIntermBinary::CreateComma(newLeft, newRight, mCompiler.getShaderVersion());
614 
615             default:
616                 break;
617         }
618 
619         return new TIntermBinary(op, newLeft, newRight);
620     }
621 
622     return &node;
623 }
624 
traverseUnaryChildren(TIntermUnary & node)625 TIntermNode *TIntermRebuild::traverseUnaryChildren(TIntermUnary &node)
626 {
627     auto *const operand = node.getOperand();
628     ASSERT(operand);
629 
630     auto *const newOperand = traverseAnyAs<TIntermTyped>(*operand);
631     GUARD(newOperand);
632 
633     if (newOperand != operand)
634     {
635         return new TIntermUnary(node.getOp(), newOperand, node.getFunction());
636     }
637 
638     return &node;
639 }
640 
traverseTernaryChildren(TIntermTernary & node)641 TIntermNode *TIntermRebuild::traverseTernaryChildren(TIntermTernary &node)
642 {
643     auto *const cond = node.getCondition();
644     ASSERT(cond);
645     auto *const true_ = node.getTrueExpression();
646     ASSERT(true_);
647     auto *const false_ = node.getFalseExpression();
648     ASSERT(false_);
649 
650     auto *const newCond = traverseAnyAs<TIntermTyped>(*cond);
651     GUARD(newCond);
652     auto *const newTrue = traverseAnyAs<TIntermTyped>(*true_);
653     GUARD(newTrue);
654     auto *const newFalse = traverseAnyAs<TIntermTyped>(*false_);
655     GUARD(newFalse);
656 
657     if (newCond != cond || newTrue != true_ || newFalse != false_)
658     {
659         return new TIntermTernary(newCond, newTrue, newFalse);
660     }
661 
662     return &node;
663 }
664 
traverseIfElseChildren(TIntermIfElse & node)665 TIntermNode *TIntermRebuild::traverseIfElseChildren(TIntermIfElse &node)
666 {
667     auto *const cond = node.getCondition();
668     ASSERT(cond);
669     auto *const true_  = node.getTrueBlock();
670     auto *const false_ = node.getFalseBlock();
671 
672     auto *const newCond = traverseAnyAs<TIntermTyped>(*cond);
673     GUARD(newCond);
674     TIntermBlock *newTrue = nullptr;
675     if (true_)
676     {
677         GUARD(traverseAnyAs(*true_, newTrue));
678     }
679     TIntermBlock *newFalse = nullptr;
680     if (false_)
681     {
682         GUARD(traverseAnyAs(*false_, newFalse));
683     }
684 
685     if (newCond != cond || newTrue != true_ || newFalse != false_)
686     {
687         return new TIntermIfElse(newCond, newTrue, newFalse);
688     }
689 
690     return &node;
691 }
692 
traverseSwitchChildren(TIntermSwitch & node)693 TIntermNode *TIntermRebuild::traverseSwitchChildren(TIntermSwitch &node)
694 {
695     auto *const init = node.getInit();
696     ASSERT(init);
697     auto *const stmts = node.getStatementList();
698     ASSERT(stmts);
699 
700     auto *const newInit = traverseAnyAs<TIntermTyped>(*init);
701     GUARD(newInit);
702     auto *const newStmts = traverseAnyAs<TIntermBlock>(*stmts);
703     GUARD(newStmts);
704 
705     if (newInit != init || newStmts != stmts)
706     {
707         return new TIntermSwitch(newInit, newStmts);
708     }
709 
710     return &node;
711 }
712 
traverseCaseChildren(TIntermCase & node)713 TIntermNode *TIntermRebuild::traverseCaseChildren(TIntermCase &node)
714 {
715     auto *const cond = node.getCondition();
716 
717     TIntermTyped *newCond = nullptr;
718     if (cond)
719     {
720         GUARD(traverseAnyAs(*cond, newCond));
721     }
722 
723     if (newCond != cond)
724     {
725         return new TIntermCase(newCond);
726     }
727 
728     return &node;
729 }
730 
traverseFunctionDefinitionChildren(TIntermFunctionDefinition & node)731 TIntermNode *TIntermRebuild::traverseFunctionDefinitionChildren(TIntermFunctionDefinition &node)
732 {
733     GUARD(!mParentFunc);  // Function definitions cannot be nested.
734     mParentFunc = node.getFunction();
735     struct OnExit
736     {
737         const TFunction *&parentFunc;
738         OnExit(const TFunction *&parentFunc) : parentFunc(parentFunc) {}
739         ~OnExit() { parentFunc = nullptr; }
740     } onExit(mParentFunc);
741 
742     auto *const proto = node.getFunctionPrototype();
743     ASSERT(proto);
744     auto *const body = node.getBody();
745     ASSERT(body);
746 
747     auto *const newProto = traverseAnyAs<TIntermFunctionPrototype>(*proto);
748     GUARD(newProto);
749     auto *const newBody = traverseAnyAs<TIntermBlock>(*body);
750     GUARD(newBody);
751 
752     if (newProto != proto || newBody != body)
753     {
754         return new TIntermFunctionDefinition(newProto, newBody);
755     }
756 
757     return &node;
758 }
759 
traverseGlobalQualifierDeclarationChildren(TIntermGlobalQualifierDeclaration & node)760 TIntermNode *TIntermRebuild::traverseGlobalQualifierDeclarationChildren(
761     TIntermGlobalQualifierDeclaration &node)
762 {
763     auto *const symbol = node.getSymbol();
764     ASSERT(symbol);
765 
766     auto *const newSymbol = traverseAnyAs<TIntermSymbol>(*symbol);
767     GUARD(newSymbol);
768 
769     if (newSymbol != symbol)
770     {
771         return new TIntermGlobalQualifierDeclaration(newSymbol, node.isPrecise(), node.getLine());
772     }
773 
774     return &node;
775 }
776 
traverseLoopChildren(TIntermLoop & node)777 TIntermNode *TIntermRebuild::traverseLoopChildren(TIntermLoop &node)
778 {
779     const TLoopType loopType = node.getType();
780 
781     auto *const init = node.getInit();
782     auto *const cond = node.getCondition();
783     auto *const expr = node.getExpression();
784     auto *const body = node.getBody();
785     ASSERT(body);
786 
787 #if defined(ANGLE_ENABLE_ASSERTS)
788     switch (loopType)
789     {
790         case TLoopType::ELoopFor:
791             break;
792         case TLoopType::ELoopWhile:
793         case TLoopType::ELoopDoWhile:
794             ASSERT(cond);
795             ASSERT(!init && !expr);
796             break;
797         default:
798             ASSERT(false);
799             break;
800     }
801 #endif
802 
803     auto *const newBody = traverseAnyAs<TIntermBlock>(*body);
804     GUARD(newBody);
805     TIntermNode *newInit = nullptr;
806     if (init)
807     {
808         GUARD(traverseAnyAs(*init, newInit));
809     }
810     TIntermTyped *newCond = nullptr;
811     if (cond)
812     {
813         GUARD(traverseAnyAs(*cond, newCond));
814     }
815     TIntermTyped *newExpr = nullptr;
816     if (expr)
817     {
818         GUARD(traverseAnyAs(*expr, newExpr));
819     }
820 
821     if (newInit != init || newCond != cond || newExpr != expr || newBody != body)
822     {
823         switch (loopType)
824         {
825             case TLoopType::ELoopFor:
826                 GUARD(newBody);
827                 break;
828             case TLoopType::ELoopWhile:
829             case TLoopType::ELoopDoWhile:
830                 GUARD(newCond && newBody);
831                 GUARD(!newInit && !newExpr);
832                 break;
833             default:
834                 ASSERT(false);
835                 break;
836         }
837         return new TIntermLoop(loopType, newInit, newCond, newExpr, newBody);
838     }
839 
840     return &node;
841 }
842 
traverseBranchChildren(TIntermBranch & node)843 TIntermNode *TIntermRebuild::traverseBranchChildren(TIntermBranch &node)
844 {
845     auto *const expr = node.getExpression();
846 
847     TIntermTyped *newExpr = nullptr;
848     if (expr)
849     {
850         GUARD(traverseAnyAs<TIntermTyped>(*expr, newExpr));
851     }
852 
853     if (newExpr != expr)
854     {
855         return new TIntermBranch(node.getFlowOp(), newExpr);
856     }
857 
858     return &node;
859 }
860 
861 ////////////////////////////////////////////////////////////////////////////////
862 
visitSymbolPre(TIntermSymbol & node)863 PreResult TIntermRebuild::visitSymbolPre(TIntermSymbol &node)
864 {
865     return {node, VisitBits::Both};
866 }
867 
visitConstantUnionPre(TIntermConstantUnion & node)868 PreResult TIntermRebuild::visitConstantUnionPre(TIntermConstantUnion &node)
869 {
870     return {node, VisitBits::Both};
871 }
872 
visitFunctionPrototypePre(TIntermFunctionPrototype & node)873 PreResult TIntermRebuild::visitFunctionPrototypePre(TIntermFunctionPrototype &node)
874 {
875     return {node, VisitBits::Both};
876 }
877 
visitPreprocessorDirectivePre(TIntermPreprocessorDirective & node)878 PreResult TIntermRebuild::visitPreprocessorDirectivePre(TIntermPreprocessorDirective &node)
879 {
880     return {node, VisitBits::Both};
881 }
882 
visitUnaryPre(TIntermUnary & node)883 PreResult TIntermRebuild::visitUnaryPre(TIntermUnary &node)
884 {
885     return {node, VisitBits::Both};
886 }
887 
visitBinaryPre(TIntermBinary & node)888 PreResult TIntermRebuild::visitBinaryPre(TIntermBinary &node)
889 {
890     return {node, VisitBits::Both};
891 }
892 
visitTernaryPre(TIntermTernary & node)893 PreResult TIntermRebuild::visitTernaryPre(TIntermTernary &node)
894 {
895     return {node, VisitBits::Both};
896 }
897 
visitSwizzlePre(TIntermSwizzle & node)898 PreResult TIntermRebuild::visitSwizzlePre(TIntermSwizzle &node)
899 {
900     return {node, VisitBits::Both};
901 }
902 
visitIfElsePre(TIntermIfElse & node)903 PreResult TIntermRebuild::visitIfElsePre(TIntermIfElse &node)
904 {
905     return {node, VisitBits::Both};
906 }
907 
visitSwitchPre(TIntermSwitch & node)908 PreResult TIntermRebuild::visitSwitchPre(TIntermSwitch &node)
909 {
910     return {node, VisitBits::Both};
911 }
912 
visitCasePre(TIntermCase & node)913 PreResult TIntermRebuild::visitCasePre(TIntermCase &node)
914 {
915     return {node, VisitBits::Both};
916 }
917 
visitLoopPre(TIntermLoop & node)918 PreResult TIntermRebuild::visitLoopPre(TIntermLoop &node)
919 {
920     return {node, VisitBits::Both};
921 }
922 
visitBranchPre(TIntermBranch & node)923 PreResult TIntermRebuild::visitBranchPre(TIntermBranch &node)
924 {
925     return {node, VisitBits::Both};
926 }
927 
visitDeclarationPre(TIntermDeclaration & node)928 PreResult TIntermRebuild::visitDeclarationPre(TIntermDeclaration &node)
929 {
930     return {node, VisitBits::Both};
931 }
932 
visitBlockPre(TIntermBlock & node)933 PreResult TIntermRebuild::visitBlockPre(TIntermBlock &node)
934 {
935     return {node, VisitBits::Both};
936 }
937 
visitAggregatePre(TIntermAggregate & node)938 PreResult TIntermRebuild::visitAggregatePre(TIntermAggregate &node)
939 {
940     return {node, VisitBits::Both};
941 }
942 
visitFunctionDefinitionPre(TIntermFunctionDefinition & node)943 PreResult TIntermRebuild::visitFunctionDefinitionPre(TIntermFunctionDefinition &node)
944 {
945     return {node, VisitBits::Both};
946 }
947 
visitGlobalQualifierDeclarationPre(TIntermGlobalQualifierDeclaration & node)948 PreResult TIntermRebuild::visitGlobalQualifierDeclarationPre(
949     TIntermGlobalQualifierDeclaration &node)
950 {
951     return {node, VisitBits::Both};
952 }
953 
954 ////////////////////////////////////////////////////////////////////////////////
955 
visitSymbolPost(TIntermSymbol & node)956 PostResult TIntermRebuild::visitSymbolPost(TIntermSymbol &node)
957 {
958     return node;
959 }
960 
visitConstantUnionPost(TIntermConstantUnion & node)961 PostResult TIntermRebuild::visitConstantUnionPost(TIntermConstantUnion &node)
962 {
963     return node;
964 }
965 
visitFunctionPrototypePost(TIntermFunctionPrototype & node)966 PostResult TIntermRebuild::visitFunctionPrototypePost(TIntermFunctionPrototype &node)
967 {
968     return node;
969 }
970 
visitPreprocessorDirectivePost(TIntermPreprocessorDirective & node)971 PostResult TIntermRebuild::visitPreprocessorDirectivePost(TIntermPreprocessorDirective &node)
972 {
973     return node;
974 }
975 
visitUnaryPost(TIntermUnary & node)976 PostResult TIntermRebuild::visitUnaryPost(TIntermUnary &node)
977 {
978     return node;
979 }
980 
visitBinaryPost(TIntermBinary & node)981 PostResult TIntermRebuild::visitBinaryPost(TIntermBinary &node)
982 {
983     return node;
984 }
985 
visitTernaryPost(TIntermTernary & node)986 PostResult TIntermRebuild::visitTernaryPost(TIntermTernary &node)
987 {
988     return node;
989 }
990 
visitSwizzlePost(TIntermSwizzle & node)991 PostResult TIntermRebuild::visitSwizzlePost(TIntermSwizzle &node)
992 {
993     return node;
994 }
995 
visitIfElsePost(TIntermIfElse & node)996 PostResult TIntermRebuild::visitIfElsePost(TIntermIfElse &node)
997 {
998     return node;
999 }
1000 
visitSwitchPost(TIntermSwitch & node)1001 PostResult TIntermRebuild::visitSwitchPost(TIntermSwitch &node)
1002 {
1003     return node;
1004 }
1005 
visitCasePost(TIntermCase & node)1006 PostResult TIntermRebuild::visitCasePost(TIntermCase &node)
1007 {
1008     return node;
1009 }
1010 
visitLoopPost(TIntermLoop & node)1011 PostResult TIntermRebuild::visitLoopPost(TIntermLoop &node)
1012 {
1013     return node;
1014 }
1015 
visitBranchPost(TIntermBranch & node)1016 PostResult TIntermRebuild::visitBranchPost(TIntermBranch &node)
1017 {
1018     return node;
1019 }
1020 
visitDeclarationPost(TIntermDeclaration & node)1021 PostResult TIntermRebuild::visitDeclarationPost(TIntermDeclaration &node)
1022 {
1023     return node;
1024 }
1025 
visitBlockPost(TIntermBlock & node)1026 PostResult TIntermRebuild::visitBlockPost(TIntermBlock &node)
1027 {
1028     return node;
1029 }
1030 
visitAggregatePost(TIntermAggregate & node)1031 PostResult TIntermRebuild::visitAggregatePost(TIntermAggregate &node)
1032 {
1033     return node;
1034 }
1035 
visitFunctionDefinitionPost(TIntermFunctionDefinition & node)1036 PostResult TIntermRebuild::visitFunctionDefinitionPost(TIntermFunctionDefinition &node)
1037 {
1038     return node;
1039 }
1040 
visitGlobalQualifierDeclarationPost(TIntermGlobalQualifierDeclaration & node)1041 PostResult TIntermRebuild::visitGlobalQualifierDeclarationPost(
1042     TIntermGlobalQualifierDeclaration &node)
1043 {
1044     return node;
1045 }
1046 
1047 }  // namespace sh
1048