1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include <algorithm>
8
9 #include "compiler/translator/Compiler.h"
10 #include "compiler/translator/SymbolTable.h"
11 #include "compiler/translator/msl/AsNode.h"
12 #include "compiler/translator/msl/IntermRebuild.h"
13
14 #define GUARD2(cond, failVal) \
15 do \
16 { \
17 if (!(cond)) \
18 { \
19 return failVal; \
20 } \
21 } while (false)
22
23 #define GUARD(cond) GUARD2(cond, nullptr)
24
25 namespace sh
26 {
27
28 template <typename T, typename U>
AllBits(T haystack,U needle)29 ANGLE_INLINE bool AllBits(T haystack, U needle)
30 {
31 return (haystack & needle) == needle;
32 }
33
34 template <typename T, typename U>
AnyBits(T haystack,U needle)35 ANGLE_INLINE bool AnyBits(T haystack, U needle)
36 {
37 return (haystack & needle) != 0;
38 }
39
40 ////////////////////////////////////////////////////////////////////////////////
41
BaseResult(BaseResult & other)42 TIntermRebuild::BaseResult::BaseResult(BaseResult &other)
43 : mAction(other.mAction),
44 mVisit(other.mVisit),
45 mSingle(other.mSingle),
46 mMulti(std::move(other.mMulti))
47 {}
48
BaseResult(TIntermNode & node,VisitBits visit)49 TIntermRebuild::BaseResult::BaseResult(TIntermNode &node, VisitBits visit)
50 : mAction(Action::ReplaceSingle), mVisit(visit), mSingle(&node)
51 {}
52
BaseResult(TIntermNode * node,VisitBits visit)53 TIntermRebuild::BaseResult::BaseResult(TIntermNode *node, VisitBits visit)
54 : mAction(node ? Action::ReplaceSingle : Action::Drop),
55 mVisit(node ? visit : VisitBits::Neither),
56 mSingle(node)
57 {}
58
BaseResult(std::nullptr_t)59 TIntermRebuild::BaseResult::BaseResult(std::nullptr_t)
60 : mAction(Action::Drop), mVisit(VisitBits::Neither), mSingle(nullptr)
61 {}
62
BaseResult(Fail)63 TIntermRebuild::BaseResult::BaseResult(Fail)
64 : mAction(Action::Fail), mVisit(VisitBits::Neither), mSingle(nullptr)
65 {}
66
BaseResult(std::vector<TIntermNode * > && nodes)67 TIntermRebuild::BaseResult::BaseResult(std::vector<TIntermNode *> &&nodes)
68 : mAction(Action::ReplaceMulti),
69 mVisit(VisitBits::Neither),
70 mSingle(nullptr),
71 mMulti(std::move(nodes))
72 {}
73
moveAssignImpl(BaseResult & other)74 void TIntermRebuild::BaseResult::moveAssignImpl(BaseResult &other)
75 {
76 mAction = other.mAction;
77 mVisit = other.mVisit;
78 mSingle = other.mSingle;
79 mMulti = std::move(other.mMulti);
80 }
81
Multi(std::vector<TIntermNode * > && nodes)82 TIntermRebuild::BaseResult TIntermRebuild::BaseResult::Multi(std::vector<TIntermNode *> &&nodes)
83 {
84 auto it = std::remove(nodes.begin(), nodes.end(), nullptr);
85 nodes.erase(it, nodes.end());
86 return std::move(nodes);
87 }
88
isFail() const89 bool TIntermRebuild::BaseResult::isFail() const
90 {
91 return mAction == Action::Fail;
92 }
93
isDrop() const94 bool TIntermRebuild::BaseResult::isDrop() const
95 {
96 return mAction == Action::Drop;
97 }
98
single() const99 TIntermNode *TIntermRebuild::BaseResult::single() const
100 {
101 return mSingle;
102 }
103
multi() const104 const std::vector<TIntermNode *> *TIntermRebuild::BaseResult::multi() const
105 {
106 if (mAction == Action::ReplaceMulti)
107 {
108 return &mMulti;
109 }
110 return nullptr;
111 }
112
113 ////////////////////////////////////////////////////////////////////////////////
114
115 using PreResult = TIntermRebuild::PreResult;
116
PreResult(TIntermNode & node,VisitBits visit)117 PreResult::PreResult(TIntermNode &node, VisitBits visit) : BaseResult(node, visit) {}
PreResult(TIntermNode * node,VisitBits visit)118 PreResult::PreResult(TIntermNode *node, VisitBits visit) : BaseResult(node, visit) {}
PreResult(std::nullptr_t)119 PreResult::PreResult(std::nullptr_t) : BaseResult(nullptr) {}
PreResult(Fail)120 PreResult::PreResult(Fail) : BaseResult(Fail()) {}
121
PreResult(BaseResult && other)122 PreResult::PreResult(BaseResult &&other) : BaseResult(other) {}
PreResult(PreResult && other)123 PreResult::PreResult(PreResult &&other) : BaseResult(other) {}
124
operator =(PreResult && other)125 void PreResult::operator=(PreResult &&other)
126 {
127 moveAssignImpl(other);
128 }
129
130 ////////////////////////////////////////////////////////////////////////////////
131
132 using PostResult = TIntermRebuild::PostResult;
133
PostResult(TIntermNode & node)134 PostResult::PostResult(TIntermNode &node) : BaseResult(node, VisitBits::Neither) {}
PostResult(TIntermNode * node)135 PostResult::PostResult(TIntermNode *node) : BaseResult(node, VisitBits::Neither) {}
PostResult(std::nullptr_t)136 PostResult::PostResult(std::nullptr_t) : BaseResult(nullptr) {}
PostResult(Fail)137 PostResult::PostResult(Fail) : BaseResult(Fail()) {}
138
PostResult(PostResult && other)139 PostResult::PostResult(PostResult &&other) : BaseResult(other) {}
PostResult(BaseResult && other)140 PostResult::PostResult(BaseResult &&other) : BaseResult(other) {}
141
operator =(PostResult && other)142 void PostResult::operator=(PostResult &&other)
143 {
144 moveAssignImpl(other);
145 }
146
147 ////////////////////////////////////////////////////////////////////////////////
148
TIntermRebuild(TCompiler & compiler,bool preVisit,bool postVisit)149 TIntermRebuild::TIntermRebuild(TCompiler &compiler, bool preVisit, bool postVisit)
150 : mCompiler(compiler),
151 mSymbolTable(compiler.getSymbolTable()),
152 mPreVisit(preVisit),
153 mPostVisit(postVisit)
154 {
155 ASSERT(preVisit || postVisit);
156 }
157
~TIntermRebuild()158 TIntermRebuild::~TIntermRebuild()
159 {
160 ASSERT(!mNodeStack.value);
161 ASSERT(!mNodeStack.tail);
162 }
163
getParentFunction() const164 const TFunction *TIntermRebuild::getParentFunction() const
165 {
166 return mParentFunc;
167 }
168
getParentNode(size_t offset) const169 TIntermNode *TIntermRebuild::getParentNode(size_t offset) const
170 {
171 ASSERT(mNodeStack.tail);
172 auto parent = *mNodeStack.tail;
173 while (offset > 0)
174 {
175 --offset;
176 ASSERT(parent.tail);
177 parent = *parent.tail;
178 }
179 return parent.value;
180 }
181
rebuildRoot(TIntermBlock & root)182 bool TIntermRebuild::rebuildRoot(TIntermBlock &root)
183 {
184 if (!rebuildInPlace(root))
185 {
186 return false;
187 }
188 return mCompiler.validateAST(&root);
189 }
190
rebuildInPlace(TIntermAggregate & node)191 bool TIntermRebuild::rebuildInPlace(TIntermAggregate &node)
192 {
193 return rebuildInPlaceImpl(node);
194 }
195
rebuildInPlace(TIntermBlock & node)196 bool TIntermRebuild::rebuildInPlace(TIntermBlock &node)
197 {
198 return rebuildInPlaceImpl(node);
199 }
200
rebuildInPlace(TIntermDeclaration & node)201 bool TIntermRebuild::rebuildInPlace(TIntermDeclaration &node)
202 {
203 return rebuildInPlaceImpl(node);
204 }
205
206 template <typename Node>
rebuildInPlaceImpl(Node & node)207 bool TIntermRebuild::rebuildInPlaceImpl(Node &node)
208 {
209 auto *newNode = traverseAnyAs<Node>(node);
210 if (!newNode)
211 {
212 return false;
213 }
214
215 if (newNode != &node)
216 {
217 *node.getSequence() = std::move(*newNode->getSequence());
218 }
219
220 return true;
221 }
222
rebuild(TIntermNode & node)223 PostResult TIntermRebuild::rebuild(TIntermNode &node)
224 {
225 return traverseAny(node);
226 }
227
228 ////////////////////////////////////////////////////////////////////////////////
229
230 template <typename Node>
traverseAnyAs(TIntermNode & node)231 Node *TIntermRebuild::traverseAnyAs(TIntermNode &node)
232 {
233 PostResult result(traverseAny(node));
234 if (result.mAction == Action::Fail || !result.mSingle)
235 {
236 return nullptr;
237 }
238 return asNode<Node>(result.mSingle);
239 }
240
241 template <typename Node>
traverseAnyAs(TIntermNode & node,Node * & out)242 bool TIntermRebuild::traverseAnyAs(TIntermNode &node, Node *&out)
243 {
244 PostResult result(traverseAny(node));
245 if (result.mAction == Action::Fail || result.mAction == Action::ReplaceMulti)
246 {
247 return false;
248 }
249 if (!result.mSingle)
250 {
251 return true;
252 }
253 out = asNode<Node>(result.mSingle);
254 return out != nullptr;
255 }
256
traverseAggregateBaseChildren(TIntermAggregateBase & node)257 bool TIntermRebuild::traverseAggregateBaseChildren(TIntermAggregateBase &node)
258 {
259 auto *const children = node.getSequence();
260 ASSERT(children);
261 TIntermSequence newChildren;
262
263 for (TIntermNode *child : *children)
264 {
265 ASSERT(child);
266 PostResult result(traverseAny(*child));
267
268 switch (result.mAction)
269 {
270 case Action::ReplaceSingle:
271 newChildren.push_back(result.mSingle);
272 break;
273
274 case Action::ReplaceMulti:
275 for (TIntermNode *newNode : result.mMulti)
276 {
277 if (newNode)
278 {
279 newChildren.push_back(newNode);
280 }
281 }
282 break;
283
284 case Action::Drop:
285 break;
286
287 case Action::Fail:
288 return false;
289
290 default:
291 ASSERT(false);
292 return false;
293 }
294 }
295
296 *children = std::move(newChildren);
297
298 return true;
299 }
300
301 ////////////////////////////////////////////////////////////////////////////////
302
303 struct TIntermRebuild::NodeStackGuard
304 {
305 ConsList<TIntermNode *> oldNodeStack;
306 ConsList<TIntermNode *> &nodeStack;
NodeStackGuardsh::TIntermRebuild::NodeStackGuard307 NodeStackGuard(ConsList<TIntermNode *> &nodeStack, TIntermNode *node)
308 : oldNodeStack(nodeStack), nodeStack(nodeStack)
309 {
310 nodeStack = {node, &oldNodeStack};
311 }
~NodeStackGuardsh::TIntermRebuild::NodeStackGuard312 ~NodeStackGuard() { nodeStack = oldNodeStack; }
313 };
314
traverseAny(TIntermNode & originalNode)315 PostResult TIntermRebuild::traverseAny(TIntermNode &originalNode)
316 {
317 PreResult preResult = traversePre(originalNode);
318 if (!preResult.mSingle)
319 {
320 ASSERT(preResult.mVisit == VisitBits::Neither);
321 return std::move(preResult);
322 }
323
324 TIntermNode *currNode = preResult.mSingle;
325 const VisitBits visit = preResult.mVisit;
326 const NodeType currNodeType = getNodeType(*currNode);
327
328 currNode = traverseChildren(currNodeType, originalNode, *currNode, visit);
329 if (!currNode)
330 {
331 return Fail();
332 }
333
334 return traversePost(currNodeType, originalNode, *currNode, visit);
335 }
336
traversePre(TIntermNode & originalNode)337 PreResult TIntermRebuild::traversePre(TIntermNode &originalNode)
338 {
339 if (!mPreVisit)
340 {
341 return {originalNode, VisitBits::Both};
342 }
343
344 NodeStackGuard guard(mNodeStack, &originalNode);
345
346 const NodeType originalNodeType = getNodeType(originalNode);
347
348 switch (originalNodeType)
349 {
350 case NodeType::Unknown:
351 ASSERT(false);
352 return Fail();
353 case NodeType::Symbol:
354 return visitSymbolPre(*originalNode.getAsSymbolNode());
355 case NodeType::ConstantUnion:
356 return visitConstantUnionPre(*originalNode.getAsConstantUnion());
357 case NodeType::FunctionPrototype:
358 return visitFunctionPrototypePre(*originalNode.getAsFunctionPrototypeNode());
359 case NodeType::PreprocessorDirective:
360 return visitPreprocessorDirectivePre(*originalNode.getAsPreprocessorDirective());
361 case NodeType::Unary:
362 return visitUnaryPre(*originalNode.getAsUnaryNode());
363 case NodeType::Binary:
364 return visitBinaryPre(*originalNode.getAsBinaryNode());
365 case NodeType::Ternary:
366 return visitTernaryPre(*originalNode.getAsTernaryNode());
367 case NodeType::Swizzle:
368 return visitSwizzlePre(*originalNode.getAsSwizzleNode());
369 case NodeType::IfElse:
370 return visitIfElsePre(*originalNode.getAsIfElseNode());
371 case NodeType::Switch:
372 return visitSwitchPre(*originalNode.getAsSwitchNode());
373 case NodeType::Case:
374 return visitCasePre(*originalNode.getAsCaseNode());
375 case NodeType::FunctionDefinition:
376 return visitFunctionDefinitionPre(*originalNode.getAsFunctionDefinition());
377 case NodeType::Aggregate:
378 return visitAggregatePre(*originalNode.getAsAggregate());
379 case NodeType::Block:
380 return visitBlockPre(*originalNode.getAsBlock());
381 case NodeType::GlobalQualifierDeclaration:
382 return visitGlobalQualifierDeclarationPre(
383 *originalNode.getAsGlobalQualifierDeclarationNode());
384 case NodeType::Declaration:
385 return visitDeclarationPre(*originalNode.getAsDeclarationNode());
386 case NodeType::Loop:
387 return visitLoopPre(*originalNode.getAsLoopNode());
388 case NodeType::Branch:
389 return visitBranchPre(*originalNode.getAsBranchNode());
390 default:
391 ASSERT(false);
392 return Fail();
393 }
394 }
395
traverseChildren(NodeType currNodeType,const TIntermNode & originalNode,TIntermNode & currNode,VisitBits visit)396 TIntermNode *TIntermRebuild::traverseChildren(NodeType currNodeType,
397 const TIntermNode &originalNode,
398 TIntermNode &currNode,
399 VisitBits visit)
400 {
401 if (!AnyBits(visit, VisitBits::Children))
402 {
403 return &currNode;
404 }
405
406 if (AnyBits(visit, VisitBits::ChildrenRequiresSame) && &originalNode != &currNode)
407 {
408 return &currNode;
409 }
410
411 NodeStackGuard guard(mNodeStack, &currNode);
412
413 switch (currNodeType)
414 {
415 case NodeType::Unknown:
416 ASSERT(false);
417 return nullptr;
418 case NodeType::Symbol:
419 return &currNode;
420 case NodeType::ConstantUnion:
421 return &currNode;
422 case NodeType::FunctionPrototype:
423 return &currNode;
424 case NodeType::PreprocessorDirective:
425 return &currNode;
426 case NodeType::Unary:
427 return traverseUnaryChildren(*currNode.getAsUnaryNode());
428 case NodeType::Binary:
429 return traverseBinaryChildren(*currNode.getAsBinaryNode());
430 case NodeType::Ternary:
431 return traverseTernaryChildren(*currNode.getAsTernaryNode());
432 case NodeType::Swizzle:
433 return traverseSwizzleChildren(*currNode.getAsSwizzleNode());
434 case NodeType::IfElse:
435 return traverseIfElseChildren(*currNode.getAsIfElseNode());
436 case NodeType::Switch:
437 return traverseSwitchChildren(*currNode.getAsSwitchNode());
438 case NodeType::Case:
439 return traverseCaseChildren(*currNode.getAsCaseNode());
440 case NodeType::FunctionDefinition:
441 return traverseFunctionDefinitionChildren(*currNode.getAsFunctionDefinition());
442 case NodeType::Aggregate:
443 return traverseAggregateChildren(*currNode.getAsAggregate());
444 case NodeType::Block:
445 return traverseBlockChildren(*currNode.getAsBlock());
446 case NodeType::GlobalQualifierDeclaration:
447 return traverseGlobalQualifierDeclarationChildren(
448 *currNode.getAsGlobalQualifierDeclarationNode());
449 case NodeType::Declaration:
450 return traverseDeclarationChildren(*currNode.getAsDeclarationNode());
451 case NodeType::Loop:
452 return traverseLoopChildren(*currNode.getAsLoopNode());
453 case NodeType::Branch:
454 return traverseBranchChildren(*currNode.getAsBranchNode());
455 default:
456 ASSERT(false);
457 return nullptr;
458 }
459 }
460
traversePost(NodeType currNodeType,const TIntermNode & originalNode,TIntermNode & currNode,VisitBits visit)461 PostResult TIntermRebuild::traversePost(NodeType currNodeType,
462 const TIntermNode &originalNode,
463 TIntermNode &currNode,
464 VisitBits visit)
465 {
466 if (!mPostVisit)
467 {
468 return currNode;
469 }
470
471 if (!AnyBits(visit, VisitBits::Post))
472 {
473 return currNode;
474 }
475
476 if (AnyBits(visit, VisitBits::PostRequiresSame) && &originalNode != &currNode)
477 {
478 return currNode;
479 }
480
481 NodeStackGuard guard(mNodeStack, &currNode);
482
483 switch (currNodeType)
484 {
485 case NodeType::Unknown:
486 ASSERT(false);
487 return Fail();
488 case NodeType::Symbol:
489 return visitSymbolPost(*currNode.getAsSymbolNode());
490 case NodeType::ConstantUnion:
491 return visitConstantUnionPost(*currNode.getAsConstantUnion());
492 case NodeType::FunctionPrototype:
493 return visitFunctionPrototypePost(*currNode.getAsFunctionPrototypeNode());
494 case NodeType::PreprocessorDirective:
495 return visitPreprocessorDirectivePost(*currNode.getAsPreprocessorDirective());
496 case NodeType::Unary:
497 return visitUnaryPost(*currNode.getAsUnaryNode());
498 case NodeType::Binary:
499 return visitBinaryPost(*currNode.getAsBinaryNode());
500 case NodeType::Ternary:
501 return visitTernaryPost(*currNode.getAsTernaryNode());
502 case NodeType::Swizzle:
503 return visitSwizzlePost(*currNode.getAsSwizzleNode());
504 case NodeType::IfElse:
505 return visitIfElsePost(*currNode.getAsIfElseNode());
506 case NodeType::Switch:
507 return visitSwitchPost(*currNode.getAsSwitchNode());
508 case NodeType::Case:
509 return visitCasePost(*currNode.getAsCaseNode());
510 case NodeType::FunctionDefinition:
511 return visitFunctionDefinitionPost(*currNode.getAsFunctionDefinition());
512 case NodeType::Aggregate:
513 return visitAggregatePost(*currNode.getAsAggregate());
514 case NodeType::Block:
515 return visitBlockPost(*currNode.getAsBlock());
516 case NodeType::GlobalQualifierDeclaration:
517 return visitGlobalQualifierDeclarationPost(
518 *currNode.getAsGlobalQualifierDeclarationNode());
519 case NodeType::Declaration:
520 return visitDeclarationPost(*currNode.getAsDeclarationNode());
521 case NodeType::Loop:
522 return visitLoopPost(*currNode.getAsLoopNode());
523 case NodeType::Branch:
524 return visitBranchPost(*currNode.getAsBranchNode());
525 default:
526 ASSERT(false);
527 return Fail();
528 }
529 }
530
531 ////////////////////////////////////////////////////////////////////////////////
532
traverseAggregateChildren(TIntermAggregate & node)533 TIntermNode *TIntermRebuild::traverseAggregateChildren(TIntermAggregate &node)
534 {
535 if (traverseAggregateBaseChildren(node))
536 {
537 return &node;
538 }
539 return nullptr;
540 }
541
traverseBlockChildren(TIntermBlock & node)542 TIntermNode *TIntermRebuild::traverseBlockChildren(TIntermBlock &node)
543 {
544 if (traverseAggregateBaseChildren(node))
545 {
546 return &node;
547 }
548 return nullptr;
549 }
550
traverseDeclarationChildren(TIntermDeclaration & node)551 TIntermNode *TIntermRebuild::traverseDeclarationChildren(TIntermDeclaration &node)
552 {
553 if (traverseAggregateBaseChildren(node))
554 {
555 return &node;
556 }
557 return nullptr;
558 }
559
traverseSwizzleChildren(TIntermSwizzle & node)560 TIntermNode *TIntermRebuild::traverseSwizzleChildren(TIntermSwizzle &node)
561 {
562 auto *const operand = node.getOperand();
563 ASSERT(operand);
564
565 auto *newOperand = traverseAnyAs<TIntermTyped>(*operand);
566 GUARD(newOperand);
567
568 if (newOperand != operand)
569 {
570 return new TIntermSwizzle(newOperand, node.getSwizzleOffsets());
571 }
572
573 return &node;
574 }
575
traverseBinaryChildren(TIntermBinary & node)576 TIntermNode *TIntermRebuild::traverseBinaryChildren(TIntermBinary &node)
577 {
578 auto *const left = node.getLeft();
579 ASSERT(left);
580 auto *const right = node.getRight();
581 ASSERT(right);
582
583 auto *const newLeft = traverseAnyAs<TIntermTyped>(*left);
584 GUARD(newLeft);
585 auto *const newRight = traverseAnyAs<TIntermTyped>(*right);
586 GUARD(newRight);
587
588 if (newLeft != left || newRight != right)
589 {
590 TOperator op = node.getOp();
591 switch (op)
592 {
593 case TOperator::EOpIndexDirectStruct:
594 {
595 if (newLeft->getType().getInterfaceBlock())
596 {
597 op = TOperator::EOpIndexDirectInterfaceBlock;
598 }
599 }
600 break;
601
602 case TOperator::EOpIndexDirectInterfaceBlock:
603 {
604 if (newLeft->getType().getStruct())
605 {
606 op = TOperator::EOpIndexDirectStruct;
607 }
608 }
609 break;
610
611 case TOperator::EOpComma:
612 return TIntermBinary::CreateComma(newLeft, newRight, mCompiler.getShaderVersion());
613
614 default:
615 break;
616 }
617
618 return new TIntermBinary(op, newLeft, newRight);
619 }
620
621 return &node;
622 }
623
traverseUnaryChildren(TIntermUnary & node)624 TIntermNode *TIntermRebuild::traverseUnaryChildren(TIntermUnary &node)
625 {
626 auto *const operand = node.getOperand();
627 ASSERT(operand);
628
629 auto *const newOperand = traverseAnyAs<TIntermTyped>(*operand);
630 GUARD(newOperand);
631
632 if (newOperand != operand)
633 {
634 return new TIntermUnary(node.getOp(), newOperand, node.getFunction());
635 }
636
637 return &node;
638 }
639
traverseTernaryChildren(TIntermTernary & node)640 TIntermNode *TIntermRebuild::traverseTernaryChildren(TIntermTernary &node)
641 {
642 auto *const cond = node.getCondition();
643 ASSERT(cond);
644 auto *const true_ = node.getTrueExpression();
645 ASSERT(true_);
646 auto *const false_ = node.getFalseExpression();
647 ASSERT(false_);
648
649 auto *const newCond = traverseAnyAs<TIntermTyped>(*cond);
650 GUARD(newCond);
651 auto *const newTrue = traverseAnyAs<TIntermTyped>(*true_);
652 GUARD(newTrue);
653 auto *const newFalse = traverseAnyAs<TIntermTyped>(*false_);
654 GUARD(newFalse);
655
656 if (newCond != cond || newTrue != true_ || newFalse != false_)
657 {
658 return new TIntermTernary(newCond, newTrue, newFalse);
659 }
660
661 return &node;
662 }
663
traverseIfElseChildren(TIntermIfElse & node)664 TIntermNode *TIntermRebuild::traverseIfElseChildren(TIntermIfElse &node)
665 {
666 auto *const cond = node.getCondition();
667 ASSERT(cond);
668 auto *const true_ = node.getTrueBlock();
669 auto *const false_ = node.getFalseBlock();
670
671 auto *const newCond = traverseAnyAs<TIntermTyped>(*cond);
672 GUARD(newCond);
673 TIntermBlock *newTrue = nullptr;
674 if (true_)
675 {
676 GUARD(traverseAnyAs(*true_, newTrue));
677 }
678 TIntermBlock *newFalse = nullptr;
679 if (false_)
680 {
681 GUARD(traverseAnyAs(*false_, newFalse));
682 }
683
684 if (newCond != cond || newTrue != true_ || newFalse != false_)
685 {
686 return new TIntermIfElse(newCond, newTrue, newFalse);
687 }
688
689 return &node;
690 }
691
traverseSwitchChildren(TIntermSwitch & node)692 TIntermNode *TIntermRebuild::traverseSwitchChildren(TIntermSwitch &node)
693 {
694 auto *const init = node.getInit();
695 ASSERT(init);
696 auto *const stmts = node.getStatementList();
697 ASSERT(stmts);
698
699 auto *const newInit = traverseAnyAs<TIntermTyped>(*init);
700 GUARD(newInit);
701 auto *const newStmts = traverseAnyAs<TIntermBlock>(*stmts);
702 GUARD(newStmts);
703
704 if (newInit != init || newStmts != stmts)
705 {
706 return new TIntermSwitch(newInit, newStmts);
707 }
708
709 return &node;
710 }
711
traverseCaseChildren(TIntermCase & node)712 TIntermNode *TIntermRebuild::traverseCaseChildren(TIntermCase &node)
713 {
714 auto *const cond = node.getCondition();
715
716 TIntermTyped *newCond = nullptr;
717 if (cond)
718 {
719 GUARD(traverseAnyAs(*cond, newCond));
720 }
721
722 if (newCond != cond)
723 {
724 return new TIntermCase(newCond);
725 }
726
727 return &node;
728 }
729
traverseFunctionDefinitionChildren(TIntermFunctionDefinition & node)730 TIntermNode *TIntermRebuild::traverseFunctionDefinitionChildren(TIntermFunctionDefinition &node)
731 {
732 GUARD(!mParentFunc); // Function definitions cannot be nested.
733 mParentFunc = node.getFunction();
734 struct OnExit
735 {
736 const TFunction *&parentFunc;
737 OnExit(const TFunction *&parentFunc) : parentFunc(parentFunc) {}
738 ~OnExit() { parentFunc = nullptr; }
739 } onExit(mParentFunc);
740
741 auto *const proto = node.getFunctionPrototype();
742 ASSERT(proto);
743 auto *const body = node.getBody();
744 ASSERT(body);
745
746 auto *const newProto = traverseAnyAs<TIntermFunctionPrototype>(*proto);
747 GUARD(newProto);
748 auto *const newBody = traverseAnyAs<TIntermBlock>(*body);
749 GUARD(newBody);
750
751 if (newProto != proto || newBody != body)
752 {
753 return new TIntermFunctionDefinition(newProto, newBody);
754 }
755
756 return &node;
757 }
758
traverseGlobalQualifierDeclarationChildren(TIntermGlobalQualifierDeclaration & node)759 TIntermNode *TIntermRebuild::traverseGlobalQualifierDeclarationChildren(
760 TIntermGlobalQualifierDeclaration &node)
761 {
762 auto *const symbol = node.getSymbol();
763 ASSERT(symbol);
764
765 auto *const newSymbol = traverseAnyAs<TIntermSymbol>(*symbol);
766 GUARD(newSymbol);
767
768 if (newSymbol != symbol)
769 {
770 return new TIntermGlobalQualifierDeclaration(newSymbol, node.isPrecise(), node.getLine());
771 }
772
773 return &node;
774 }
775
traverseLoopChildren(TIntermLoop & node)776 TIntermNode *TIntermRebuild::traverseLoopChildren(TIntermLoop &node)
777 {
778 const TLoopType loopType = node.getType();
779
780 auto *const init = node.getInit();
781 auto *const cond = node.getCondition();
782 auto *const expr = node.getExpression();
783 auto *const body = node.getBody();
784 ASSERT(body);
785
786 #if defined(ANGLE_ENABLE_ASSERTS)
787 switch (loopType)
788 {
789 case TLoopType::ELoopFor:
790 break;
791 case TLoopType::ELoopWhile:
792 case TLoopType::ELoopDoWhile:
793 ASSERT(cond);
794 ASSERT(!init && !expr);
795 break;
796 default:
797 ASSERT(false);
798 break;
799 }
800 #endif
801
802 auto *const newBody = traverseAnyAs<TIntermBlock>(*body);
803 GUARD(newBody);
804 TIntermNode *newInit = nullptr;
805 if (init)
806 {
807 GUARD(traverseAnyAs(*init, newInit));
808 }
809 TIntermTyped *newCond = nullptr;
810 if (cond)
811 {
812 GUARD(traverseAnyAs(*cond, newCond));
813 }
814 TIntermTyped *newExpr = nullptr;
815 if (expr)
816 {
817 GUARD(traverseAnyAs(*expr, newExpr));
818 }
819
820 if (newInit != init || newCond != cond || newExpr != expr || newBody != body)
821 {
822 switch (loopType)
823 {
824 case TLoopType::ELoopFor:
825 GUARD(newBody);
826 break;
827 case TLoopType::ELoopWhile:
828 case TLoopType::ELoopDoWhile:
829 GUARD(newCond && newBody);
830 GUARD(!newInit && !newExpr);
831 break;
832 default:
833 ASSERT(false);
834 break;
835 }
836 return new TIntermLoop(loopType, newInit, newCond, newExpr, newBody);
837 }
838
839 return &node;
840 }
841
traverseBranchChildren(TIntermBranch & node)842 TIntermNode *TIntermRebuild::traverseBranchChildren(TIntermBranch &node)
843 {
844 auto *const expr = node.getExpression();
845
846 TIntermTyped *newExpr = nullptr;
847 if (expr)
848 {
849 GUARD(traverseAnyAs<TIntermTyped>(*expr, newExpr));
850 }
851
852 if (newExpr != expr)
853 {
854 return new TIntermBranch(node.getFlowOp(), newExpr);
855 }
856
857 return &node;
858 }
859
860 ////////////////////////////////////////////////////////////////////////////////
861
visitSymbolPre(TIntermSymbol & node)862 PreResult TIntermRebuild::visitSymbolPre(TIntermSymbol &node)
863 {
864 return {node, VisitBits::Both};
865 }
866
visitConstantUnionPre(TIntermConstantUnion & node)867 PreResult TIntermRebuild::visitConstantUnionPre(TIntermConstantUnion &node)
868 {
869 return {node, VisitBits::Both};
870 }
871
visitFunctionPrototypePre(TIntermFunctionPrototype & node)872 PreResult TIntermRebuild::visitFunctionPrototypePre(TIntermFunctionPrototype &node)
873 {
874 return {node, VisitBits::Both};
875 }
876
visitPreprocessorDirectivePre(TIntermPreprocessorDirective & node)877 PreResult TIntermRebuild::visitPreprocessorDirectivePre(TIntermPreprocessorDirective &node)
878 {
879 return {node, VisitBits::Both};
880 }
881
visitUnaryPre(TIntermUnary & node)882 PreResult TIntermRebuild::visitUnaryPre(TIntermUnary &node)
883 {
884 return {node, VisitBits::Both};
885 }
886
visitBinaryPre(TIntermBinary & node)887 PreResult TIntermRebuild::visitBinaryPre(TIntermBinary &node)
888 {
889 return {node, VisitBits::Both};
890 }
891
visitTernaryPre(TIntermTernary & node)892 PreResult TIntermRebuild::visitTernaryPre(TIntermTernary &node)
893 {
894 return {node, VisitBits::Both};
895 }
896
visitSwizzlePre(TIntermSwizzle & node)897 PreResult TIntermRebuild::visitSwizzlePre(TIntermSwizzle &node)
898 {
899 return {node, VisitBits::Both};
900 }
901
visitIfElsePre(TIntermIfElse & node)902 PreResult TIntermRebuild::visitIfElsePre(TIntermIfElse &node)
903 {
904 return {node, VisitBits::Both};
905 }
906
visitSwitchPre(TIntermSwitch & node)907 PreResult TIntermRebuild::visitSwitchPre(TIntermSwitch &node)
908 {
909 return {node, VisitBits::Both};
910 }
911
visitCasePre(TIntermCase & node)912 PreResult TIntermRebuild::visitCasePre(TIntermCase &node)
913 {
914 return {node, VisitBits::Both};
915 }
916
visitLoopPre(TIntermLoop & node)917 PreResult TIntermRebuild::visitLoopPre(TIntermLoop &node)
918 {
919 return {node, VisitBits::Both};
920 }
921
visitBranchPre(TIntermBranch & node)922 PreResult TIntermRebuild::visitBranchPre(TIntermBranch &node)
923 {
924 return {node, VisitBits::Both};
925 }
926
visitDeclarationPre(TIntermDeclaration & node)927 PreResult TIntermRebuild::visitDeclarationPre(TIntermDeclaration &node)
928 {
929 return {node, VisitBits::Both};
930 }
931
visitBlockPre(TIntermBlock & node)932 PreResult TIntermRebuild::visitBlockPre(TIntermBlock &node)
933 {
934 return {node, VisitBits::Both};
935 }
936
visitAggregatePre(TIntermAggregate & node)937 PreResult TIntermRebuild::visitAggregatePre(TIntermAggregate &node)
938 {
939 return {node, VisitBits::Both};
940 }
941
visitFunctionDefinitionPre(TIntermFunctionDefinition & node)942 PreResult TIntermRebuild::visitFunctionDefinitionPre(TIntermFunctionDefinition &node)
943 {
944 return {node, VisitBits::Both};
945 }
946
visitGlobalQualifierDeclarationPre(TIntermGlobalQualifierDeclaration & node)947 PreResult TIntermRebuild::visitGlobalQualifierDeclarationPre(
948 TIntermGlobalQualifierDeclaration &node)
949 {
950 return {node, VisitBits::Both};
951 }
952
953 ////////////////////////////////////////////////////////////////////////////////
954
visitSymbolPost(TIntermSymbol & node)955 PostResult TIntermRebuild::visitSymbolPost(TIntermSymbol &node)
956 {
957 return node;
958 }
959
visitConstantUnionPost(TIntermConstantUnion & node)960 PostResult TIntermRebuild::visitConstantUnionPost(TIntermConstantUnion &node)
961 {
962 return node;
963 }
964
visitFunctionPrototypePost(TIntermFunctionPrototype & node)965 PostResult TIntermRebuild::visitFunctionPrototypePost(TIntermFunctionPrototype &node)
966 {
967 return node;
968 }
969
visitPreprocessorDirectivePost(TIntermPreprocessorDirective & node)970 PostResult TIntermRebuild::visitPreprocessorDirectivePost(TIntermPreprocessorDirective &node)
971 {
972 return node;
973 }
974
visitUnaryPost(TIntermUnary & node)975 PostResult TIntermRebuild::visitUnaryPost(TIntermUnary &node)
976 {
977 return node;
978 }
979
visitBinaryPost(TIntermBinary & node)980 PostResult TIntermRebuild::visitBinaryPost(TIntermBinary &node)
981 {
982 return node;
983 }
984
visitTernaryPost(TIntermTernary & node)985 PostResult TIntermRebuild::visitTernaryPost(TIntermTernary &node)
986 {
987 return node;
988 }
989
visitSwizzlePost(TIntermSwizzle & node)990 PostResult TIntermRebuild::visitSwizzlePost(TIntermSwizzle &node)
991 {
992 return node;
993 }
994
visitIfElsePost(TIntermIfElse & node)995 PostResult TIntermRebuild::visitIfElsePost(TIntermIfElse &node)
996 {
997 return node;
998 }
999
visitSwitchPost(TIntermSwitch & node)1000 PostResult TIntermRebuild::visitSwitchPost(TIntermSwitch &node)
1001 {
1002 return node;
1003 }
1004
visitCasePost(TIntermCase & node)1005 PostResult TIntermRebuild::visitCasePost(TIntermCase &node)
1006 {
1007 return node;
1008 }
1009
visitLoopPost(TIntermLoop & node)1010 PostResult TIntermRebuild::visitLoopPost(TIntermLoop &node)
1011 {
1012 return node;
1013 }
1014
visitBranchPost(TIntermBranch & node)1015 PostResult TIntermRebuild::visitBranchPost(TIntermBranch &node)
1016 {
1017 return node;
1018 }
1019
visitDeclarationPost(TIntermDeclaration & node)1020 PostResult TIntermRebuild::visitDeclarationPost(TIntermDeclaration &node)
1021 {
1022 return node;
1023 }
1024
visitBlockPost(TIntermBlock & node)1025 PostResult TIntermRebuild::visitBlockPost(TIntermBlock &node)
1026 {
1027 return node;
1028 }
1029
visitAggregatePost(TIntermAggregate & node)1030 PostResult TIntermRebuild::visitAggregatePost(TIntermAggregate &node)
1031 {
1032 return node;
1033 }
1034
visitFunctionDefinitionPost(TIntermFunctionDefinition & node)1035 PostResult TIntermRebuild::visitFunctionDefinitionPost(TIntermFunctionDefinition &node)
1036 {
1037 return node;
1038 }
1039
visitGlobalQualifierDeclarationPost(TIntermGlobalQualifierDeclaration & node)1040 PostResult TIntermRebuild::visitGlobalQualifierDeclarationPost(
1041 TIntermGlobalQualifierDeclaration &node)
1042 {
1043 return node;
1044 }
1045
1046 } // namespace sh
1047