1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include <algorithm>
8
9 #include "compiler/translator/Compiler.h"
10 #include "compiler/translator/SymbolTable.h"
11 #include "compiler/translator/tree_util/AsNode.h"
12 #include "compiler/translator/tree_util/IntermRebuild.h"
13
14 #define GUARD2(cond, failVal) \
15 do \
16 { \
17 if (!(cond)) \
18 { \
19 return failVal; \
20 } \
21 } while (false)
22
23 #define GUARD(cond) GUARD2(cond, nullptr)
24
25 namespace sh
26 {
27
28 template <typename T, typename U>
AllBits(T haystack,U needle)29 ANGLE_INLINE bool AllBits(T haystack, U needle)
30 {
31 return (haystack & needle) == needle;
32 }
33
34 template <typename T, typename U>
AnyBits(T haystack,U needle)35 ANGLE_INLINE bool AnyBits(T haystack, U needle)
36 {
37 return (haystack & needle) != 0;
38 }
39
40 ////////////////////////////////////////////////////////////////////////////////
41
BaseResult(BaseResult & other)42 TIntermRebuild::BaseResult::BaseResult(BaseResult &other)
43 : mAction(other.mAction),
44 mVisit(other.mVisit),
45 mSingle(other.mSingle),
46 mMulti(std::move(other.mMulti))
47 {}
48
BaseResult(TIntermNode & node,VisitBits visit)49 TIntermRebuild::BaseResult::BaseResult(TIntermNode &node, VisitBits visit)
50 : mAction(Action::ReplaceSingle), mVisit(visit), mSingle(&node)
51 {}
52
BaseResult(TIntermNode * node,VisitBits visit)53 TIntermRebuild::BaseResult::BaseResult(TIntermNode *node, VisitBits visit)
54 : mAction(node ? Action::ReplaceSingle : Action::Drop),
55 mVisit(node ? visit : VisitBits::Neither),
56 mSingle(node)
57 {}
58
BaseResult(std::nullptr_t)59 TIntermRebuild::BaseResult::BaseResult(std::nullptr_t)
60 : mAction(Action::Drop), mVisit(VisitBits::Neither), mSingle(nullptr)
61 {}
62
BaseResult(Fail)63 TIntermRebuild::BaseResult::BaseResult(Fail)
64 : mAction(Action::Fail), mVisit(VisitBits::Neither), mSingle(nullptr)
65 {}
66
BaseResult(std::vector<TIntermNode * > && nodes)67 TIntermRebuild::BaseResult::BaseResult(std::vector<TIntermNode *> &&nodes)
68 : mAction(Action::ReplaceMulti),
69 mVisit(VisitBits::Neither),
70 mSingle(nullptr),
71 mMulti(std::move(nodes))
72 {}
73
moveAssignImpl(BaseResult & other)74 void TIntermRebuild::BaseResult::moveAssignImpl(BaseResult &other)
75 {
76 mAction = other.mAction;
77 mVisit = other.mVisit;
78 mSingle = other.mSingle;
79 mMulti = std::move(other.mMulti);
80 }
81
Multi(std::vector<TIntermNode * > && nodes)82 TIntermRebuild::BaseResult TIntermRebuild::BaseResult::Multi(std::vector<TIntermNode *> &&nodes)
83 {
84 auto it = std::remove(nodes.begin(), nodes.end(), nullptr);
85 nodes.erase(it, nodes.end());
86 return std::move(nodes);
87 }
88
isFail() const89 bool TIntermRebuild::BaseResult::isFail() const
90 {
91 return mAction == Action::Fail;
92 }
93
isDrop() const94 bool TIntermRebuild::BaseResult::isDrop() const
95 {
96 return mAction == Action::Drop;
97 }
98
single() const99 TIntermNode *TIntermRebuild::BaseResult::single() const
100 {
101 return mSingle;
102 }
103
multi() const104 const std::vector<TIntermNode *> *TIntermRebuild::BaseResult::multi() const
105 {
106 if (mAction == Action::ReplaceMulti)
107 {
108 return &mMulti;
109 }
110 return nullptr;
111 }
112
113 ////////////////////////////////////////////////////////////////////////////////
114
115 using PreResult = TIntermRebuild::PreResult;
116
PreResult(TIntermNode & node,VisitBits visit)117 PreResult::PreResult(TIntermNode &node, VisitBits visit) : BaseResult(node, visit) {}
PreResult(TIntermNode * node,VisitBits visit)118 PreResult::PreResult(TIntermNode *node, VisitBits visit) : BaseResult(node, visit) {}
PreResult(std::nullptr_t)119 PreResult::PreResult(std::nullptr_t) : BaseResult(nullptr) {}
PreResult(Fail)120 PreResult::PreResult(Fail) : BaseResult(Fail()) {}
121
PreResult(BaseResult && other)122 PreResult::PreResult(BaseResult &&other) : BaseResult(other) {}
PreResult(PreResult && other)123 PreResult::PreResult(PreResult &&other) : BaseResult(other) {}
124
operator =(PreResult && other)125 void PreResult::operator=(PreResult &&other)
126 {
127 moveAssignImpl(other);
128 }
129
130 ////////////////////////////////////////////////////////////////////////////////
131
132 using PostResult = TIntermRebuild::PostResult;
133
PostResult(TIntermNode & node)134 PostResult::PostResult(TIntermNode &node) : BaseResult(node, VisitBits::Neither) {}
PostResult(TIntermNode * node)135 PostResult::PostResult(TIntermNode *node) : BaseResult(node, VisitBits::Neither) {}
PostResult(std::nullptr_t)136 PostResult::PostResult(std::nullptr_t) : BaseResult(nullptr) {}
PostResult(Fail)137 PostResult::PostResult(Fail) : BaseResult(Fail()) {}
138
PostResult(PostResult && other)139 PostResult::PostResult(PostResult &&other) : BaseResult(other) {}
PostResult(BaseResult && other)140 PostResult::PostResult(BaseResult &&other) : BaseResult(other) {}
141
operator =(PostResult && other)142 void PostResult::operator=(PostResult &&other)
143 {
144 moveAssignImpl(other);
145 }
146
147 ////////////////////////////////////////////////////////////////////////////////
148
TIntermRebuild(TCompiler & compiler,bool preVisit,bool postVisit)149 TIntermRebuild::TIntermRebuild(TCompiler &compiler, bool preVisit, bool postVisit)
150 : mCompiler(compiler),
151 mSymbolTable(compiler.getSymbolTable()),
152 mPreVisit(preVisit),
153 mPostVisit(postVisit)
154 {
155 ASSERT(preVisit || postVisit);
156 }
157
~TIntermRebuild()158 TIntermRebuild::~TIntermRebuild()
159 {
160 ASSERT(!mNodeStack.value);
161 ASSERT(!mNodeStack.tail);
162 }
163
getParentFunction() const164 const TFunction *TIntermRebuild::getParentFunction() const
165 {
166 return mParentFunc;
167 }
168
getParentNode(size_t offset) const169 TIntermNode *TIntermRebuild::getParentNode(size_t offset) const
170 {
171 ASSERT(mNodeStack.tail);
172 auto parent = *mNodeStack.tail;
173 while (offset > 0)
174 {
175 --offset;
176 ASSERT(parent.tail);
177 parent = *parent.tail;
178 }
179 return parent.value;
180 }
181
rebuildRoot(TIntermBlock & root)182 bool TIntermRebuild::rebuildRoot(TIntermBlock &root)
183 {
184 if (!rebuildInPlace(root))
185 {
186 return false;
187 }
188 return mCompiler.validateAST(&root);
189 }
190
rebuildInPlace(TIntermAggregate & node)191 bool TIntermRebuild::rebuildInPlace(TIntermAggregate &node)
192 {
193 return rebuildInPlaceImpl(node);
194 }
195
rebuildInPlace(TIntermBlock & node)196 bool TIntermRebuild::rebuildInPlace(TIntermBlock &node)
197 {
198 return rebuildInPlaceImpl(node);
199 }
200
rebuildInPlace(TIntermDeclaration & node)201 bool TIntermRebuild::rebuildInPlace(TIntermDeclaration &node)
202 {
203 return rebuildInPlaceImpl(node);
204 }
205
206 template <typename Node>
rebuildInPlaceImpl(Node & node)207 bool TIntermRebuild::rebuildInPlaceImpl(Node &node)
208 {
209 auto *newNode = traverseAnyAs<Node>(node);
210 if (!newNode)
211 {
212 return false;
213 }
214
215 if (newNode != &node)
216 {
217 *node.getSequence() = std::move(*newNode->getSequence());
218 }
219
220 return true;
221 }
222
rebuild(TIntermNode & node)223 PostResult TIntermRebuild::rebuild(TIntermNode &node)
224 {
225 return traverseAny(node);
226 }
227
228 ////////////////////////////////////////////////////////////////////////////////
229
230 template <typename Node>
traverseAnyAs(TIntermNode & node)231 Node *TIntermRebuild::traverseAnyAs(TIntermNode &node)
232 {
233 PostResult result(traverseAny(node));
234 if (result.mAction == Action::Fail || !result.mSingle)
235 {
236 return nullptr;
237 }
238 return asNode<Node>(result.mSingle);
239 }
240
241 template <typename Node>
traverseAnyAs(TIntermNode & node,Node * & out)242 bool TIntermRebuild::traverseAnyAs(TIntermNode &node, Node *&out)
243 {
244 PostResult result(traverseAny(node));
245 if (result.mAction == Action::Fail || result.mAction == Action::ReplaceMulti)
246 {
247 return false;
248 }
249 if (!result.mSingle)
250 {
251 return true;
252 }
253 out = asNode<Node>(result.mSingle);
254 return out != nullptr;
255 }
256
traverseAggregateBaseChildren(TIntermAggregateBase & node)257 bool TIntermRebuild::traverseAggregateBaseChildren(TIntermAggregateBase &node)
258 {
259 auto *const children = node.getSequence();
260 ASSERT(children);
261 TIntermSequence newChildren;
262
263 for (TIntermNode *child : *children)
264 {
265 ASSERT(child);
266 PostResult result(traverseAny(*child));
267
268 switch (result.mAction)
269 {
270 case Action::ReplaceSingle:
271 newChildren.push_back(result.mSingle);
272 break;
273
274 case Action::ReplaceMulti:
275 for (TIntermNode *newNode : result.mMulti)
276 {
277 if (newNode)
278 {
279 newChildren.push_back(newNode);
280 }
281 }
282 break;
283
284 case Action::Drop:
285 break;
286
287 case Action::Fail:
288 return false;
289
290 default:
291 ASSERT(false);
292 return false;
293 }
294 }
295
296 *children = std::move(newChildren);
297
298 return true;
299 }
300
301 ////////////////////////////////////////////////////////////////////////////////
302
303 struct TIntermRebuild::NodeStackGuard
304 {
305 ConsList<TIntermNode *> oldNodeStack;
306 ConsList<TIntermNode *> &nodeStack;
NodeStackGuardsh::TIntermRebuild::NodeStackGuard307 NodeStackGuard(ConsList<TIntermNode *> &nodeStack)
308 : oldNodeStack(nodeStack), nodeStack(nodeStack)
309 {}
~NodeStackGuardsh::TIntermRebuild::NodeStackGuard310 ~NodeStackGuard() { nodeStack = oldNodeStack; }
311 };
312
traverseAny(TIntermNode & originalNode)313 PostResult TIntermRebuild::traverseAny(TIntermNode &originalNode)
314 {
315 PreResult preResult = traversePre(originalNode);
316 if (!preResult.mSingle)
317 {
318 ASSERT(preResult.mVisit == VisitBits::Neither);
319 return std::move(preResult);
320 }
321
322 TIntermNode *currNode = preResult.mSingle;
323 const VisitBits visit = preResult.mVisit;
324 const NodeType currNodeType = getNodeType(*currNode);
325
326 currNode = traverseChildren(currNodeType, originalNode, *currNode, visit);
327 if (!currNode)
328 {
329 return Fail();
330 }
331
332 return traversePost(currNodeType, originalNode, *currNode, visit);
333 }
334
traversePre(TIntermNode & originalNode)335 PreResult TIntermRebuild::traversePre(TIntermNode &originalNode)
336 {
337 if (!mPreVisit)
338 {
339 return {originalNode, VisitBits::Both};
340 }
341
342 NodeStackGuard guard(mNodeStack);
343 mNodeStack = {&originalNode, &guard.oldNodeStack};
344
345 const NodeType originalNodeType = getNodeType(originalNode);
346
347 switch (originalNodeType)
348 {
349 case NodeType::Unknown:
350 ASSERT(false);
351 return Fail();
352 case NodeType::Symbol:
353 return visitSymbolPre(*originalNode.getAsSymbolNode());
354 case NodeType::ConstantUnion:
355 return visitConstantUnionPre(*originalNode.getAsConstantUnion());
356 case NodeType::FunctionPrototype:
357 return visitFunctionPrototypePre(*originalNode.getAsFunctionPrototypeNode());
358 case NodeType::PreprocessorDirective:
359 return visitPreprocessorDirectivePre(*originalNode.getAsPreprocessorDirective());
360 case NodeType::Unary:
361 return visitUnaryPre(*originalNode.getAsUnaryNode());
362 case NodeType::Binary:
363 return visitBinaryPre(*originalNode.getAsBinaryNode());
364 case NodeType::Ternary:
365 return visitTernaryPre(*originalNode.getAsTernaryNode());
366 case NodeType::Swizzle:
367 return visitSwizzlePre(*originalNode.getAsSwizzleNode());
368 case NodeType::IfElse:
369 return visitIfElsePre(*originalNode.getAsIfElseNode());
370 case NodeType::Switch:
371 return visitSwitchPre(*originalNode.getAsSwitchNode());
372 case NodeType::Case:
373 return visitCasePre(*originalNode.getAsCaseNode());
374 case NodeType::FunctionDefinition:
375 return visitFunctionDefinitionPre(*originalNode.getAsFunctionDefinition());
376 case NodeType::Aggregate:
377 return visitAggregatePre(*originalNode.getAsAggregate());
378 case NodeType::Block:
379 return visitBlockPre(*originalNode.getAsBlock());
380 case NodeType::GlobalQualifierDeclaration:
381 return visitGlobalQualifierDeclarationPre(
382 *originalNode.getAsGlobalQualifierDeclarationNode());
383 case NodeType::Declaration:
384 return visitDeclarationPre(*originalNode.getAsDeclarationNode());
385 case NodeType::Loop:
386 return visitLoopPre(*originalNode.getAsLoopNode());
387 case NodeType::Branch:
388 return visitBranchPre(*originalNode.getAsBranchNode());
389 default:
390 ASSERT(false);
391 return Fail();
392 }
393 }
394
traverseChildren(NodeType currNodeType,const TIntermNode & originalNode,TIntermNode & currNode,VisitBits visit)395 TIntermNode *TIntermRebuild::traverseChildren(NodeType currNodeType,
396 const TIntermNode &originalNode,
397 TIntermNode &currNode,
398 VisitBits visit)
399 {
400 if (!AnyBits(visit, VisitBits::Children))
401 {
402 return &currNode;
403 }
404
405 if (AnyBits(visit, VisitBits::ChildrenRequiresSame) && &originalNode != &currNode)
406 {
407 return &currNode;
408 }
409
410 NodeStackGuard guard(mNodeStack);
411 mNodeStack = {&currNode, &guard.oldNodeStack};
412
413 switch (currNodeType)
414 {
415 case NodeType::Unknown:
416 ASSERT(false);
417 return nullptr;
418 case NodeType::Symbol:
419 return &currNode;
420 case NodeType::ConstantUnion:
421 return &currNode;
422 case NodeType::FunctionPrototype:
423 return &currNode;
424 case NodeType::PreprocessorDirective:
425 return &currNode;
426 case NodeType::Unary:
427 return traverseUnaryChildren(*currNode.getAsUnaryNode());
428 case NodeType::Binary:
429 return traverseBinaryChildren(*currNode.getAsBinaryNode());
430 case NodeType::Ternary:
431 return traverseTernaryChildren(*currNode.getAsTernaryNode());
432 case NodeType::Swizzle:
433 return traverseSwizzleChildren(*currNode.getAsSwizzleNode());
434 case NodeType::IfElse:
435 return traverseIfElseChildren(*currNode.getAsIfElseNode());
436 case NodeType::Switch:
437 return traverseSwitchChildren(*currNode.getAsSwitchNode());
438 case NodeType::Case:
439 return traverseCaseChildren(*currNode.getAsCaseNode());
440 case NodeType::FunctionDefinition:
441 return traverseFunctionDefinitionChildren(*currNode.getAsFunctionDefinition());
442 case NodeType::Aggregate:
443 return traverseAggregateChildren(*currNode.getAsAggregate());
444 case NodeType::Block:
445 return traverseBlockChildren(*currNode.getAsBlock());
446 case NodeType::GlobalQualifierDeclaration:
447 return traverseGlobalQualifierDeclarationChildren(
448 *currNode.getAsGlobalQualifierDeclarationNode());
449 case NodeType::Declaration:
450 return traverseDeclarationChildren(*currNode.getAsDeclarationNode());
451 case NodeType::Loop:
452 return traverseLoopChildren(*currNode.getAsLoopNode());
453 case NodeType::Branch:
454 return traverseBranchChildren(*currNode.getAsBranchNode());
455 default:
456 ASSERT(false);
457 return nullptr;
458 }
459 }
460
traversePost(NodeType currNodeType,const TIntermNode & originalNode,TIntermNode & currNode,VisitBits visit)461 PostResult TIntermRebuild::traversePost(NodeType currNodeType,
462 const TIntermNode &originalNode,
463 TIntermNode &currNode,
464 VisitBits visit)
465 {
466 if (!mPostVisit)
467 {
468 return currNode;
469 }
470
471 if (!AnyBits(visit, VisitBits::Post))
472 {
473 return currNode;
474 }
475
476 if (AnyBits(visit, VisitBits::PostRequiresSame) && &originalNode != &currNode)
477 {
478 return currNode;
479 }
480
481 NodeStackGuard guard(mNodeStack);
482 mNodeStack = {&currNode, &guard.oldNodeStack};
483
484 switch (currNodeType)
485 {
486 case NodeType::Unknown:
487 ASSERT(false);
488 return Fail();
489 case NodeType::Symbol:
490 return visitSymbolPost(*currNode.getAsSymbolNode());
491 case NodeType::ConstantUnion:
492 return visitConstantUnionPost(*currNode.getAsConstantUnion());
493 case NodeType::FunctionPrototype:
494 return visitFunctionPrototypePost(*currNode.getAsFunctionPrototypeNode());
495 case NodeType::PreprocessorDirective:
496 return visitPreprocessorDirectivePost(*currNode.getAsPreprocessorDirective());
497 case NodeType::Unary:
498 return visitUnaryPost(*currNode.getAsUnaryNode());
499 case NodeType::Binary:
500 return visitBinaryPost(*currNode.getAsBinaryNode());
501 case NodeType::Ternary:
502 return visitTernaryPost(*currNode.getAsTernaryNode());
503 case NodeType::Swizzle:
504 return visitSwizzlePost(*currNode.getAsSwizzleNode());
505 case NodeType::IfElse:
506 return visitIfElsePost(*currNode.getAsIfElseNode());
507 case NodeType::Switch:
508 return visitSwitchPost(*currNode.getAsSwitchNode());
509 case NodeType::Case:
510 return visitCasePost(*currNode.getAsCaseNode());
511 case NodeType::FunctionDefinition:
512 return visitFunctionDefinitionPost(*currNode.getAsFunctionDefinition());
513 case NodeType::Aggregate:
514 return visitAggregatePost(*currNode.getAsAggregate());
515 case NodeType::Block:
516 return visitBlockPost(*currNode.getAsBlock());
517 case NodeType::GlobalQualifierDeclaration:
518 return visitGlobalQualifierDeclarationPost(
519 *currNode.getAsGlobalQualifierDeclarationNode());
520 case NodeType::Declaration:
521 return visitDeclarationPost(*currNode.getAsDeclarationNode());
522 case NodeType::Loop:
523 return visitLoopPost(*currNode.getAsLoopNode());
524 case NodeType::Branch:
525 return visitBranchPost(*currNode.getAsBranchNode());
526 default:
527 ASSERT(false);
528 return Fail();
529 }
530 }
531
532 ////////////////////////////////////////////////////////////////////////////////
533
traverseAggregateChildren(TIntermAggregate & node)534 TIntermNode *TIntermRebuild::traverseAggregateChildren(TIntermAggregate &node)
535 {
536 if (traverseAggregateBaseChildren(node))
537 {
538 return &node;
539 }
540 return nullptr;
541 }
542
traverseBlockChildren(TIntermBlock & node)543 TIntermNode *TIntermRebuild::traverseBlockChildren(TIntermBlock &node)
544 {
545 if (traverseAggregateBaseChildren(node))
546 {
547 return &node;
548 }
549 return nullptr;
550 }
551
traverseDeclarationChildren(TIntermDeclaration & node)552 TIntermNode *TIntermRebuild::traverseDeclarationChildren(TIntermDeclaration &node)
553 {
554 if (traverseAggregateBaseChildren(node))
555 {
556 return &node;
557 }
558 return nullptr;
559 }
560
traverseSwizzleChildren(TIntermSwizzle & node)561 TIntermNode *TIntermRebuild::traverseSwizzleChildren(TIntermSwizzle &node)
562 {
563 auto *const operand = node.getOperand();
564 ASSERT(operand);
565
566 auto *newOperand = traverseAnyAs<TIntermTyped>(*operand);
567 GUARD(newOperand);
568
569 if (newOperand != operand)
570 {
571 return new TIntermSwizzle(newOperand, node.getSwizzleOffsets());
572 }
573
574 return &node;
575 }
576
traverseBinaryChildren(TIntermBinary & node)577 TIntermNode *TIntermRebuild::traverseBinaryChildren(TIntermBinary &node)
578 {
579 auto *const left = node.getLeft();
580 ASSERT(left);
581 auto *const right = node.getRight();
582 ASSERT(right);
583
584 auto *const newLeft = traverseAnyAs<TIntermTyped>(*left);
585 GUARD(newLeft);
586 auto *const newRight = traverseAnyAs<TIntermTyped>(*right);
587 GUARD(newRight);
588
589 if (newLeft != left || newRight != right)
590 {
591 TOperator op = node.getOp();
592 switch (op)
593 {
594 case TOperator::EOpIndexDirectStruct:
595 {
596 if (newLeft->getType().getInterfaceBlock())
597 {
598 op = TOperator::EOpIndexDirectInterfaceBlock;
599 }
600 }
601 break;
602
603 case TOperator::EOpIndexDirectInterfaceBlock:
604 {
605 if (newLeft->getType().getStruct())
606 {
607 op = TOperator::EOpIndexDirectStruct;
608 }
609 }
610 break;
611
612 case TOperator::EOpComma:
613 return TIntermBinary::CreateComma(newLeft, newRight, mCompiler.getShaderVersion());
614
615 default:
616 break;
617 }
618
619 return new TIntermBinary(op, newLeft, newRight);
620 }
621
622 return &node;
623 }
624
traverseUnaryChildren(TIntermUnary & node)625 TIntermNode *TIntermRebuild::traverseUnaryChildren(TIntermUnary &node)
626 {
627 auto *const operand = node.getOperand();
628 ASSERT(operand);
629
630 auto *const newOperand = traverseAnyAs<TIntermTyped>(*operand);
631 GUARD(newOperand);
632
633 if (newOperand != operand)
634 {
635 return new TIntermUnary(node.getOp(), newOperand, node.getFunction());
636 }
637
638 return &node;
639 }
640
traverseTernaryChildren(TIntermTernary & node)641 TIntermNode *TIntermRebuild::traverseTernaryChildren(TIntermTernary &node)
642 {
643 auto *const cond = node.getCondition();
644 ASSERT(cond);
645 auto *const true_ = node.getTrueExpression();
646 ASSERT(true_);
647 auto *const false_ = node.getFalseExpression();
648 ASSERT(false_);
649
650 auto *const newCond = traverseAnyAs<TIntermTyped>(*cond);
651 GUARD(newCond);
652 auto *const newTrue = traverseAnyAs<TIntermTyped>(*true_);
653 GUARD(newTrue);
654 auto *const newFalse = traverseAnyAs<TIntermTyped>(*false_);
655 GUARD(newFalse);
656
657 if (newCond != cond || newTrue != true_ || newFalse != false_)
658 {
659 return new TIntermTernary(newCond, newTrue, newFalse);
660 }
661
662 return &node;
663 }
664
traverseIfElseChildren(TIntermIfElse & node)665 TIntermNode *TIntermRebuild::traverseIfElseChildren(TIntermIfElse &node)
666 {
667 auto *const cond = node.getCondition();
668 ASSERT(cond);
669 auto *const true_ = node.getTrueBlock();
670 auto *const false_ = node.getFalseBlock();
671
672 auto *const newCond = traverseAnyAs<TIntermTyped>(*cond);
673 GUARD(newCond);
674 TIntermBlock *newTrue = nullptr;
675 if (true_)
676 {
677 GUARD(traverseAnyAs(*true_, newTrue));
678 }
679 TIntermBlock *newFalse = nullptr;
680 if (false_)
681 {
682 GUARD(traverseAnyAs(*false_, newFalse));
683 }
684
685 if (newCond != cond || newTrue != true_ || newFalse != false_)
686 {
687 return new TIntermIfElse(newCond, newTrue, newFalse);
688 }
689
690 return &node;
691 }
692
traverseSwitchChildren(TIntermSwitch & node)693 TIntermNode *TIntermRebuild::traverseSwitchChildren(TIntermSwitch &node)
694 {
695 auto *const init = node.getInit();
696 ASSERT(init);
697 auto *const stmts = node.getStatementList();
698 ASSERT(stmts);
699
700 auto *const newInit = traverseAnyAs<TIntermTyped>(*init);
701 GUARD(newInit);
702 auto *const newStmts = traverseAnyAs<TIntermBlock>(*stmts);
703 GUARD(newStmts);
704
705 if (newInit != init || newStmts != stmts)
706 {
707 return new TIntermSwitch(newInit, newStmts);
708 }
709
710 return &node;
711 }
712
traverseCaseChildren(TIntermCase & node)713 TIntermNode *TIntermRebuild::traverseCaseChildren(TIntermCase &node)
714 {
715 auto *const cond = node.getCondition();
716
717 TIntermTyped *newCond = nullptr;
718 if (cond)
719 {
720 GUARD(traverseAnyAs(*cond, newCond));
721 }
722
723 if (newCond != cond)
724 {
725 return new TIntermCase(newCond);
726 }
727
728 return &node;
729 }
730
traverseFunctionDefinitionChildren(TIntermFunctionDefinition & node)731 TIntermNode *TIntermRebuild::traverseFunctionDefinitionChildren(TIntermFunctionDefinition &node)
732 {
733 GUARD(!mParentFunc); // Function definitions cannot be nested.
734 mParentFunc = node.getFunction();
735 struct OnExit
736 {
737 const TFunction *&parentFunc;
738 OnExit(const TFunction *&parentFunc) : parentFunc(parentFunc) {}
739 ~OnExit() { parentFunc = nullptr; }
740 } onExit(mParentFunc);
741
742 auto *const proto = node.getFunctionPrototype();
743 ASSERT(proto);
744 auto *const body = node.getBody();
745 ASSERT(body);
746
747 auto *const newProto = traverseAnyAs<TIntermFunctionPrototype>(*proto);
748 GUARD(newProto);
749 auto *const newBody = traverseAnyAs<TIntermBlock>(*body);
750 GUARD(newBody);
751
752 if (newProto != proto || newBody != body)
753 {
754 return new TIntermFunctionDefinition(newProto, newBody);
755 }
756
757 return &node;
758 }
759
traverseGlobalQualifierDeclarationChildren(TIntermGlobalQualifierDeclaration & node)760 TIntermNode *TIntermRebuild::traverseGlobalQualifierDeclarationChildren(
761 TIntermGlobalQualifierDeclaration &node)
762 {
763 auto *const symbol = node.getSymbol();
764 ASSERT(symbol);
765
766 auto *const newSymbol = traverseAnyAs<TIntermSymbol>(*symbol);
767 GUARD(newSymbol);
768
769 if (newSymbol != symbol)
770 {
771 return new TIntermGlobalQualifierDeclaration(newSymbol, node.isPrecise(), node.getLine());
772 }
773
774 return &node;
775 }
776
traverseLoopChildren(TIntermLoop & node)777 TIntermNode *TIntermRebuild::traverseLoopChildren(TIntermLoop &node)
778 {
779 const TLoopType loopType = node.getType();
780
781 auto *const init = node.getInit();
782 auto *const cond = node.getCondition();
783 auto *const expr = node.getExpression();
784 auto *const body = node.getBody();
785 ASSERT(body);
786
787 #if defined(ANGLE_ENABLE_ASSERTS)
788 switch (loopType)
789 {
790 case TLoopType::ELoopFor:
791 break;
792 case TLoopType::ELoopWhile:
793 case TLoopType::ELoopDoWhile:
794 ASSERT(cond);
795 ASSERT(!init && !expr);
796 break;
797 default:
798 ASSERT(false);
799 break;
800 }
801 #endif
802
803 auto *const newBody = traverseAnyAs<TIntermBlock>(*body);
804 GUARD(newBody);
805 TIntermNode *newInit = nullptr;
806 if (init)
807 {
808 GUARD(traverseAnyAs(*init, newInit));
809 }
810 TIntermTyped *newCond = nullptr;
811 if (cond)
812 {
813 GUARD(traverseAnyAs(*cond, newCond));
814 }
815 TIntermTyped *newExpr = nullptr;
816 if (expr)
817 {
818 GUARD(traverseAnyAs(*expr, newExpr));
819 }
820
821 if (newInit != init || newCond != cond || newExpr != expr || newBody != body)
822 {
823 switch (loopType)
824 {
825 case TLoopType::ELoopFor:
826 GUARD(newBody);
827 break;
828 case TLoopType::ELoopWhile:
829 case TLoopType::ELoopDoWhile:
830 GUARD(newCond && newBody);
831 GUARD(!newInit && !newExpr);
832 break;
833 default:
834 ASSERT(false);
835 break;
836 }
837 return new TIntermLoop(loopType, newInit, newCond, newExpr, newBody);
838 }
839
840 return &node;
841 }
842
traverseBranchChildren(TIntermBranch & node)843 TIntermNode *TIntermRebuild::traverseBranchChildren(TIntermBranch &node)
844 {
845 auto *const expr = node.getExpression();
846
847 TIntermTyped *newExpr = nullptr;
848 if (expr)
849 {
850 GUARD(traverseAnyAs<TIntermTyped>(*expr, newExpr));
851 }
852
853 if (newExpr != expr)
854 {
855 return new TIntermBranch(node.getFlowOp(), newExpr);
856 }
857
858 return &node;
859 }
860
861 ////////////////////////////////////////////////////////////////////////////////
862
visitSymbolPre(TIntermSymbol & node)863 PreResult TIntermRebuild::visitSymbolPre(TIntermSymbol &node)
864 {
865 return {node, VisitBits::Both};
866 }
867
visitConstantUnionPre(TIntermConstantUnion & node)868 PreResult TIntermRebuild::visitConstantUnionPre(TIntermConstantUnion &node)
869 {
870 return {node, VisitBits::Both};
871 }
872
visitFunctionPrototypePre(TIntermFunctionPrototype & node)873 PreResult TIntermRebuild::visitFunctionPrototypePre(TIntermFunctionPrototype &node)
874 {
875 return {node, VisitBits::Both};
876 }
877
visitPreprocessorDirectivePre(TIntermPreprocessorDirective & node)878 PreResult TIntermRebuild::visitPreprocessorDirectivePre(TIntermPreprocessorDirective &node)
879 {
880 return {node, VisitBits::Both};
881 }
882
visitUnaryPre(TIntermUnary & node)883 PreResult TIntermRebuild::visitUnaryPre(TIntermUnary &node)
884 {
885 return {node, VisitBits::Both};
886 }
887
visitBinaryPre(TIntermBinary & node)888 PreResult TIntermRebuild::visitBinaryPre(TIntermBinary &node)
889 {
890 return {node, VisitBits::Both};
891 }
892
visitTernaryPre(TIntermTernary & node)893 PreResult TIntermRebuild::visitTernaryPre(TIntermTernary &node)
894 {
895 return {node, VisitBits::Both};
896 }
897
visitSwizzlePre(TIntermSwizzle & node)898 PreResult TIntermRebuild::visitSwizzlePre(TIntermSwizzle &node)
899 {
900 return {node, VisitBits::Both};
901 }
902
visitIfElsePre(TIntermIfElse & node)903 PreResult TIntermRebuild::visitIfElsePre(TIntermIfElse &node)
904 {
905 return {node, VisitBits::Both};
906 }
907
visitSwitchPre(TIntermSwitch & node)908 PreResult TIntermRebuild::visitSwitchPre(TIntermSwitch &node)
909 {
910 return {node, VisitBits::Both};
911 }
912
visitCasePre(TIntermCase & node)913 PreResult TIntermRebuild::visitCasePre(TIntermCase &node)
914 {
915 return {node, VisitBits::Both};
916 }
917
visitLoopPre(TIntermLoop & node)918 PreResult TIntermRebuild::visitLoopPre(TIntermLoop &node)
919 {
920 return {node, VisitBits::Both};
921 }
922
visitBranchPre(TIntermBranch & node)923 PreResult TIntermRebuild::visitBranchPre(TIntermBranch &node)
924 {
925 return {node, VisitBits::Both};
926 }
927
visitDeclarationPre(TIntermDeclaration & node)928 PreResult TIntermRebuild::visitDeclarationPre(TIntermDeclaration &node)
929 {
930 return {node, VisitBits::Both};
931 }
932
visitBlockPre(TIntermBlock & node)933 PreResult TIntermRebuild::visitBlockPre(TIntermBlock &node)
934 {
935 return {node, VisitBits::Both};
936 }
937
visitAggregatePre(TIntermAggregate & node)938 PreResult TIntermRebuild::visitAggregatePre(TIntermAggregate &node)
939 {
940 return {node, VisitBits::Both};
941 }
942
visitFunctionDefinitionPre(TIntermFunctionDefinition & node)943 PreResult TIntermRebuild::visitFunctionDefinitionPre(TIntermFunctionDefinition &node)
944 {
945 return {node, VisitBits::Both};
946 }
947
visitGlobalQualifierDeclarationPre(TIntermGlobalQualifierDeclaration & node)948 PreResult TIntermRebuild::visitGlobalQualifierDeclarationPre(
949 TIntermGlobalQualifierDeclaration &node)
950 {
951 return {node, VisitBits::Both};
952 }
953
954 ////////////////////////////////////////////////////////////////////////////////
955
visitSymbolPost(TIntermSymbol & node)956 PostResult TIntermRebuild::visitSymbolPost(TIntermSymbol &node)
957 {
958 return node;
959 }
960
visitConstantUnionPost(TIntermConstantUnion & node)961 PostResult TIntermRebuild::visitConstantUnionPost(TIntermConstantUnion &node)
962 {
963 return node;
964 }
965
visitFunctionPrototypePost(TIntermFunctionPrototype & node)966 PostResult TIntermRebuild::visitFunctionPrototypePost(TIntermFunctionPrototype &node)
967 {
968 return node;
969 }
970
visitPreprocessorDirectivePost(TIntermPreprocessorDirective & node)971 PostResult TIntermRebuild::visitPreprocessorDirectivePost(TIntermPreprocessorDirective &node)
972 {
973 return node;
974 }
975
visitUnaryPost(TIntermUnary & node)976 PostResult TIntermRebuild::visitUnaryPost(TIntermUnary &node)
977 {
978 return node;
979 }
980
visitBinaryPost(TIntermBinary & node)981 PostResult TIntermRebuild::visitBinaryPost(TIntermBinary &node)
982 {
983 return node;
984 }
985
visitTernaryPost(TIntermTernary & node)986 PostResult TIntermRebuild::visitTernaryPost(TIntermTernary &node)
987 {
988 return node;
989 }
990
visitSwizzlePost(TIntermSwizzle & node)991 PostResult TIntermRebuild::visitSwizzlePost(TIntermSwizzle &node)
992 {
993 return node;
994 }
995
visitIfElsePost(TIntermIfElse & node)996 PostResult TIntermRebuild::visitIfElsePost(TIntermIfElse &node)
997 {
998 return node;
999 }
1000
visitSwitchPost(TIntermSwitch & node)1001 PostResult TIntermRebuild::visitSwitchPost(TIntermSwitch &node)
1002 {
1003 return node;
1004 }
1005
visitCasePost(TIntermCase & node)1006 PostResult TIntermRebuild::visitCasePost(TIntermCase &node)
1007 {
1008 return node;
1009 }
1010
visitLoopPost(TIntermLoop & node)1011 PostResult TIntermRebuild::visitLoopPost(TIntermLoop &node)
1012 {
1013 return node;
1014 }
1015
visitBranchPost(TIntermBranch & node)1016 PostResult TIntermRebuild::visitBranchPost(TIntermBranch &node)
1017 {
1018 return node;
1019 }
1020
visitDeclarationPost(TIntermDeclaration & node)1021 PostResult TIntermRebuild::visitDeclarationPost(TIntermDeclaration &node)
1022 {
1023 return node;
1024 }
1025
visitBlockPost(TIntermBlock & node)1026 PostResult TIntermRebuild::visitBlockPost(TIntermBlock &node)
1027 {
1028 return node;
1029 }
1030
visitAggregatePost(TIntermAggregate & node)1031 PostResult TIntermRebuild::visitAggregatePost(TIntermAggregate &node)
1032 {
1033 return node;
1034 }
1035
visitFunctionDefinitionPost(TIntermFunctionDefinition & node)1036 PostResult TIntermRebuild::visitFunctionDefinitionPost(TIntermFunctionDefinition &node)
1037 {
1038 return node;
1039 }
1040
visitGlobalQualifierDeclarationPost(TIntermGlobalQualifierDeclaration & node)1041 PostResult TIntermRebuild::visitGlobalQualifierDeclarationPost(
1042 TIntermGlobalQualifierDeclaration &node)
1043 {
1044 return node;
1045 }
1046
1047 } // namespace sh
1048