• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include <algorithm>
8 
9 #include "compiler/translator/Compiler.h"
10 #include "compiler/translator/SymbolTable.h"
11 #include "compiler/translator/msl/AsNode.h"
12 #include "compiler/translator/msl/IntermRebuild.h"
13 
14 #define GUARD2(cond, failVal) \
15     do                        \
16     {                         \
17         if (!(cond))          \
18         {                     \
19             return failVal;   \
20         }                     \
21     } while (false)
22 
23 #define GUARD(cond) GUARD2(cond, nullptr)
24 
25 namespace sh
26 {
27 
28 template <typename T, typename U>
AllBits(T haystack,U needle)29 ANGLE_INLINE bool AllBits(T haystack, U needle)
30 {
31     return (haystack & needle) == needle;
32 }
33 
34 template <typename T, typename U>
AnyBits(T haystack,U needle)35 ANGLE_INLINE bool AnyBits(T haystack, U needle)
36 {
37     return (haystack & needle) != 0;
38 }
39 
40 ////////////////////////////////////////////////////////////////////////////////
41 
BaseResult(BaseResult & other)42 TIntermRebuild::BaseResult::BaseResult(BaseResult &other)
43     : mAction(other.mAction),
44       mVisit(other.mVisit),
45       mSingle(other.mSingle),
46       mMulti(std::move(other.mMulti))
47 {}
48 
BaseResult(TIntermNode & node,VisitBits visit)49 TIntermRebuild::BaseResult::BaseResult(TIntermNode &node, VisitBits visit)
50     : mAction(Action::ReplaceSingle), mVisit(visit), mSingle(&node)
51 {}
52 
BaseResult(TIntermNode * node,VisitBits visit)53 TIntermRebuild::BaseResult::BaseResult(TIntermNode *node, VisitBits visit)
54     : mAction(node ? Action::ReplaceSingle : Action::Drop),
55       mVisit(node ? visit : VisitBits::Neither),
56       mSingle(node)
57 {}
58 
BaseResult(std::nullptr_t)59 TIntermRebuild::BaseResult::BaseResult(std::nullptr_t)
60     : mAction(Action::Drop), mVisit(VisitBits::Neither), mSingle(nullptr)
61 {}
62 
BaseResult(Fail)63 TIntermRebuild::BaseResult::BaseResult(Fail)
64     : mAction(Action::Fail), mVisit(VisitBits::Neither), mSingle(nullptr)
65 {}
66 
BaseResult(std::vector<TIntermNode * > && nodes)67 TIntermRebuild::BaseResult::BaseResult(std::vector<TIntermNode *> &&nodes)
68     : mAction(Action::ReplaceMulti),
69       mVisit(VisitBits::Neither),
70       mSingle(nullptr),
71       mMulti(std::move(nodes))
72 {}
73 
moveAssignImpl(BaseResult & other)74 void TIntermRebuild::BaseResult::moveAssignImpl(BaseResult &other)
75 {
76     mAction = other.mAction;
77     mVisit  = other.mVisit;
78     mSingle = other.mSingle;
79     mMulti  = std::move(other.mMulti);
80 }
81 
Multi(std::vector<TIntermNode * > && nodes)82 TIntermRebuild::BaseResult TIntermRebuild::BaseResult::Multi(std::vector<TIntermNode *> &&nodes)
83 {
84     auto it = std::remove(nodes.begin(), nodes.end(), nullptr);
85     nodes.erase(it, nodes.end());
86     return std::move(nodes);
87 }
88 
isFail() const89 bool TIntermRebuild::BaseResult::isFail() const
90 {
91     return mAction == Action::Fail;
92 }
93 
isDrop() const94 bool TIntermRebuild::BaseResult::isDrop() const
95 {
96     return mAction == Action::Drop;
97 }
98 
single() const99 TIntermNode *TIntermRebuild::BaseResult::single() const
100 {
101     return mSingle;
102 }
103 
multi() const104 const std::vector<TIntermNode *> *TIntermRebuild::BaseResult::multi() const
105 {
106     if (mAction == Action::ReplaceMulti)
107     {
108         return &mMulti;
109     }
110     return nullptr;
111 }
112 
113 ////////////////////////////////////////////////////////////////////////////////
114 
115 using PreResult = TIntermRebuild::PreResult;
116 
PreResult(TIntermNode & node,VisitBits visit)117 PreResult::PreResult(TIntermNode &node, VisitBits visit) : BaseResult(node, visit) {}
PreResult(TIntermNode * node,VisitBits visit)118 PreResult::PreResult(TIntermNode *node, VisitBits visit) : BaseResult(node, visit) {}
PreResult(std::nullptr_t)119 PreResult::PreResult(std::nullptr_t) : BaseResult(nullptr) {}
PreResult(Fail)120 PreResult::PreResult(Fail) : BaseResult(Fail()) {}
121 
PreResult(BaseResult && other)122 PreResult::PreResult(BaseResult &&other) : BaseResult(other) {}
PreResult(PreResult && other)123 PreResult::PreResult(PreResult &&other) : BaseResult(other) {}
124 
operator =(PreResult && other)125 void PreResult::operator=(PreResult &&other)
126 {
127     moveAssignImpl(other);
128 }
129 
130 ////////////////////////////////////////////////////////////////////////////////
131 
132 using PostResult = TIntermRebuild::PostResult;
133 
PostResult(TIntermNode & node)134 PostResult::PostResult(TIntermNode &node) : BaseResult(node, VisitBits::Neither) {}
PostResult(TIntermNode * node)135 PostResult::PostResult(TIntermNode *node) : BaseResult(node, VisitBits::Neither) {}
PostResult(std::nullptr_t)136 PostResult::PostResult(std::nullptr_t) : BaseResult(nullptr) {}
PostResult(Fail)137 PostResult::PostResult(Fail) : BaseResult(Fail()) {}
138 
PostResult(PostResult && other)139 PostResult::PostResult(PostResult &&other) : BaseResult(other) {}
PostResult(BaseResult && other)140 PostResult::PostResult(BaseResult &&other) : BaseResult(other) {}
141 
operator =(PostResult && other)142 void PostResult::operator=(PostResult &&other)
143 {
144     moveAssignImpl(other);
145 }
146 
147 ////////////////////////////////////////////////////////////////////////////////
148 
TIntermRebuild(TCompiler & compiler,bool preVisit,bool postVisit)149 TIntermRebuild::TIntermRebuild(TCompiler &compiler, bool preVisit, bool postVisit)
150     : mCompiler(compiler),
151       mSymbolTable(compiler.getSymbolTable()),
152       mPreVisit(preVisit),
153       mPostVisit(postVisit)
154 {
155     ASSERT(preVisit || postVisit);
156 }
157 
~TIntermRebuild()158 TIntermRebuild::~TIntermRebuild()
159 {
160     ASSERT(!mNodeStack.value);
161     ASSERT(!mNodeStack.tail);
162 }
163 
getParentFunction() const164 const TFunction *TIntermRebuild::getParentFunction() const
165 {
166     return mParentFunc;
167 }
168 
getParentNode(size_t offset) const169 TIntermNode *TIntermRebuild::getParentNode(size_t offset) const
170 {
171     ASSERT(mNodeStack.tail);
172     auto parent = *mNodeStack.tail;
173     while (offset > 0)
174     {
175         --offset;
176         ASSERT(parent.tail);
177         parent = *parent.tail;
178     }
179     return parent.value;
180 }
181 
rebuildRoot(TIntermBlock & root)182 bool TIntermRebuild::rebuildRoot(TIntermBlock &root)
183 {
184     if (!rebuildInPlace(root))
185     {
186         return false;
187     }
188     return mCompiler.validateAST(&root);
189 }
190 
rebuildInPlace(TIntermAggregate & node)191 bool TIntermRebuild::rebuildInPlace(TIntermAggregate &node)
192 {
193     return rebuildInPlaceImpl(node);
194 }
195 
rebuildInPlace(TIntermBlock & node)196 bool TIntermRebuild::rebuildInPlace(TIntermBlock &node)
197 {
198     return rebuildInPlaceImpl(node);
199 }
200 
rebuildInPlace(TIntermDeclaration & node)201 bool TIntermRebuild::rebuildInPlace(TIntermDeclaration &node)
202 {
203     return rebuildInPlaceImpl(node);
204 }
205 
206 template <typename Node>
rebuildInPlaceImpl(Node & node)207 bool TIntermRebuild::rebuildInPlaceImpl(Node &node)
208 {
209     auto *newNode = traverseAnyAs<Node>(node);
210     if (!newNode)
211     {
212         return false;
213     }
214 
215     if (newNode != &node)
216     {
217         *node.getSequence() = std::move(*newNode->getSequence());
218     }
219 
220     return true;
221 }
222 
rebuild(TIntermNode & node)223 PostResult TIntermRebuild::rebuild(TIntermNode &node)
224 {
225     return traverseAny(node);
226 }
227 
228 ////////////////////////////////////////////////////////////////////////////////
229 
230 template <typename Node>
traverseAnyAs(TIntermNode & node)231 Node *TIntermRebuild::traverseAnyAs(TIntermNode &node)
232 {
233     PostResult result(traverseAny(node));
234     if (result.mAction == Action::Fail || !result.mSingle)
235     {
236         return nullptr;
237     }
238     return asNode<Node>(result.mSingle);
239 }
240 
241 template <typename Node>
traverseAnyAs(TIntermNode & node,Node * & out)242 bool TIntermRebuild::traverseAnyAs(TIntermNode &node, Node *&out)
243 {
244     PostResult result(traverseAny(node));
245     if (result.mAction == Action::Fail || result.mAction == Action::ReplaceMulti)
246     {
247         return false;
248     }
249     if (!result.mSingle)
250     {
251         return true;
252     }
253     out = asNode<Node>(result.mSingle);
254     return out != nullptr;
255 }
256 
traverseAggregateBaseChildren(TIntermAggregateBase & node)257 bool TIntermRebuild::traverseAggregateBaseChildren(TIntermAggregateBase &node)
258 {
259     auto *const children = node.getSequence();
260     ASSERT(children);
261     TIntermSequence newChildren;
262 
263     for (TIntermNode *child : *children)
264     {
265         ASSERT(child);
266         PostResult result(traverseAny(*child));
267 
268         switch (result.mAction)
269         {
270             case Action::ReplaceSingle:
271                 newChildren.push_back(result.mSingle);
272                 break;
273 
274             case Action::ReplaceMulti:
275                 for (TIntermNode *newNode : result.mMulti)
276                 {
277                     if (newNode)
278                     {
279                         newChildren.push_back(newNode);
280                     }
281                 }
282                 break;
283 
284             case Action::Drop:
285                 break;
286 
287             case Action::Fail:
288                 return false;
289 
290             default:
291                 ASSERT(false);
292                 return false;
293         }
294     }
295 
296     *children = std::move(newChildren);
297 
298     return true;
299 }
300 
301 ////////////////////////////////////////////////////////////////////////////////
302 
303 struct TIntermRebuild::NodeStackGuard
304 {
305     ConsList<TIntermNode *> oldNodeStack;
306     ConsList<TIntermNode *> &nodeStack;
NodeStackGuardsh::TIntermRebuild::NodeStackGuard307     NodeStackGuard(ConsList<TIntermNode *> &nodeStack, TIntermNode *node)
308         : oldNodeStack(nodeStack), nodeStack(nodeStack)
309     {
310         nodeStack = {node, &oldNodeStack};
311     }
~NodeStackGuardsh::TIntermRebuild::NodeStackGuard312     ~NodeStackGuard() { nodeStack = oldNodeStack; }
313 };
314 
traverseAny(TIntermNode & originalNode)315 PostResult TIntermRebuild::traverseAny(TIntermNode &originalNode)
316 {
317     PreResult preResult = traversePre(originalNode);
318     if (!preResult.mSingle)
319     {
320         ASSERT(preResult.mVisit == VisitBits::Neither);
321         return std::move(preResult);
322     }
323 
324     TIntermNode *currNode       = preResult.mSingle;
325     const VisitBits visit       = preResult.mVisit;
326     const NodeType currNodeType = getNodeType(*currNode);
327 
328     currNode = traverseChildren(currNodeType, originalNode, *currNode, visit);
329     if (!currNode)
330     {
331         return Fail();
332     }
333 
334     return traversePost(currNodeType, originalNode, *currNode, visit);
335 }
336 
traversePre(TIntermNode & originalNode)337 PreResult TIntermRebuild::traversePre(TIntermNode &originalNode)
338 {
339     if (!mPreVisit)
340     {
341         return {originalNode, VisitBits::Both};
342     }
343 
344     NodeStackGuard guard(mNodeStack, &originalNode);
345 
346     const NodeType originalNodeType = getNodeType(originalNode);
347 
348     switch (originalNodeType)
349     {
350         case NodeType::Unknown:
351             ASSERT(false);
352             return Fail();
353         case NodeType::Symbol:
354             return visitSymbolPre(*originalNode.getAsSymbolNode());
355         case NodeType::ConstantUnion:
356             return visitConstantUnionPre(*originalNode.getAsConstantUnion());
357         case NodeType::FunctionPrototype:
358             return visitFunctionPrototypePre(*originalNode.getAsFunctionPrototypeNode());
359         case NodeType::PreprocessorDirective:
360             return visitPreprocessorDirectivePre(*originalNode.getAsPreprocessorDirective());
361         case NodeType::Unary:
362             return visitUnaryPre(*originalNode.getAsUnaryNode());
363         case NodeType::Binary:
364             return visitBinaryPre(*originalNode.getAsBinaryNode());
365         case NodeType::Ternary:
366             return visitTernaryPre(*originalNode.getAsTernaryNode());
367         case NodeType::Swizzle:
368             return visitSwizzlePre(*originalNode.getAsSwizzleNode());
369         case NodeType::IfElse:
370             return visitIfElsePre(*originalNode.getAsIfElseNode());
371         case NodeType::Switch:
372             return visitSwitchPre(*originalNode.getAsSwitchNode());
373         case NodeType::Case:
374             return visitCasePre(*originalNode.getAsCaseNode());
375         case NodeType::FunctionDefinition:
376             return visitFunctionDefinitionPre(*originalNode.getAsFunctionDefinition());
377         case NodeType::Aggregate:
378             return visitAggregatePre(*originalNode.getAsAggregate());
379         case NodeType::Block:
380             return visitBlockPre(*originalNode.getAsBlock());
381         case NodeType::GlobalQualifierDeclaration:
382             return visitGlobalQualifierDeclarationPre(
383                 *originalNode.getAsGlobalQualifierDeclarationNode());
384         case NodeType::Declaration:
385             return visitDeclarationPre(*originalNode.getAsDeclarationNode());
386         case NodeType::Loop:
387             return visitLoopPre(*originalNode.getAsLoopNode());
388         case NodeType::Branch:
389             return visitBranchPre(*originalNode.getAsBranchNode());
390         default:
391             ASSERT(false);
392             return Fail();
393     }
394 }
395 
traverseChildren(NodeType currNodeType,const TIntermNode & originalNode,TIntermNode & currNode,VisitBits visit)396 TIntermNode *TIntermRebuild::traverseChildren(NodeType currNodeType,
397                                               const TIntermNode &originalNode,
398                                               TIntermNode &currNode,
399                                               VisitBits visit)
400 {
401     if (!AnyBits(visit, VisitBits::Children))
402     {
403         return &currNode;
404     }
405 
406     if (AnyBits(visit, VisitBits::ChildrenRequiresSame) && &originalNode != &currNode)
407     {
408         return &currNode;
409     }
410 
411     NodeStackGuard guard(mNodeStack, &currNode);
412 
413     switch (currNodeType)
414     {
415         case NodeType::Unknown:
416             ASSERT(false);
417             return nullptr;
418         case NodeType::Symbol:
419             return &currNode;
420         case NodeType::ConstantUnion:
421             return &currNode;
422         case NodeType::FunctionPrototype:
423             return &currNode;
424         case NodeType::PreprocessorDirective:
425             return &currNode;
426         case NodeType::Unary:
427             return traverseUnaryChildren(*currNode.getAsUnaryNode());
428         case NodeType::Binary:
429             return traverseBinaryChildren(*currNode.getAsBinaryNode());
430         case NodeType::Ternary:
431             return traverseTernaryChildren(*currNode.getAsTernaryNode());
432         case NodeType::Swizzle:
433             return traverseSwizzleChildren(*currNode.getAsSwizzleNode());
434         case NodeType::IfElse:
435             return traverseIfElseChildren(*currNode.getAsIfElseNode());
436         case NodeType::Switch:
437             return traverseSwitchChildren(*currNode.getAsSwitchNode());
438         case NodeType::Case:
439             return traverseCaseChildren(*currNode.getAsCaseNode());
440         case NodeType::FunctionDefinition:
441             return traverseFunctionDefinitionChildren(*currNode.getAsFunctionDefinition());
442         case NodeType::Aggregate:
443             return traverseAggregateChildren(*currNode.getAsAggregate());
444         case NodeType::Block:
445             return traverseBlockChildren(*currNode.getAsBlock());
446         case NodeType::GlobalQualifierDeclaration:
447             return traverseGlobalQualifierDeclarationChildren(
448                 *currNode.getAsGlobalQualifierDeclarationNode());
449         case NodeType::Declaration:
450             return traverseDeclarationChildren(*currNode.getAsDeclarationNode());
451         case NodeType::Loop:
452             return traverseLoopChildren(*currNode.getAsLoopNode());
453         case NodeType::Branch:
454             return traverseBranchChildren(*currNode.getAsBranchNode());
455         default:
456             ASSERT(false);
457             return nullptr;
458     }
459 }
460 
traversePost(NodeType currNodeType,const TIntermNode & originalNode,TIntermNode & currNode,VisitBits visit)461 PostResult TIntermRebuild::traversePost(NodeType currNodeType,
462                                         const TIntermNode &originalNode,
463                                         TIntermNode &currNode,
464                                         VisitBits visit)
465 {
466     if (!mPostVisit)
467     {
468         return currNode;
469     }
470 
471     if (!AnyBits(visit, VisitBits::Post))
472     {
473         return currNode;
474     }
475 
476     if (AnyBits(visit, VisitBits::PostRequiresSame) && &originalNode != &currNode)
477     {
478         return currNode;
479     }
480 
481     NodeStackGuard guard(mNodeStack, &currNode);
482 
483     switch (currNodeType)
484     {
485         case NodeType::Unknown:
486             ASSERT(false);
487             return Fail();
488         case NodeType::Symbol:
489             return visitSymbolPost(*currNode.getAsSymbolNode());
490         case NodeType::ConstantUnion:
491             return visitConstantUnionPost(*currNode.getAsConstantUnion());
492         case NodeType::FunctionPrototype:
493             return visitFunctionPrototypePost(*currNode.getAsFunctionPrototypeNode());
494         case NodeType::PreprocessorDirective:
495             return visitPreprocessorDirectivePost(*currNode.getAsPreprocessorDirective());
496         case NodeType::Unary:
497             return visitUnaryPost(*currNode.getAsUnaryNode());
498         case NodeType::Binary:
499             return visitBinaryPost(*currNode.getAsBinaryNode());
500         case NodeType::Ternary:
501             return visitTernaryPost(*currNode.getAsTernaryNode());
502         case NodeType::Swizzle:
503             return visitSwizzlePost(*currNode.getAsSwizzleNode());
504         case NodeType::IfElse:
505             return visitIfElsePost(*currNode.getAsIfElseNode());
506         case NodeType::Switch:
507             return visitSwitchPost(*currNode.getAsSwitchNode());
508         case NodeType::Case:
509             return visitCasePost(*currNode.getAsCaseNode());
510         case NodeType::FunctionDefinition:
511             return visitFunctionDefinitionPost(*currNode.getAsFunctionDefinition());
512         case NodeType::Aggregate:
513             return visitAggregatePost(*currNode.getAsAggregate());
514         case NodeType::Block:
515             return visitBlockPost(*currNode.getAsBlock());
516         case NodeType::GlobalQualifierDeclaration:
517             return visitGlobalQualifierDeclarationPost(
518                 *currNode.getAsGlobalQualifierDeclarationNode());
519         case NodeType::Declaration:
520             return visitDeclarationPost(*currNode.getAsDeclarationNode());
521         case NodeType::Loop:
522             return visitLoopPost(*currNode.getAsLoopNode());
523         case NodeType::Branch:
524             return visitBranchPost(*currNode.getAsBranchNode());
525         default:
526             ASSERT(false);
527             return Fail();
528     }
529 }
530 
531 ////////////////////////////////////////////////////////////////////////////////
532 
traverseAggregateChildren(TIntermAggregate & node)533 TIntermNode *TIntermRebuild::traverseAggregateChildren(TIntermAggregate &node)
534 {
535     if (traverseAggregateBaseChildren(node))
536     {
537         return &node;
538     }
539     return nullptr;
540 }
541 
traverseBlockChildren(TIntermBlock & node)542 TIntermNode *TIntermRebuild::traverseBlockChildren(TIntermBlock &node)
543 {
544     if (traverseAggregateBaseChildren(node))
545     {
546         return &node;
547     }
548     return nullptr;
549 }
550 
traverseDeclarationChildren(TIntermDeclaration & node)551 TIntermNode *TIntermRebuild::traverseDeclarationChildren(TIntermDeclaration &node)
552 {
553     if (traverseAggregateBaseChildren(node))
554     {
555         return &node;
556     }
557     return nullptr;
558 }
559 
traverseSwizzleChildren(TIntermSwizzle & node)560 TIntermNode *TIntermRebuild::traverseSwizzleChildren(TIntermSwizzle &node)
561 {
562     auto *const operand = node.getOperand();
563     ASSERT(operand);
564 
565     auto *newOperand = traverseAnyAs<TIntermTyped>(*operand);
566     GUARD(newOperand);
567 
568     if (newOperand != operand)
569     {
570         return new TIntermSwizzle(newOperand, node.getSwizzleOffsets());
571     }
572 
573     return &node;
574 }
575 
traverseBinaryChildren(TIntermBinary & node)576 TIntermNode *TIntermRebuild::traverseBinaryChildren(TIntermBinary &node)
577 {
578     auto *const left = node.getLeft();
579     ASSERT(left);
580     auto *const right = node.getRight();
581     ASSERT(right);
582 
583     auto *const newLeft = traverseAnyAs<TIntermTyped>(*left);
584     GUARD(newLeft);
585     auto *const newRight = traverseAnyAs<TIntermTyped>(*right);
586     GUARD(newRight);
587 
588     if (newLeft != left || newRight != right)
589     {
590         TOperator op = node.getOp();
591         switch (op)
592         {
593             case TOperator::EOpIndexDirectStruct:
594             {
595                 if (newLeft->getType().getInterfaceBlock())
596                 {
597                     op = TOperator::EOpIndexDirectInterfaceBlock;
598                 }
599             }
600             break;
601 
602             case TOperator::EOpIndexDirectInterfaceBlock:
603             {
604                 if (newLeft->getType().getStruct())
605                 {
606                     op = TOperator::EOpIndexDirectStruct;
607                 }
608             }
609             break;
610 
611             case TOperator::EOpComma:
612                 return TIntermBinary::CreateComma(newLeft, newRight, mCompiler.getShaderVersion());
613 
614             default:
615                 break;
616         }
617 
618         return new TIntermBinary(op, newLeft, newRight);
619     }
620 
621     return &node;
622 }
623 
traverseUnaryChildren(TIntermUnary & node)624 TIntermNode *TIntermRebuild::traverseUnaryChildren(TIntermUnary &node)
625 {
626     auto *const operand = node.getOperand();
627     ASSERT(operand);
628 
629     auto *const newOperand = traverseAnyAs<TIntermTyped>(*operand);
630     GUARD(newOperand);
631 
632     if (newOperand != operand)
633     {
634         return new TIntermUnary(node.getOp(), newOperand, node.getFunction());
635     }
636 
637     return &node;
638 }
639 
traverseTernaryChildren(TIntermTernary & node)640 TIntermNode *TIntermRebuild::traverseTernaryChildren(TIntermTernary &node)
641 {
642     auto *const cond = node.getCondition();
643     ASSERT(cond);
644     auto *const true_ = node.getTrueExpression();
645     ASSERT(true_);
646     auto *const false_ = node.getFalseExpression();
647     ASSERT(false_);
648 
649     auto *const newCond = traverseAnyAs<TIntermTyped>(*cond);
650     GUARD(newCond);
651     auto *const newTrue = traverseAnyAs<TIntermTyped>(*true_);
652     GUARD(newTrue);
653     auto *const newFalse = traverseAnyAs<TIntermTyped>(*false_);
654     GUARD(newFalse);
655 
656     if (newCond != cond || newTrue != true_ || newFalse != false_)
657     {
658         return new TIntermTernary(newCond, newTrue, newFalse);
659     }
660 
661     return &node;
662 }
663 
traverseIfElseChildren(TIntermIfElse & node)664 TIntermNode *TIntermRebuild::traverseIfElseChildren(TIntermIfElse &node)
665 {
666     auto *const cond = node.getCondition();
667     ASSERT(cond);
668     auto *const true_  = node.getTrueBlock();
669     auto *const false_ = node.getFalseBlock();
670 
671     auto *const newCond = traverseAnyAs<TIntermTyped>(*cond);
672     GUARD(newCond);
673     TIntermBlock *newTrue = nullptr;
674     if (true_)
675     {
676         GUARD(traverseAnyAs(*true_, newTrue));
677     }
678     TIntermBlock *newFalse = nullptr;
679     if (false_)
680     {
681         GUARD(traverseAnyAs(*false_, newFalse));
682     }
683 
684     if (newCond != cond || newTrue != true_ || newFalse != false_)
685     {
686         return new TIntermIfElse(newCond, newTrue, newFalse);
687     }
688 
689     return &node;
690 }
691 
traverseSwitchChildren(TIntermSwitch & node)692 TIntermNode *TIntermRebuild::traverseSwitchChildren(TIntermSwitch &node)
693 {
694     auto *const init = node.getInit();
695     ASSERT(init);
696     auto *const stmts = node.getStatementList();
697     ASSERT(stmts);
698 
699     auto *const newInit = traverseAnyAs<TIntermTyped>(*init);
700     GUARD(newInit);
701     auto *const newStmts = traverseAnyAs<TIntermBlock>(*stmts);
702     GUARD(newStmts);
703 
704     if (newInit != init || newStmts != stmts)
705     {
706         return new TIntermSwitch(newInit, newStmts);
707     }
708 
709     return &node;
710 }
711 
traverseCaseChildren(TIntermCase & node)712 TIntermNode *TIntermRebuild::traverseCaseChildren(TIntermCase &node)
713 {
714     auto *const cond = node.getCondition();
715 
716     TIntermTyped *newCond = nullptr;
717     if (cond)
718     {
719         GUARD(traverseAnyAs(*cond, newCond));
720     }
721 
722     if (newCond != cond)
723     {
724         return new TIntermCase(newCond);
725     }
726 
727     return &node;
728 }
729 
traverseFunctionDefinitionChildren(TIntermFunctionDefinition & node)730 TIntermNode *TIntermRebuild::traverseFunctionDefinitionChildren(TIntermFunctionDefinition &node)
731 {
732     GUARD(!mParentFunc);  // Function definitions cannot be nested.
733     mParentFunc = node.getFunction();
734     struct OnExit
735     {
736         const TFunction *&parentFunc;
737         OnExit(const TFunction *&parentFunc) : parentFunc(parentFunc) {}
738         ~OnExit() { parentFunc = nullptr; }
739     } onExit(mParentFunc);
740 
741     auto *const proto = node.getFunctionPrototype();
742     ASSERT(proto);
743     auto *const body = node.getBody();
744     ASSERT(body);
745 
746     auto *const newProto = traverseAnyAs<TIntermFunctionPrototype>(*proto);
747     GUARD(newProto);
748     auto *const newBody = traverseAnyAs<TIntermBlock>(*body);
749     GUARD(newBody);
750 
751     if (newProto != proto || newBody != body)
752     {
753         return new TIntermFunctionDefinition(newProto, newBody);
754     }
755 
756     return &node;
757 }
758 
traverseGlobalQualifierDeclarationChildren(TIntermGlobalQualifierDeclaration & node)759 TIntermNode *TIntermRebuild::traverseGlobalQualifierDeclarationChildren(
760     TIntermGlobalQualifierDeclaration &node)
761 {
762     auto *const symbol = node.getSymbol();
763     ASSERT(symbol);
764 
765     auto *const newSymbol = traverseAnyAs<TIntermSymbol>(*symbol);
766     GUARD(newSymbol);
767 
768     if (newSymbol != symbol)
769     {
770         return new TIntermGlobalQualifierDeclaration(newSymbol, node.isPrecise(), node.getLine());
771     }
772 
773     return &node;
774 }
775 
traverseLoopChildren(TIntermLoop & node)776 TIntermNode *TIntermRebuild::traverseLoopChildren(TIntermLoop &node)
777 {
778     const TLoopType loopType = node.getType();
779 
780     auto *const init = node.getInit();
781     auto *const cond = node.getCondition();
782     auto *const expr = node.getExpression();
783     auto *const body = node.getBody();
784     ASSERT(body);
785 
786 #if defined(ANGLE_ENABLE_ASSERTS)
787     switch (loopType)
788     {
789         case TLoopType::ELoopFor:
790             break;
791         case TLoopType::ELoopWhile:
792         case TLoopType::ELoopDoWhile:
793             ASSERT(cond);
794             ASSERT(!init && !expr);
795             break;
796         default:
797             ASSERT(false);
798             break;
799     }
800 #endif
801 
802     auto *const newBody = traverseAnyAs<TIntermBlock>(*body);
803     GUARD(newBody);
804     TIntermNode *newInit = nullptr;
805     if (init)
806     {
807         GUARD(traverseAnyAs(*init, newInit));
808     }
809     TIntermTyped *newCond = nullptr;
810     if (cond)
811     {
812         GUARD(traverseAnyAs(*cond, newCond));
813     }
814     TIntermTyped *newExpr = nullptr;
815     if (expr)
816     {
817         GUARD(traverseAnyAs(*expr, newExpr));
818     }
819 
820     if (newInit != init || newCond != cond || newExpr != expr || newBody != body)
821     {
822         switch (loopType)
823         {
824             case TLoopType::ELoopFor:
825                 GUARD(newBody);
826                 break;
827             case TLoopType::ELoopWhile:
828             case TLoopType::ELoopDoWhile:
829                 GUARD(newCond && newBody);
830                 GUARD(!newInit && !newExpr);
831                 break;
832             default:
833                 ASSERT(false);
834                 break;
835         }
836         return new TIntermLoop(loopType, newInit, newCond, newExpr, newBody);
837     }
838 
839     return &node;
840 }
841 
traverseBranchChildren(TIntermBranch & node)842 TIntermNode *TIntermRebuild::traverseBranchChildren(TIntermBranch &node)
843 {
844     auto *const expr = node.getExpression();
845 
846     TIntermTyped *newExpr = nullptr;
847     if (expr)
848     {
849         GUARD(traverseAnyAs<TIntermTyped>(*expr, newExpr));
850     }
851 
852     if (newExpr != expr)
853     {
854         return new TIntermBranch(node.getFlowOp(), newExpr);
855     }
856 
857     return &node;
858 }
859 
860 ////////////////////////////////////////////////////////////////////////////////
861 
visitSymbolPre(TIntermSymbol & node)862 PreResult TIntermRebuild::visitSymbolPre(TIntermSymbol &node)
863 {
864     return {node, VisitBits::Both};
865 }
866 
visitConstantUnionPre(TIntermConstantUnion & node)867 PreResult TIntermRebuild::visitConstantUnionPre(TIntermConstantUnion &node)
868 {
869     return {node, VisitBits::Both};
870 }
871 
visitFunctionPrototypePre(TIntermFunctionPrototype & node)872 PreResult TIntermRebuild::visitFunctionPrototypePre(TIntermFunctionPrototype &node)
873 {
874     return {node, VisitBits::Both};
875 }
876 
visitPreprocessorDirectivePre(TIntermPreprocessorDirective & node)877 PreResult TIntermRebuild::visitPreprocessorDirectivePre(TIntermPreprocessorDirective &node)
878 {
879     return {node, VisitBits::Both};
880 }
881 
visitUnaryPre(TIntermUnary & node)882 PreResult TIntermRebuild::visitUnaryPre(TIntermUnary &node)
883 {
884     return {node, VisitBits::Both};
885 }
886 
visitBinaryPre(TIntermBinary & node)887 PreResult TIntermRebuild::visitBinaryPre(TIntermBinary &node)
888 {
889     return {node, VisitBits::Both};
890 }
891 
visitTernaryPre(TIntermTernary & node)892 PreResult TIntermRebuild::visitTernaryPre(TIntermTernary &node)
893 {
894     return {node, VisitBits::Both};
895 }
896 
visitSwizzlePre(TIntermSwizzle & node)897 PreResult TIntermRebuild::visitSwizzlePre(TIntermSwizzle &node)
898 {
899     return {node, VisitBits::Both};
900 }
901 
visitIfElsePre(TIntermIfElse & node)902 PreResult TIntermRebuild::visitIfElsePre(TIntermIfElse &node)
903 {
904     return {node, VisitBits::Both};
905 }
906 
visitSwitchPre(TIntermSwitch & node)907 PreResult TIntermRebuild::visitSwitchPre(TIntermSwitch &node)
908 {
909     return {node, VisitBits::Both};
910 }
911 
visitCasePre(TIntermCase & node)912 PreResult TIntermRebuild::visitCasePre(TIntermCase &node)
913 {
914     return {node, VisitBits::Both};
915 }
916 
visitLoopPre(TIntermLoop & node)917 PreResult TIntermRebuild::visitLoopPre(TIntermLoop &node)
918 {
919     return {node, VisitBits::Both};
920 }
921 
visitBranchPre(TIntermBranch & node)922 PreResult TIntermRebuild::visitBranchPre(TIntermBranch &node)
923 {
924     return {node, VisitBits::Both};
925 }
926 
visitDeclarationPre(TIntermDeclaration & node)927 PreResult TIntermRebuild::visitDeclarationPre(TIntermDeclaration &node)
928 {
929     return {node, VisitBits::Both};
930 }
931 
visitBlockPre(TIntermBlock & node)932 PreResult TIntermRebuild::visitBlockPre(TIntermBlock &node)
933 {
934     return {node, VisitBits::Both};
935 }
936 
visitAggregatePre(TIntermAggregate & node)937 PreResult TIntermRebuild::visitAggregatePre(TIntermAggregate &node)
938 {
939     return {node, VisitBits::Both};
940 }
941 
visitFunctionDefinitionPre(TIntermFunctionDefinition & node)942 PreResult TIntermRebuild::visitFunctionDefinitionPre(TIntermFunctionDefinition &node)
943 {
944     return {node, VisitBits::Both};
945 }
946 
visitGlobalQualifierDeclarationPre(TIntermGlobalQualifierDeclaration & node)947 PreResult TIntermRebuild::visitGlobalQualifierDeclarationPre(
948     TIntermGlobalQualifierDeclaration &node)
949 {
950     return {node, VisitBits::Both};
951 }
952 
953 ////////////////////////////////////////////////////////////////////////////////
954 
visitSymbolPost(TIntermSymbol & node)955 PostResult TIntermRebuild::visitSymbolPost(TIntermSymbol &node)
956 {
957     return node;
958 }
959 
visitConstantUnionPost(TIntermConstantUnion & node)960 PostResult TIntermRebuild::visitConstantUnionPost(TIntermConstantUnion &node)
961 {
962     return node;
963 }
964 
visitFunctionPrototypePost(TIntermFunctionPrototype & node)965 PostResult TIntermRebuild::visitFunctionPrototypePost(TIntermFunctionPrototype &node)
966 {
967     return node;
968 }
969 
visitPreprocessorDirectivePost(TIntermPreprocessorDirective & node)970 PostResult TIntermRebuild::visitPreprocessorDirectivePost(TIntermPreprocessorDirective &node)
971 {
972     return node;
973 }
974 
visitUnaryPost(TIntermUnary & node)975 PostResult TIntermRebuild::visitUnaryPost(TIntermUnary &node)
976 {
977     return node;
978 }
979 
visitBinaryPost(TIntermBinary & node)980 PostResult TIntermRebuild::visitBinaryPost(TIntermBinary &node)
981 {
982     return node;
983 }
984 
visitTernaryPost(TIntermTernary & node)985 PostResult TIntermRebuild::visitTernaryPost(TIntermTernary &node)
986 {
987     return node;
988 }
989 
visitSwizzlePost(TIntermSwizzle & node)990 PostResult TIntermRebuild::visitSwizzlePost(TIntermSwizzle &node)
991 {
992     return node;
993 }
994 
visitIfElsePost(TIntermIfElse & node)995 PostResult TIntermRebuild::visitIfElsePost(TIntermIfElse &node)
996 {
997     return node;
998 }
999 
visitSwitchPost(TIntermSwitch & node)1000 PostResult TIntermRebuild::visitSwitchPost(TIntermSwitch &node)
1001 {
1002     return node;
1003 }
1004 
visitCasePost(TIntermCase & node)1005 PostResult TIntermRebuild::visitCasePost(TIntermCase &node)
1006 {
1007     return node;
1008 }
1009 
visitLoopPost(TIntermLoop & node)1010 PostResult TIntermRebuild::visitLoopPost(TIntermLoop &node)
1011 {
1012     return node;
1013 }
1014 
visitBranchPost(TIntermBranch & node)1015 PostResult TIntermRebuild::visitBranchPost(TIntermBranch &node)
1016 {
1017     return node;
1018 }
1019 
visitDeclarationPost(TIntermDeclaration & node)1020 PostResult TIntermRebuild::visitDeclarationPost(TIntermDeclaration &node)
1021 {
1022     return node;
1023 }
1024 
visitBlockPost(TIntermBlock & node)1025 PostResult TIntermRebuild::visitBlockPost(TIntermBlock &node)
1026 {
1027     return node;
1028 }
1029 
visitAggregatePost(TIntermAggregate & node)1030 PostResult TIntermRebuild::visitAggregatePost(TIntermAggregate &node)
1031 {
1032     return node;
1033 }
1034 
visitFunctionDefinitionPost(TIntermFunctionDefinition & node)1035 PostResult TIntermRebuild::visitFunctionDefinitionPost(TIntermFunctionDefinition &node)
1036 {
1037     return node;
1038 }
1039 
visitGlobalQualifierDeclarationPost(TIntermGlobalQualifierDeclaration & node)1040 PostResult TIntermRebuild::visitGlobalQualifierDeclarationPost(
1041     TIntermGlobalQualifierDeclaration &node)
1042 {
1043     return node;
1044 }
1045 
1046 }  // namespace sh
1047