1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include <algorithm>
8
9 #include "compiler/translator/Compiler.h"
10 #include "compiler/translator/SymbolTable.h"
11 #include "compiler/translator/TranslatorMetalDirect/IntermRebuild.h"
12 #include "compiler/translator/tree_util/AsNode.h"
13
14 #define GUARD2(cond, failVal) \
15 do \
16 { \
17 if (!(cond)) \
18 { \
19 return failVal; \
20 } \
21 } while (false)
22
23 #define GUARD(cond) GUARD2(cond, nullptr)
24
25 namespace sh
26 {
27
28 template <typename T, typename U>
AllBits(T haystack,U needle)29 ANGLE_INLINE bool AllBits(T haystack, U needle)
30 {
31 return (haystack & needle) == needle;
32 }
33
34 template <typename T, typename U>
AnyBits(T haystack,U needle)35 ANGLE_INLINE bool AnyBits(T haystack, U needle)
36 {
37 return (haystack & needle) != 0;
38 }
39
40 ////////////////////////////////////////////////////////////////////////////////
41
BaseResult(BaseResult & other)42 TIntermRebuild::BaseResult::BaseResult(BaseResult &other)
43 : mAction(other.mAction),
44 mVisit(other.mVisit),
45 mSingle(other.mSingle),
46 mMulti(std::move(other.mMulti))
47 {}
48
BaseResult(TIntermNode & node,VisitBits visit)49 TIntermRebuild::BaseResult::BaseResult(TIntermNode &node, VisitBits visit)
50 : mAction(Action::ReplaceSingle), mVisit(visit), mSingle(&node)
51 {}
52
BaseResult(TIntermNode * node,VisitBits visit)53 TIntermRebuild::BaseResult::BaseResult(TIntermNode *node, VisitBits visit)
54 : mAction(node ? Action::ReplaceSingle : Action::Drop),
55 mVisit(node ? visit : VisitBits::Neither),
56 mSingle(node)
57 {}
58
BaseResult(nullptr_t)59 TIntermRebuild::BaseResult::BaseResult(nullptr_t)
60 : mAction(Action::Drop), mVisit(VisitBits::Neither), mSingle(nullptr)
61 {}
62
BaseResult(Fail)63 TIntermRebuild::BaseResult::BaseResult(Fail)
64 : mAction(Action::Fail), mVisit(VisitBits::Neither), mSingle(nullptr)
65 {}
66
BaseResult(std::vector<TIntermNode * > && nodes)67 TIntermRebuild::BaseResult::BaseResult(std::vector<TIntermNode *> &&nodes)
68 : mAction(Action::ReplaceMulti),
69 mVisit(VisitBits::Neither),
70 mSingle(nullptr),
71 mMulti(std::move(nodes))
72 {}
73
moveAssignImpl(BaseResult & other)74 void TIntermRebuild::BaseResult::moveAssignImpl(BaseResult &other)
75 {
76 mAction = other.mAction;
77 mVisit = other.mVisit;
78 mSingle = other.mSingle;
79 mMulti = std::move(other.mMulti);
80 }
81
Multi(std::vector<TIntermNode * > && nodes)82 TIntermRebuild::BaseResult TIntermRebuild::BaseResult::Multi(std::vector<TIntermNode *> &&nodes)
83 {
84 auto it = std::remove(nodes.begin(), nodes.end(), nullptr);
85 nodes.erase(it, nodes.end());
86 return std::move(nodes);
87 }
88
isFail() const89 bool TIntermRebuild::BaseResult::isFail() const
90 {
91 return mAction == Action::Fail;
92 }
93
isDrop() const94 bool TIntermRebuild::BaseResult::isDrop() const
95 {
96 return mAction == Action::Drop;
97 }
98
single() const99 TIntermNode *TIntermRebuild::BaseResult::single() const
100 {
101 return mSingle;
102 }
103
multi() const104 const std::vector<TIntermNode *> *TIntermRebuild::BaseResult::multi() const
105 {
106 if (mAction == Action::ReplaceMulti)
107 {
108 return &mMulti;
109 }
110 return nullptr;
111 }
112
113 ////////////////////////////////////////////////////////////////////////////////
114
115 using PreResult = TIntermRebuild::PreResult;
116
PreResult(TIntermNode & node,VisitBits visit)117 PreResult::PreResult(TIntermNode &node, VisitBits visit) : BaseResult(node, visit) {}
PreResult(TIntermNode * node,VisitBits visit)118 PreResult::PreResult(TIntermNode *node, VisitBits visit) : BaseResult(node, visit) {}
PreResult(nullptr_t)119 PreResult::PreResult(nullptr_t) : BaseResult(nullptr) {}
PreResult(Fail)120 PreResult::PreResult(Fail) : BaseResult(Fail()) {}
121
PreResult(BaseResult && other)122 PreResult::PreResult(BaseResult &&other) : BaseResult(other) {}
PreResult(PreResult && other)123 PreResult::PreResult(PreResult &&other) : BaseResult(other) {}
124
operator =(PreResult && other)125 void PreResult::operator=(PreResult &&other)
126 {
127 moveAssignImpl(other);
128 }
129
130 ////////////////////////////////////////////////////////////////////////////////
131
132 using PostResult = TIntermRebuild::PostResult;
133
PostResult(TIntermNode & node)134 PostResult::PostResult(TIntermNode &node) : BaseResult(node, VisitBits::Neither) {}
PostResult(TIntermNode * node)135 PostResult::PostResult(TIntermNode *node) : BaseResult(node, VisitBits::Neither) {}
PostResult(nullptr_t)136 PostResult::PostResult(nullptr_t) : BaseResult(nullptr) {}
PostResult(Fail)137 PostResult::PostResult(Fail) : BaseResult(Fail()) {}
138
PostResult(PostResult && other)139 PostResult::PostResult(PostResult &&other) : BaseResult(other) {}
PostResult(BaseResult && other)140 PostResult::PostResult(BaseResult &&other) : BaseResult(other) {}
141
operator =(PostResult && other)142 void PostResult::operator=(PostResult &&other)
143 {
144 moveAssignImpl(other);
145 }
146
147 ////////////////////////////////////////////////////////////////////////////////
148
TIntermRebuild(TCompiler & compiler,bool preVisit,bool postVisit)149 TIntermRebuild::TIntermRebuild(TCompiler &compiler, bool preVisit, bool postVisit)
150 : mCompiler(compiler),
151 mSymbolTable(compiler.getSymbolTable()),
152 mPreVisit(preVisit),
153 mPostVisit(postVisit)
154 {
155 ASSERT(preVisit || postVisit);
156 }
157
~TIntermRebuild()158 TIntermRebuild::~TIntermRebuild()
159 {
160 ASSERT(!mNodeStack.value);
161 ASSERT(!mNodeStack.tail);
162 }
163
getParentFunction() const164 const TFunction *TIntermRebuild::getParentFunction() const
165 {
166 return mParentFunc;
167 }
168
getParentNode(size_t offset) const169 TIntermNode *TIntermRebuild::getParentNode(size_t offset) const
170 {
171 ASSERT(mNodeStack.tail);
172 auto parent = *mNodeStack.tail;
173 while (offset > 0)
174 {
175 --offset;
176 ASSERT(parent.tail);
177 parent = *parent.tail;
178 }
179 return parent.value;
180 }
181
rebuildRoot(TIntermBlock & root)182 bool TIntermRebuild::rebuildRoot(TIntermBlock &root)
183 {
184 if (!rebuildInPlace(root))
185 {
186 return false;
187 }
188 return mCompiler.validateAST(&root);
189 }
190
rebuildInPlace(TIntermAggregate & node)191 bool TIntermRebuild::rebuildInPlace(TIntermAggregate &node)
192 {
193 return rebuildInPlaceImpl(node);
194 }
195
rebuildInPlace(TIntermBlock & node)196 bool TIntermRebuild::rebuildInPlace(TIntermBlock &node)
197 {
198 return rebuildInPlaceImpl(node);
199 }
200
rebuildInPlace(TIntermDeclaration & node)201 bool TIntermRebuild::rebuildInPlace(TIntermDeclaration &node)
202 {
203 return rebuildInPlaceImpl(node);
204 }
205
206 template <typename Node>
rebuildInPlaceImpl(Node & node)207 bool TIntermRebuild::rebuildInPlaceImpl(Node &node)
208 {
209 auto *newNode = traverseAnyAs<Node>(node);
210 if (!newNode)
211 {
212 return false;
213 }
214
215 if (newNode != &node)
216 {
217 *node.getSequence() = std::move(*newNode->getSequence());
218 }
219
220 return true;
221 }
222
rebuild(TIntermNode & node)223 PostResult TIntermRebuild::rebuild(TIntermNode &node)
224 {
225 return traverseAny(node);
226 }
227
228 ////////////////////////////////////////////////////////////////////////////////
229
230 template <typename Node>
traverseAnyAs(TIntermNode & node)231 Node *TIntermRebuild::traverseAnyAs(TIntermNode &node)
232 {
233 PostResult result(traverseAny(node));
234 if (result.mAction == Action::Fail || !result.mSingle)
235 {
236 return nullptr;
237 }
238 return asNode<Node>(result.mSingle);
239 }
240
241 template <typename Node>
traverseAnyAs(TIntermNode & node,Node * & out)242 bool TIntermRebuild::traverseAnyAs(TIntermNode &node, Node *&out)
243 {
244 PostResult result(traverseAny(node));
245 if (result.mAction == Action::Fail || result.mAction == Action::ReplaceMulti)
246 {
247 return false;
248 }
249 if (!result.mSingle)
250 {
251 return true;
252 }
253 out = asNode<Node>(result.mSingle);
254 return out;
255 }
256
traverseAggregateBaseChildren(TIntermAggregateBase & node)257 bool TIntermRebuild::traverseAggregateBaseChildren(TIntermAggregateBase &node)
258 {
259 auto *const children = node.getSequence();
260 ASSERT(children);
261 TIntermSequence newChildren;
262
263 for (TIntermNode *child : *children)
264 {
265 ASSERT(child);
266 PostResult result(traverseAny(*child));
267
268 switch (result.mAction)
269 {
270 case Action::ReplaceSingle:
271 newChildren.push_back(result.mSingle);
272 break;
273
274 case Action::ReplaceMulti:
275 for (TIntermNode *newNode : result.mMulti)
276 {
277 if (newNode)
278 {
279 newChildren.push_back(newNode);
280 }
281 }
282 break;
283
284 case Action::Drop:
285 break;
286
287 case Action::Fail:
288 return false;
289 }
290 }
291
292 *children = std::move(newChildren);
293
294 return true;
295 }
296
297 ////////////////////////////////////////////////////////////////////////////////
298
299 struct TIntermRebuild::NodeStackGuard
300 {
301 ConsList<TIntermNode *> oldNodeStack;
302 ConsList<TIntermNode *> &nodeStack;
NodeStackGuardsh::TIntermRebuild::NodeStackGuard303 NodeStackGuard(ConsList<TIntermNode *> &nodeStack)
304 : oldNodeStack(nodeStack), nodeStack(nodeStack)
305 {}
~NodeStackGuardsh::TIntermRebuild::NodeStackGuard306 ~NodeStackGuard() { nodeStack = oldNodeStack; }
307 };
308
traverseAny(TIntermNode & originalNode)309 PostResult TIntermRebuild::traverseAny(TIntermNode &originalNode)
310 {
311 PreResult preResult = traversePre(originalNode);
312 if (!preResult.mSingle)
313 {
314 ASSERT(preResult.mVisit == VisitBits::Neither);
315 return std::move(preResult);
316 }
317
318 TIntermNode *currNode = preResult.mSingle;
319 const VisitBits visit = preResult.mVisit;
320 const NodeType currNodeType = getNodeType(*currNode);
321
322 currNode = traverseChildren(currNodeType, originalNode, *currNode, visit);
323 if (!currNode)
324 {
325 return Fail();
326 }
327
328 return traversePost(currNodeType, originalNode, *currNode, visit);
329 }
330
traversePre(TIntermNode & originalNode)331 PreResult TIntermRebuild::traversePre(TIntermNode &originalNode)
332 {
333 if (!mPreVisit)
334 {
335 return {originalNode, VisitBits::Both};
336 }
337
338 NodeStackGuard guard(mNodeStack);
339 mNodeStack = {&originalNode, &guard.oldNodeStack};
340
341 const NodeType originalNodeType = getNodeType(originalNode);
342
343 switch (originalNodeType)
344 {
345 case NodeType::Unknown:
346 ASSERT(false);
347 return Fail();
348 case NodeType::Symbol:
349 return visitSymbolPre(*originalNode.getAsSymbolNode());
350 case NodeType::ConstantUnion:
351 return visitConstantUnionPre(*originalNode.getAsConstantUnion());
352 case NodeType::FunctionPrototype:
353 return visitFunctionPrototypePre(*originalNode.getAsFunctionPrototypeNode());
354 case NodeType::PreprocessorDirective:
355 return visitPreprocessorDirectivePre(*originalNode.getAsPreprocessorDirective());
356 case NodeType::Unary:
357 return visitUnaryPre(*originalNode.getAsUnaryNode());
358 case NodeType::Binary:
359 return visitBinaryPre(*originalNode.getAsBinaryNode());
360 case NodeType::Ternary:
361 return visitTernaryPre(*originalNode.getAsTernaryNode());
362 case NodeType::Swizzle:
363 return visitSwizzlePre(*originalNode.getAsSwizzleNode());
364 case NodeType::IfElse:
365 return visitIfElsePre(*originalNode.getAsIfElseNode());
366 case NodeType::Switch:
367 return visitSwitchPre(*originalNode.getAsSwitchNode());
368 case NodeType::Case:
369 return visitCasePre(*originalNode.getAsCaseNode());
370 case NodeType::FunctionDefinition:
371 return visitFunctionDefinitionPre(*originalNode.getAsFunctionDefinition());
372 case NodeType::Aggregate:
373 return visitAggregatePre(*originalNode.getAsAggregate());
374 case NodeType::Block:
375 return visitBlockPre(*originalNode.getAsBlock());
376 case NodeType::GlobalQualifierDeclaration:
377 return visitGlobalQualifierDeclarationPre(
378 *originalNode.getAsGlobalQualifierDeclarationNode());
379 case NodeType::Declaration:
380 return visitDeclarationPre(*originalNode.getAsDeclarationNode());
381 case NodeType::Loop:
382 return visitLoopPre(*originalNode.getAsLoopNode());
383 case NodeType::Branch:
384 return visitBranchPre(*originalNode.getAsBranchNode());
385 }
386 }
387
traverseChildren(NodeType currNodeType,const TIntermNode & originalNode,TIntermNode & currNode,VisitBits visit)388 TIntermNode *TIntermRebuild::traverseChildren(NodeType currNodeType,
389 const TIntermNode &originalNode,
390 TIntermNode &currNode,
391 VisitBits visit)
392 {
393 if (!AnyBits(visit, VisitBits::Children))
394 {
395 return &currNode;
396 }
397
398 if (AnyBits(visit, VisitBits::ChildrenRequiresSame) && &originalNode != &currNode)
399 {
400 return &currNode;
401 }
402
403 NodeStackGuard guard(mNodeStack);
404 mNodeStack = {&currNode, &guard.oldNodeStack};
405
406 switch (currNodeType)
407 {
408 case NodeType::Unknown:
409 ASSERT(false);
410 return nullptr;
411 case NodeType::Symbol:
412 return &currNode;
413 case NodeType::ConstantUnion:
414 return &currNode;
415 case NodeType::FunctionPrototype:
416 return &currNode;
417 case NodeType::PreprocessorDirective:
418 return &currNode;
419 case NodeType::Unary:
420 return traverseUnaryChildren(*currNode.getAsUnaryNode());
421 case NodeType::Binary:
422 return traverseBinaryChildren(*currNode.getAsBinaryNode());
423 case NodeType::Ternary:
424 return traverseTernaryChildren(*currNode.getAsTernaryNode());
425 case NodeType::Swizzle:
426 return traverseSwizzleChildren(*currNode.getAsSwizzleNode());
427 case NodeType::IfElse:
428 return traverseIfElseChildren(*currNode.getAsIfElseNode());
429 case NodeType::Switch:
430 return traverseSwitchChildren(*currNode.getAsSwitchNode());
431 case NodeType::Case:
432 return traverseCaseChildren(*currNode.getAsCaseNode());
433 case NodeType::FunctionDefinition:
434 return traverseFunctionDefinitionChildren(*currNode.getAsFunctionDefinition());
435 case NodeType::Aggregate:
436 return traverseAggregateChildren(*currNode.getAsAggregate());
437 case NodeType::Block:
438 return traverseBlockChildren(*currNode.getAsBlock());
439 case NodeType::GlobalQualifierDeclaration:
440 return traverseGlobalQualifierDeclarationChildren(
441 *currNode.getAsGlobalQualifierDeclarationNode());
442 case NodeType::Declaration:
443 return traverseDeclarationChildren(*currNode.getAsDeclarationNode());
444 case NodeType::Loop:
445 return traverseLoopChildren(*currNode.getAsLoopNode());
446 case NodeType::Branch:
447 return traverseBranchChildren(*currNode.getAsBranchNode());
448 }
449 }
450
traversePost(NodeType currNodeType,const TIntermNode & originalNode,TIntermNode & currNode,VisitBits visit)451 PostResult TIntermRebuild::traversePost(NodeType currNodeType,
452 const TIntermNode &originalNode,
453 TIntermNode &currNode,
454 VisitBits visit)
455 {
456 if (!mPostVisit)
457 {
458 return currNode;
459 }
460
461 if (!AnyBits(visit, VisitBits::Post))
462 {
463 return currNode;
464 }
465
466 if (AnyBits(visit, VisitBits::PostRequiresSame) && &originalNode != &currNode)
467 {
468 return currNode;
469 }
470
471 NodeStackGuard guard(mNodeStack);
472 mNodeStack = {&currNode, &guard.oldNodeStack};
473
474 switch (currNodeType)
475 {
476 case NodeType::Unknown:
477 ASSERT(false);
478 return Fail();
479 case NodeType::Symbol:
480 return visitSymbolPost(*currNode.getAsSymbolNode());
481 case NodeType::ConstantUnion:
482 return visitConstantUnionPost(*currNode.getAsConstantUnion());
483 case NodeType::FunctionPrototype:
484 return visitFunctionPrototypePost(*currNode.getAsFunctionPrototypeNode());
485 case NodeType::PreprocessorDirective:
486 return visitPreprocessorDirectivePost(*currNode.getAsPreprocessorDirective());
487 case NodeType::Unary:
488 return visitUnaryPost(*currNode.getAsUnaryNode());
489 case NodeType::Binary:
490 return visitBinaryPost(*currNode.getAsBinaryNode());
491 case NodeType::Ternary:
492 return visitTernaryPost(*currNode.getAsTernaryNode());
493 case NodeType::Swizzle:
494 return visitSwizzlePost(*currNode.getAsSwizzleNode());
495 case NodeType::IfElse:
496 return visitIfElsePost(*currNode.getAsIfElseNode());
497 case NodeType::Switch:
498 return visitSwitchPost(*currNode.getAsSwitchNode());
499 case NodeType::Case:
500 return visitCasePost(*currNode.getAsCaseNode());
501 case NodeType::FunctionDefinition:
502 return visitFunctionDefinitionPost(*currNode.getAsFunctionDefinition());
503 case NodeType::Aggregate:
504 return visitAggregatePost(*currNode.getAsAggregate());
505 case NodeType::Block:
506 return visitBlockPost(*currNode.getAsBlock());
507 case NodeType::GlobalQualifierDeclaration:
508 return visitGlobalQualifierDeclarationPost(
509 *currNode.getAsGlobalQualifierDeclarationNode());
510 case NodeType::Declaration:
511 return visitDeclarationPost(*currNode.getAsDeclarationNode());
512 case NodeType::Loop:
513 return visitLoopPost(*currNode.getAsLoopNode());
514 case NodeType::Branch:
515 return visitBranchPost(*currNode.getAsBranchNode());
516 }
517 }
518
519 ////////////////////////////////////////////////////////////////////////////////
520
traverseAggregateChildren(TIntermAggregate & node)521 TIntermNode *TIntermRebuild::traverseAggregateChildren(TIntermAggregate &node)
522 {
523 if (traverseAggregateBaseChildren(node))
524 {
525 return &node;
526 }
527 return nullptr;
528 }
529
traverseBlockChildren(TIntermBlock & node)530 TIntermNode *TIntermRebuild::traverseBlockChildren(TIntermBlock &node)
531 {
532 if (traverseAggregateBaseChildren(node))
533 {
534 return &node;
535 }
536 return nullptr;
537 }
538
traverseDeclarationChildren(TIntermDeclaration & node)539 TIntermNode *TIntermRebuild::traverseDeclarationChildren(TIntermDeclaration &node)
540 {
541 if (traverseAggregateBaseChildren(node))
542 {
543 return &node;
544 }
545 return nullptr;
546 }
547
traverseSwizzleChildren(TIntermSwizzle & node)548 TIntermNode *TIntermRebuild::traverseSwizzleChildren(TIntermSwizzle &node)
549 {
550 auto *const operand = node.getOperand();
551 ASSERT(operand);
552
553 auto *newOperand = traverseAnyAs<TIntermTyped>(*operand);
554 GUARD(newOperand);
555
556 if (newOperand != operand)
557 {
558 return new TIntermSwizzle(newOperand, node.getSwizzleOffsets());
559 }
560
561 return &node;
562 }
563
traverseBinaryChildren(TIntermBinary & node)564 TIntermNode *TIntermRebuild::traverseBinaryChildren(TIntermBinary &node)
565 {
566 auto *const left = node.getLeft();
567 ASSERT(left);
568 auto *const right = node.getRight();
569 ASSERT(right);
570
571 auto *const newLeft = traverseAnyAs<TIntermTyped>(*left);
572 GUARD(newLeft);
573 auto *const newRight = traverseAnyAs<TIntermTyped>(*right);
574 GUARD(newRight);
575
576 if (newLeft != left || newRight != right)
577 {
578 TOperator op = node.getOp();
579 switch (op)
580 {
581 case TOperator::EOpIndexDirectStruct:
582 {
583 if (newLeft->getType().getInterfaceBlock())
584 {
585 op = TOperator::EOpIndexDirectInterfaceBlock;
586 }
587 }
588 break;
589
590 case TOperator::EOpIndexDirectInterfaceBlock:
591 {
592 if (newLeft->getType().getStruct())
593 {
594 op = TOperator::EOpIndexDirectStruct;
595 }
596 }
597 break;
598
599 case TOperator::EOpComma:
600 return TIntermBinary::CreateComma(newLeft, newRight, mCompiler.getShaderVersion());
601
602 default:
603 break;
604 }
605
606 return new TIntermBinary(op, newLeft, newRight);
607 }
608
609 return &node;
610 }
611
traverseUnaryChildren(TIntermUnary & node)612 TIntermNode *TIntermRebuild::traverseUnaryChildren(TIntermUnary &node)
613 {
614 auto *const operand = node.getOperand();
615 ASSERT(operand);
616
617 auto *const newOperand = traverseAnyAs<TIntermTyped>(*operand);
618 GUARD(newOperand);
619
620 if (newOperand != operand)
621 {
622 return new TIntermUnary(node.getOp(), newOperand, node.getFunction());
623 }
624
625 return &node;
626 }
627
traverseTernaryChildren(TIntermTernary & node)628 TIntermNode *TIntermRebuild::traverseTernaryChildren(TIntermTernary &node)
629 {
630 auto *const cond = node.getCondition();
631 ASSERT(cond);
632 auto *const true_ = node.getTrueExpression();
633 ASSERT(true_);
634 auto *const false_ = node.getFalseExpression();
635 ASSERT(false_);
636
637 auto *const newCond = traverseAnyAs<TIntermTyped>(*cond);
638 GUARD(newCond);
639 auto *const newTrue = traverseAnyAs<TIntermTyped>(*true_);
640 GUARD(newTrue);
641 auto *const newFalse = traverseAnyAs<TIntermTyped>(*false_);
642 GUARD(newFalse);
643
644 if (newCond != cond || newTrue != true_ || newFalse != false_)
645 {
646 return new TIntermTernary(newCond, newTrue, newFalse);
647 }
648
649 return &node;
650 }
651
traverseIfElseChildren(TIntermIfElse & node)652 TIntermNode *TIntermRebuild::traverseIfElseChildren(TIntermIfElse &node)
653 {
654 auto *const cond = node.getCondition();
655 ASSERT(cond);
656 auto *const true_ = node.getTrueBlock();
657 auto *const false_ = node.getFalseBlock();
658
659 auto *const newCond = traverseAnyAs<TIntermTyped>(*cond);
660 GUARD(newCond);
661 TIntermBlock *newTrue = nullptr;
662 if (true_)
663 {
664 GUARD(traverseAnyAs(*true_, newTrue));
665 }
666 TIntermBlock *newFalse = nullptr;
667 if (false_)
668 {
669 GUARD(traverseAnyAs(*false_, newFalse));
670 }
671
672 if (newCond != cond || newTrue != true_ || newFalse != false_)
673 {
674 return new TIntermIfElse(newCond, newTrue, newFalse);
675 }
676
677 return &node;
678 }
679
traverseSwitchChildren(TIntermSwitch & node)680 TIntermNode *TIntermRebuild::traverseSwitchChildren(TIntermSwitch &node)
681 {
682 auto *const init = node.getInit();
683 ASSERT(init);
684 auto *const stmts = node.getStatementList();
685 ASSERT(stmts);
686
687 auto *const newInit = traverseAnyAs<TIntermTyped>(*init);
688 GUARD(newInit);
689 auto *const newStmts = traverseAnyAs<TIntermBlock>(*stmts);
690 GUARD(newStmts);
691
692 if (newInit != init || newStmts != stmts)
693 {
694 return new TIntermSwitch(newInit, newStmts);
695 }
696
697 return &node;
698 }
699
traverseCaseChildren(TIntermCase & node)700 TIntermNode *TIntermRebuild::traverseCaseChildren(TIntermCase &node)
701 {
702 auto *const cond = node.getCondition();
703
704 TIntermTyped *newCond = nullptr;
705 if (cond)
706 {
707 GUARD(traverseAnyAs(*cond, newCond));
708 }
709
710 if (newCond != cond)
711 {
712 return new TIntermCase(newCond);
713 }
714
715 return &node;
716 }
717
traverseFunctionDefinitionChildren(TIntermFunctionDefinition & node)718 TIntermNode *TIntermRebuild::traverseFunctionDefinitionChildren(TIntermFunctionDefinition &node)
719 {
720 GUARD(!mParentFunc); // Function definitions cannot be nested.
721 mParentFunc = node.getFunction();
722 struct OnExit
723 {
724 const TFunction *&parentFunc;
725 OnExit(const TFunction *&parentFunc) : parentFunc(parentFunc) {}
726 ~OnExit() { parentFunc = nullptr; }
727 } onExit(mParentFunc);
728
729 auto *const proto = node.getFunctionPrototype();
730 ASSERT(proto);
731 auto *const body = node.getBody();
732 ASSERT(body);
733
734 auto *const newProto = traverseAnyAs<TIntermFunctionPrototype>(*proto);
735 GUARD(newProto);
736 auto *const newBody = traverseAnyAs<TIntermBlock>(*body);
737 GUARD(newBody);
738
739 if (newProto != proto || newBody != body)
740 {
741 return new TIntermFunctionDefinition(newProto, newBody);
742 }
743
744 return &node;
745 }
746
traverseGlobalQualifierDeclarationChildren(TIntermGlobalQualifierDeclaration & node)747 TIntermNode *TIntermRebuild::traverseGlobalQualifierDeclarationChildren(
748 TIntermGlobalQualifierDeclaration &node)
749 {
750 auto *const symbol = node.getSymbol();
751 ASSERT(symbol);
752
753 auto *const newSymbol = traverseAnyAs<TIntermSymbol>(*symbol);
754 GUARD(newSymbol);
755
756 if (newSymbol != symbol)
757 {
758 return new TIntermGlobalQualifierDeclaration(newSymbol, node.isPrecise(), node.getLine());
759 }
760
761 return &node;
762 }
763
traverseLoopChildren(TIntermLoop & node)764 TIntermNode *TIntermRebuild::traverseLoopChildren(TIntermLoop &node)
765 {
766 const TLoopType loopType = node.getType();
767
768 auto *const init = node.getInit();
769 auto *const cond = node.getCondition();
770 auto *const expr = node.getExpression();
771 auto *const body = node.getBody();
772 ASSERT(body);
773
774 #if defined(ANGLE_ENABLE_ASSERTS)
775 switch (loopType)
776 {
777 case TLoopType::ELoopFor:
778 break;
779 case TLoopType::ELoopWhile:
780 case TLoopType::ELoopDoWhile:
781 ASSERT(cond);
782 ASSERT(!init && !expr);
783 break;
784 }
785 #endif
786
787 auto *const newBody = traverseAnyAs<TIntermBlock>(*body);
788 GUARD(newBody);
789 TIntermNode *newInit = nullptr;
790 if (init)
791 {
792 GUARD(traverseAnyAs(*init, newInit));
793 }
794 TIntermTyped *newCond = nullptr;
795 if (cond)
796 {
797 GUARD(traverseAnyAs(*cond, newCond));
798 }
799 TIntermTyped *newExpr = nullptr;
800 if (expr)
801 {
802 GUARD(traverseAnyAs(*expr, newExpr));
803 }
804
805 if (newInit != init || newCond != cond || newExpr != expr || newBody != body)
806 {
807 switch (loopType)
808 {
809 case TLoopType::ELoopFor:
810 GUARD(newBody);
811 break;
812 case TLoopType::ELoopWhile:
813 case TLoopType::ELoopDoWhile:
814 GUARD(newCond && newBody);
815 GUARD(!newInit && !newExpr);
816 break;
817 }
818 return new TIntermLoop(loopType, newInit, newCond, newExpr, newBody);
819 }
820
821 return &node;
822 }
823
traverseBranchChildren(TIntermBranch & node)824 TIntermNode *TIntermRebuild::traverseBranchChildren(TIntermBranch &node)
825 {
826 auto *const expr = node.getExpression();
827
828 TIntermTyped *newExpr = nullptr;
829 if (expr)
830 {
831 GUARD(traverseAnyAs<TIntermTyped>(*expr, newExpr));
832 }
833
834 if (newExpr != expr)
835 {
836 return new TIntermBranch(node.getFlowOp(), newExpr);
837 }
838
839 return &node;
840 }
841
842 ////////////////////////////////////////////////////////////////////////////////
843
visitSymbolPre(TIntermSymbol & node)844 PreResult TIntermRebuild::visitSymbolPre(TIntermSymbol &node)
845 {
846 return {node, VisitBits::Both};
847 }
848
visitConstantUnionPre(TIntermConstantUnion & node)849 PreResult TIntermRebuild::visitConstantUnionPre(TIntermConstantUnion &node)
850 {
851 return {node, VisitBits::Both};
852 }
853
visitFunctionPrototypePre(TIntermFunctionPrototype & node)854 PreResult TIntermRebuild::visitFunctionPrototypePre(TIntermFunctionPrototype &node)
855 {
856 return {node, VisitBits::Both};
857 }
858
visitPreprocessorDirectivePre(TIntermPreprocessorDirective & node)859 PreResult TIntermRebuild::visitPreprocessorDirectivePre(TIntermPreprocessorDirective &node)
860 {
861 return {node, VisitBits::Both};
862 }
863
visitUnaryPre(TIntermUnary & node)864 PreResult TIntermRebuild::visitUnaryPre(TIntermUnary &node)
865 {
866 return {node, VisitBits::Both};
867 }
868
visitBinaryPre(TIntermBinary & node)869 PreResult TIntermRebuild::visitBinaryPre(TIntermBinary &node)
870 {
871 return {node, VisitBits::Both};
872 }
873
visitTernaryPre(TIntermTernary & node)874 PreResult TIntermRebuild::visitTernaryPre(TIntermTernary &node)
875 {
876 return {node, VisitBits::Both};
877 }
878
visitSwizzlePre(TIntermSwizzle & node)879 PreResult TIntermRebuild::visitSwizzlePre(TIntermSwizzle &node)
880 {
881 return {node, VisitBits::Both};
882 }
883
visitIfElsePre(TIntermIfElse & node)884 PreResult TIntermRebuild::visitIfElsePre(TIntermIfElse &node)
885 {
886 return {node, VisitBits::Both};
887 }
888
visitSwitchPre(TIntermSwitch & node)889 PreResult TIntermRebuild::visitSwitchPre(TIntermSwitch &node)
890 {
891 return {node, VisitBits::Both};
892 }
893
visitCasePre(TIntermCase & node)894 PreResult TIntermRebuild::visitCasePre(TIntermCase &node)
895 {
896 return {node, VisitBits::Both};
897 }
898
visitLoopPre(TIntermLoop & node)899 PreResult TIntermRebuild::visitLoopPre(TIntermLoop &node)
900 {
901 return {node, VisitBits::Both};
902 }
903
visitBranchPre(TIntermBranch & node)904 PreResult TIntermRebuild::visitBranchPre(TIntermBranch &node)
905 {
906 return {node, VisitBits::Both};
907 }
908
visitDeclarationPre(TIntermDeclaration & node)909 PreResult TIntermRebuild::visitDeclarationPre(TIntermDeclaration &node)
910 {
911 return {node, VisitBits::Both};
912 }
913
visitBlockPre(TIntermBlock & node)914 PreResult TIntermRebuild::visitBlockPre(TIntermBlock &node)
915 {
916 return {node, VisitBits::Both};
917 }
918
visitAggregatePre(TIntermAggregate & node)919 PreResult TIntermRebuild::visitAggregatePre(TIntermAggregate &node)
920 {
921 return {node, VisitBits::Both};
922 }
923
visitFunctionDefinitionPre(TIntermFunctionDefinition & node)924 PreResult TIntermRebuild::visitFunctionDefinitionPre(TIntermFunctionDefinition &node)
925 {
926 return {node, VisitBits::Both};
927 }
928
visitGlobalQualifierDeclarationPre(TIntermGlobalQualifierDeclaration & node)929 PreResult TIntermRebuild::visitGlobalQualifierDeclarationPre(
930 TIntermGlobalQualifierDeclaration &node)
931 {
932 return {node, VisitBits::Both};
933 }
934
935 ////////////////////////////////////////////////////////////////////////////////
936
visitSymbolPost(TIntermSymbol & node)937 PostResult TIntermRebuild::visitSymbolPost(TIntermSymbol &node)
938 {
939 return node;
940 }
941
visitConstantUnionPost(TIntermConstantUnion & node)942 PostResult TIntermRebuild::visitConstantUnionPost(TIntermConstantUnion &node)
943 {
944 return node;
945 }
946
visitFunctionPrototypePost(TIntermFunctionPrototype & node)947 PostResult TIntermRebuild::visitFunctionPrototypePost(TIntermFunctionPrototype &node)
948 {
949 return node;
950 }
951
visitPreprocessorDirectivePost(TIntermPreprocessorDirective & node)952 PostResult TIntermRebuild::visitPreprocessorDirectivePost(TIntermPreprocessorDirective &node)
953 {
954 return node;
955 }
956
visitUnaryPost(TIntermUnary & node)957 PostResult TIntermRebuild::visitUnaryPost(TIntermUnary &node)
958 {
959 return node;
960 }
961
visitBinaryPost(TIntermBinary & node)962 PostResult TIntermRebuild::visitBinaryPost(TIntermBinary &node)
963 {
964 return node;
965 }
966
visitTernaryPost(TIntermTernary & node)967 PostResult TIntermRebuild::visitTernaryPost(TIntermTernary &node)
968 {
969 return node;
970 }
971
visitSwizzlePost(TIntermSwizzle & node)972 PostResult TIntermRebuild::visitSwizzlePost(TIntermSwizzle &node)
973 {
974 return node;
975 }
976
visitIfElsePost(TIntermIfElse & node)977 PostResult TIntermRebuild::visitIfElsePost(TIntermIfElse &node)
978 {
979 return node;
980 }
981
visitSwitchPost(TIntermSwitch & node)982 PostResult TIntermRebuild::visitSwitchPost(TIntermSwitch &node)
983 {
984 return node;
985 }
986
visitCasePost(TIntermCase & node)987 PostResult TIntermRebuild::visitCasePost(TIntermCase &node)
988 {
989 return node;
990 }
991
visitLoopPost(TIntermLoop & node)992 PostResult TIntermRebuild::visitLoopPost(TIntermLoop &node)
993 {
994 return node;
995 }
996
visitBranchPost(TIntermBranch & node)997 PostResult TIntermRebuild::visitBranchPost(TIntermBranch &node)
998 {
999 return node;
1000 }
1001
visitDeclarationPost(TIntermDeclaration & node)1002 PostResult TIntermRebuild::visitDeclarationPost(TIntermDeclaration &node)
1003 {
1004 return node;
1005 }
1006
visitBlockPost(TIntermBlock & node)1007 PostResult TIntermRebuild::visitBlockPost(TIntermBlock &node)
1008 {
1009 return node;
1010 }
1011
visitAggregatePost(TIntermAggregate & node)1012 PostResult TIntermRebuild::visitAggregatePost(TIntermAggregate &node)
1013 {
1014 return node;
1015 }
1016
visitFunctionDefinitionPost(TIntermFunctionDefinition & node)1017 PostResult TIntermRebuild::visitFunctionDefinitionPost(TIntermFunctionDefinition &node)
1018 {
1019 return node;
1020 }
1021
visitGlobalQualifierDeclarationPost(TIntermGlobalQualifierDeclaration & node)1022 PostResult TIntermRebuild::visitGlobalQualifierDeclarationPost(
1023 TIntermGlobalQualifierDeclaration &node)
1024 {
1025 return node;
1026 }
1027
1028 } // namespace sh
1029