• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include <unordered_map>
8 
9 #include "common/system_utils.h"
10 #include "compiler/translator/msl/AstHelpers.h"
11 #include "compiler/translator/msl/IntermRebuild.h"
12 #include "compiler/translator/tree_ops/SimplifyLoopConditions.h"
13 #include "compiler/translator/tree_ops/msl/SeparateCompoundExpressions.h"
14 
15 using namespace sh;
16 
17 ////////////////////////////////////////////////////////////////////////////////
18 
19 namespace
20 {
21 
IsIndex(TOperator op)22 bool IsIndex(TOperator op)
23 {
24     switch (op)
25     {
26         case TOperator::EOpIndexDirect:
27         case TOperator::EOpIndexDirectInterfaceBlock:
28         case TOperator::EOpIndexDirectStruct:
29         case TOperator::EOpIndexIndirect:
30             return true;
31         default:
32             return false;
33     }
34 }
35 
IsIndex(TIntermTyped & expr)36 bool IsIndex(TIntermTyped &expr)
37 {
38     if (auto *binary = expr.getAsBinaryNode())
39     {
40         return IsIndex(binary->getOp());
41     }
42     return expr.getAsSwizzleNode();
43 }
44 
IsCompoundAssignment(TOperator op)45 bool IsCompoundAssignment(TOperator op)
46 {
47     switch (op)
48     {
49         case EOpAddAssign:
50         case EOpSubAssign:
51         case EOpMulAssign:
52         case EOpVectorTimesMatrixAssign:
53         case EOpVectorTimesScalarAssign:
54         case EOpMatrixTimesScalarAssign:
55         case EOpMatrixTimesMatrixAssign:
56         case EOpDivAssign:
57         case EOpIModAssign:
58         case EOpBitShiftLeftAssign:
59         case EOpBitShiftRightAssign:
60         case EOpBitwiseAndAssign:
61         case EOpBitwiseXorAssign:
62         case EOpBitwiseOrAssign:
63             return true;
64         default:
65             return false;
66     }
67 }
68 
ViewBinaryChain(TOperator op,TIntermTyped & node,std::vector<TIntermTyped * > & out)69 bool ViewBinaryChain(TOperator op, TIntermTyped &node, std::vector<TIntermTyped *> &out)
70 {
71     TIntermBinary *binary = node.getAsBinaryNode();
72     if (!binary || binary->getOp() != op)
73     {
74         return false;
75     }
76 
77     TIntermTyped *left  = binary->getLeft();
78     TIntermTyped *right = binary->getRight();
79 
80     if (!ViewBinaryChain(op, *left, out))
81     {
82         out.push_back(left);
83     }
84 
85     if (!ViewBinaryChain(op, *right, out))
86     {
87         out.push_back(right);
88     }
89 
90     return true;
91 }
92 
ViewBinaryChain(TIntermBinary & node)93 std::vector<TIntermTyped *> ViewBinaryChain(TIntermBinary &node)
94 {
95     std::vector<TIntermTyped *> chain;
96     ViewBinaryChain(node.getOp(), node, chain);
97     ASSERT(chain.size() >= 2);
98     return chain;
99 }
100 
101 class PrePass : public TIntermRebuild
102 {
103   public:
PrePass(TCompiler & compiler)104     PrePass(TCompiler &compiler) : TIntermRebuild(compiler, true, true) {}
105 
106   private:
107     // Change chains of
108     //      x OP y OP z
109     // to
110     //      x OP (y OP z)
111     // regardless of original parenthesization.
reassociateRight(TIntermBinary & node)112     TIntermTyped &reassociateRight(TIntermBinary &node)
113     {
114         const TOperator op                = node.getOp();
115         std::vector<TIntermTyped *> chain = ViewBinaryChain(node);
116 
117         TIntermTyped *result = chain.back();
118         chain.pop_back();
119         ASSERT(result);
120 
121         const auto begin = chain.rbegin();
122         const auto end   = chain.rend();
123 
124         for (auto iter = begin; iter != end; ++iter)
125         {
126             TIntermTyped *part = *iter;
127             ASSERT(part);
128             TIntermNode *temp = rebuild(*part).single();
129             ASSERT(temp);
130             part = temp->getAsTyped();
131             ASSERT(part);
132             result = new TIntermBinary(op, part, result);
133         }
134         return *result;
135     }
136 
137   private:
visitBinaryPre(TIntermBinary & node)138     PreResult visitBinaryPre(TIntermBinary &node) override
139     {
140         const TOperator op = node.getOp();
141         if (op == TOperator::EOpLogicalAnd || op == TOperator::EOpLogicalOr)
142         {
143             return {reassociateRight(node), VisitBits::Neither};
144         }
145         return node;
146     }
147 };
148 
149 class Separator : public TIntermRebuild
150 {
151     IdGen &mIdGen;
152     std::vector<std::vector<TIntermNode *>> mStmtsStack;
153     std::vector<std::unordered_map<const TVariable *, TIntermDeclaration *>> mBindingMapStack;
154     std::unordered_map<TIntermTyped *, TIntermTyped *> mExprMap;
155     std::unordered_set<TIntermDeclaration *> mMaskedDecls;
156 
157   public:
Separator(TCompiler & compiler,SymbolEnv & symbolEnv,IdGen & idGen)158     Separator(TCompiler &compiler, SymbolEnv &symbolEnv, IdGen &idGen)
159         : TIntermRebuild(compiler, true, true), mIdGen(idGen)
160     {}
161 
~Separator()162     ~Separator() override
163     {
164         ASSERT(mStmtsStack.empty());
165         ASSERT(mExprMap.empty());
166         ASSERT(mBindingMapStack.empty());
167     }
168 
169   private:
getCurrStmts()170     std::vector<TIntermNode *> &getCurrStmts()
171     {
172         ASSERT(!mStmtsStack.empty());
173         return mStmtsStack.back();
174     }
175 
getCurrBindingMap()176     std::unordered_map<const TVariable *, TIntermDeclaration *> &getCurrBindingMap()
177     {
178         ASSERT(!mBindingMapStack.empty());
179         return mBindingMapStack.back();
180     }
181 
pushStmt(TIntermNode & node)182     void pushStmt(TIntermNode &node) { getCurrStmts().push_back(&node); }
183 
isTerminalExpr(TIntermNode & node)184     bool isTerminalExpr(TIntermNode &node)
185     {
186         NodeType nodeType = getNodeType(node);
187         switch (nodeType)
188         {
189             case NodeType::Symbol:
190             case NodeType::ConstantUnion:
191                 return true;
192             default:
193                 return false;
194         }
195     }
196 
pullMappedExpr(TIntermTyped * node,bool allowBacktrack)197     TIntermTyped *pullMappedExpr(TIntermTyped *node, bool allowBacktrack)
198     {
199         TIntermTyped *expr;
200 
201         {
202             auto iter = mExprMap.find(node);
203             if (iter == mExprMap.end())
204             {
205                 return node;
206             }
207             ASSERT(node);
208             expr = iter->second;
209             ASSERT(expr);
210             mExprMap.erase(iter);
211         }
212 
213         if (allowBacktrack)
214         {
215             auto &bindingMap = getCurrBindingMap();
216             while (TIntermSymbol *symbol = expr->getAsSymbolNode())
217             {
218                 const TVariable &var = symbol->variable();
219                 auto iter            = bindingMap.find(&var);
220                 if (iter == bindingMap.end())
221                 {
222                     return expr;
223                 }
224                 ASSERT(var.symbolType() == SymbolType::AngleInternal);
225                 TIntermDeclaration *decl = iter->second;
226                 ASSERT(decl);
227                 expr = ViewDeclaration(*decl).initExpr;
228                 ASSERT(expr);
229                 bindingMap.erase(iter);
230                 mMaskedDecls.insert(decl);
231             }
232         }
233 
234         return expr;
235     }
236 
isStandaloneExpr(TIntermTyped & expr)237     bool isStandaloneExpr(TIntermTyped &expr)
238     {
239         if (getParentNode()->getAsBlock())
240         {
241             return true;
242         }
243         // https://bugs.webkit.org/show_bug.cgi?id=227723: Fix for sequence operator.
244         if ((expr.getType().getBasicType() == TBasicType::EbtVoid))
245         {
246             return true;
247         }
248         return false;
249     }
250 
pushBinding(TIntermTyped & oldExpr,TIntermTyped & newExpr)251     void pushBinding(TIntermTyped &oldExpr, TIntermTyped &newExpr)
252     {
253         if (isStandaloneExpr(newExpr))
254         {
255             pushStmt(newExpr);
256             return;
257         }
258         if (IsIndex(newExpr))
259         {
260             mExprMap[&oldExpr] = &newExpr;
261             return;
262         }
263         auto &bindingMap = getCurrBindingMap();
264         const Name name  = mIdGen.createNewName();
265         auto *var =
266             new TVariable(&mSymbolTable, name.rawName(), &newExpr.getType(), name.symbolType());
267         auto *decl = new TIntermDeclaration(var, &newExpr);
268         pushStmt(*decl);
269         mExprMap[&oldExpr] = new TIntermSymbol(var);
270         bindingMap[var]    = decl;
271     }
272 
pushStacks()273     void pushStacks()
274     {
275         mStmtsStack.emplace_back();
276         mBindingMapStack.emplace_back();
277     }
278 
popStacks()279     void popStacks()
280     {
281         ASSERT(!mBindingMapStack.empty());
282         ASSERT(!mStmtsStack.empty());
283         ASSERT(mStmtsStack.back().empty());
284         mBindingMapStack.pop_back();
285         mStmtsStack.pop_back();
286     }
287 
pushStmtsIntoBlock(TIntermBlock & block,std::vector<TIntermNode * > & stmts)288     void pushStmtsIntoBlock(TIntermBlock &block, std::vector<TIntermNode *> &stmts)
289     {
290         TIntermSequence &seq = *block.getSequence();
291         for (TIntermNode *stmt : stmts)
292         {
293             if (TIntermDeclaration *decl = stmt->getAsDeclarationNode())
294             {
295                 auto iter = mMaskedDecls.find(decl);
296                 if (iter != mMaskedDecls.end())
297                 {
298                     mMaskedDecls.erase(iter);
299                     continue;
300                 }
301             }
302             seq.push_back(stmt);
303         }
304     }
305 
buildBlockWithTailAssign(const TVariable & var,TIntermTyped & newExpr)306     TIntermBlock &buildBlockWithTailAssign(const TVariable &var, TIntermTyped &newExpr)
307     {
308         std::vector<TIntermNode *> stmts = std::move(getCurrStmts());
309         popStacks();
310 
311         auto &block = *new TIntermBlock();
312         auto &seq   = *block.getSequence();
313         seq.reserve(1 + stmts.size());
314         pushStmtsIntoBlock(block, stmts);
315         seq.push_back(new TIntermBinary(TOperator::EOpAssign, new TIntermSymbol(&var), &newExpr));
316 
317         return block;
318     }
319 
320   private:
visitBlockPre(TIntermBlock & node)321     PreResult visitBlockPre(TIntermBlock &node) override
322     {
323         pushStacks();
324         return node;
325     }
326 
visitBlockPost(TIntermBlock & node)327     PostResult visitBlockPost(TIntermBlock &node) override
328     {
329         std::vector<TIntermNode *> stmts = std::move(getCurrStmts());
330         popStacks();
331 
332         TIntermSequence &seq = *node.getSequence();
333         seq.clear();
334         seq.reserve(stmts.size());
335         pushStmtsIntoBlock(node, stmts);
336 
337         TIntermNode *parent = getParentNode();
338         if (parent && parent->getAsBlock())
339         {
340             pushStmt(node);
341         }
342 
343         return node;
344     }
345 
visitDeclarationPre(TIntermDeclaration & node)346     PreResult visitDeclarationPre(TIntermDeclaration &node) override
347     {
348         Declaration decl = ViewDeclaration(node);
349         if (!decl.initExpr || isTerminalExpr(*decl.initExpr))
350         {
351             pushStmt(node);
352             return {node, VisitBits::Neither};
353         }
354         return node;
355     }
356 
visitDeclarationPost(TIntermDeclaration & node)357     PostResult visitDeclarationPost(TIntermDeclaration &node) override
358     {
359         Declaration decl = ViewDeclaration(node);
360         ASSERT(decl.symbol.variable().symbolType() != SymbolType::Empty);
361         ASSERT(!decl.symbol.variable().getType().isStructSpecifier());
362 
363         TIntermTyped *newInitExpr = pullMappedExpr(decl.initExpr, true);
364         if (decl.initExpr == newInitExpr)
365         {
366             pushStmt(node);
367         }
368         else
369         {
370             auto &newNode = *new TIntermDeclaration();
371             newNode.appendDeclarator(
372                 new TIntermBinary(TOperator::EOpInitialize, &decl.symbol, newInitExpr));
373             pushStmt(newNode);
374         }
375         return node;
376     }
377 
visitUnaryPost(TIntermUnary & node)378     PostResult visitUnaryPost(TIntermUnary &node) override
379     {
380         TIntermTyped *expr    = node.getOperand();
381         TIntermTyped *newExpr = pullMappedExpr(expr, false);
382         if (expr == newExpr)
383         {
384             pushBinding(node, node);
385         }
386         else
387         {
388             pushBinding(node, *new TIntermUnary(node.getOp(), newExpr, node.getFunction()));
389         }
390         return node;
391     }
392 
visitBinaryPre(TIntermBinary & node)393     PreResult visitBinaryPre(TIntermBinary &node) override
394     {
395         const TOperator op = node.getOp();
396         if (op == TOperator::EOpLogicalAnd || op == TOperator::EOpLogicalOr)
397         {
398             TIntermTyped *left  = node.getLeft();
399             TIntermTyped *right = node.getRight();
400 
401             PostResult leftResult = rebuild(*left);
402             ASSERT(leftResult.single());
403 
404             pushStacks();
405             PostResult rightResult = rebuild(*right);
406             ASSERT(rightResult.single());
407 
408             return {node, VisitBits::Post};
409         }
410 
411         return node;
412     }
413 
visitBinaryPost(TIntermBinary & node)414     PostResult visitBinaryPost(TIntermBinary &node) override
415     {
416         const TOperator op = node.getOp();
417         if (op == TOperator::EOpInitialize && getParentNode()->getAsDeclarationNode())
418         {
419             // Special case is handled by visitDeclarationPost
420             return node;
421         }
422 
423         TIntermTyped *left  = node.getLeft();
424         TIntermTyped *right = node.getRight();
425 
426         if (op == TOperator::EOpLogicalAnd || op == TOperator::EOpLogicalOr)
427         {
428             const Name name = mIdGen.createNewName();
429             auto *var = new TVariable(&mSymbolTable, name.rawName(), new TType(TBasicType::EbtBool),
430                                       name.symbolType());
431 
432             TIntermTyped *newRight   = pullMappedExpr(right, true);
433             TIntermBlock *rightBlock = &buildBlockWithTailAssign(*var, *newRight);
434             TIntermTyped *newLeft    = pullMappedExpr(left, true);
435 
436             TIntermTyped *cond = new TIntermSymbol(var);
437             if (op == TOperator::EOpLogicalOr)
438             {
439                 cond = new TIntermUnary(TOperator::EOpLogicalNot, cond, nullptr);
440             }
441 
442             pushStmt(*new TIntermDeclaration(var, newLeft));
443             pushStmt(*new TIntermIfElse(cond, rightBlock, nullptr));
444             if (!isStandaloneExpr(node))
445             {
446                 mExprMap[&node] = new TIntermSymbol(var);
447             }
448 
449             return node;
450         }
451 
452         const bool isAssign         = IsAssignment(op);
453         const bool isCompoundAssign = IsCompoundAssignment(op);
454         TIntermTyped *newLeft       = pullMappedExpr(left, false);
455         TIntermTyped *newRight      = pullMappedExpr(right, isAssign && !isCompoundAssign);
456         if (op == TOperator::EOpComma)
457         {
458             pushBinding(node, *newRight);
459             return node;
460         }
461         else
462         {
463             TIntermBinary *newNode;
464             if (left == newLeft && right == newRight)
465             {
466                 newNode = &node;
467             }
468             else
469             {
470                 newNode = new TIntermBinary(op, newLeft, newRight);
471             }
472             pushBinding(node, *newNode);
473             return node;
474         }
475     }
476 
visitTernaryPre(TIntermTernary & node)477     PreResult visitTernaryPre(TIntermTernary &node) override
478     {
479         PostResult condResult = rebuild(*node.getCondition());
480         ASSERT(condResult.single());
481 
482         pushStacks();
483         PostResult thenResult = rebuild(*node.getTrueExpression());
484         ASSERT(thenResult.single());
485 
486         pushStacks();
487         PostResult elseResult = rebuild(*node.getFalseExpression());
488         ASSERT(elseResult.single());
489 
490         return {node, VisitBits::Post};
491     }
492 
visitTernaryPost(TIntermTernary & node)493     PostResult visitTernaryPost(TIntermTernary &node) override
494     {
495         TIntermTyped *cond  = node.getCondition();
496         TIntermTyped *then  = node.getTrueExpression();
497         TIntermTyped *else_ = node.getFalseExpression();
498 
499         const Name name = mIdGen.createNewName();
500         TType *newType  = new TType(node.getType());
501         newType->setInterfaceBlock(nullptr);
502         auto *var = new TVariable(&mSymbolTable, name.rawName(), newType, name.symbolType());
503 
504         TIntermTyped *newElse   = pullMappedExpr(else_, false);
505         TIntermBlock *elseBlock = &buildBlockWithTailAssign(*var, *newElse);
506         TIntermTyped *newThen   = pullMappedExpr(then, true);
507         TIntermBlock *thenBlock = &buildBlockWithTailAssign(*var, *newThen);
508         TIntermTyped *newCond   = pullMappedExpr(cond, true);
509 
510         pushStmt(*new TIntermDeclaration{var});
511         pushStmt(*new TIntermIfElse(newCond, thenBlock, elseBlock));
512         if (!isStandaloneExpr(node))
513         {
514             mExprMap[&node] = new TIntermSymbol(var);
515         }
516 
517         return node;
518     }
519 
visitSwizzlePost(TIntermSwizzle & node)520     PostResult visitSwizzlePost(TIntermSwizzle &node) override
521     {
522         TIntermTyped *expr    = node.getOperand();
523         TIntermTyped *newExpr = pullMappedExpr(expr, false);
524         if (expr == newExpr)
525         {
526             pushBinding(node, node);
527         }
528         else
529         {
530             pushBinding(node, *new TIntermSwizzle(newExpr, node.getSwizzleOffsets()));
531         }
532         return node;
533     }
534 
visitAggregatePost(TIntermAggregate & node)535     PostResult visitAggregatePost(TIntermAggregate &node) override
536     {
537         TIntermSequence &args = *node.getSequence();
538         for (TIntermNode *&arg : args)
539         {
540             TIntermTyped *targ = arg->getAsTyped();
541             ASSERT(targ);
542             arg = pullMappedExpr(targ, false);
543         }
544         pushBinding(node, node);
545         return node;
546     }
547 
visitPreprocessorDirectivePost(TIntermPreprocessorDirective & node)548     PostResult visitPreprocessorDirectivePost(TIntermPreprocessorDirective &node) override
549     {
550         pushStmt(node);
551         return node;
552     }
553 
visitFunctionPrototypePost(TIntermFunctionPrototype & node)554     PostResult visitFunctionPrototypePost(TIntermFunctionPrototype &node) override
555     {
556         if (!getParentFunction())
557         {
558             pushStmt(node);
559         }
560         return node;
561     }
562 
visitCasePre(TIntermCase & node)563     PreResult visitCasePre(TIntermCase &node) override
564     {
565         if (TIntermTyped *cond = node.getCondition())
566         {
567             ASSERT(isTerminalExpr(*cond));
568         }
569         pushStmt(node);
570         return {node, VisitBits::Neither};
571     }
572 
visitSwitchPost(TIntermSwitch & node)573     PostResult visitSwitchPost(TIntermSwitch &node) override
574     {
575         TIntermTyped *init    = node.getInit();
576         TIntermTyped *newInit = pullMappedExpr(init, false);
577         if (init == newInit)
578         {
579             pushStmt(node);
580         }
581         else
582         {
583             pushStmt(*new TIntermSwitch(newInit, node.getStatementList()));
584         }
585 
586         return node;
587     }
588 
visitFunctionDefinitionPost(TIntermFunctionDefinition & node)589     PostResult visitFunctionDefinitionPost(TIntermFunctionDefinition &node) override
590     {
591         pushStmt(node);
592         return node;
593     }
594 
visitIfElsePost(TIntermIfElse & node)595     PostResult visitIfElsePost(TIntermIfElse &node) override
596     {
597         TIntermTyped *cond    = node.getCondition();
598         TIntermTyped *newCond = pullMappedExpr(cond, false);
599         if (cond == newCond)
600         {
601             pushStmt(node);
602         }
603         else
604         {
605             pushStmt(*new TIntermIfElse(newCond, node.getTrueBlock(), node.getFalseBlock()));
606         }
607         return node;
608     }
609 
visitBranchPost(TIntermBranch & node)610     PostResult visitBranchPost(TIntermBranch &node) override
611     {
612         TIntermTyped *expr    = node.getExpression();
613         TIntermTyped *newExpr = pullMappedExpr(expr, false);
614         if (expr == newExpr)
615         {
616             pushStmt(node);
617         }
618         else
619         {
620             pushStmt(*new TIntermBranch(node.getFlowOp(), newExpr));
621         }
622         return node;
623     }
624 
visitLoopPre(TIntermLoop & node)625     PreResult visitLoopPre(TIntermLoop &node) override
626     {
627         if (!rebuildInPlace(*node.getBody()))
628         {
629             UNREACHABLE();
630         }
631         pushStmt(node);
632         return {node, VisitBits::Neither};
633     }
634 
visitConstantUnionPost(TIntermConstantUnion & node)635     PostResult visitConstantUnionPost(TIntermConstantUnion &node) override
636     {
637         const TType &type = node.getType();
638         if (!type.isScalar())
639         {
640             pushBinding(node, node);
641         }
642         return node;
643     }
644 
visitGlobalQualifierDeclarationPost(TIntermGlobalQualifierDeclaration & node)645     PostResult visitGlobalQualifierDeclarationPost(TIntermGlobalQualifierDeclaration &node) override
646     {
647         // With the removal of RewriteGlobalQualifierDecls, we may encounter globals while
648         // seperating compound expressions.
649         pushStmt(node);
650         return node;
651     }
652 };
653 
654 }  // anonymous namespace
655 
656 ////////////////////////////////////////////////////////////////////////////////
657 
SeparateCompoundExpressions(TCompiler & compiler,SymbolEnv & symbolEnv,IdGen & idGen,TIntermBlock & root)658 bool sh::SeparateCompoundExpressions(TCompiler &compiler,
659                                      SymbolEnv &symbolEnv,
660                                      IdGen &idGen,
661                                      TIntermBlock &root)
662 {
663     if (angle::GetBoolEnvironmentVar("GMT_DISABLE_SEPARATE_COMPOUND_EXPRESSIONS"))
664     {
665         return true;
666     }
667 
668     if (!SimplifyLoopConditions(&compiler, &root, &compiler.getSymbolTable()))
669     {
670         return false;
671     }
672 
673     if (!PrePass(compiler).rebuildRoot(root))
674     {
675         return false;
676     }
677 
678     if (!Separator(compiler, symbolEnv, idGen).rebuildRoot(root))
679     {
680         return false;
681     }
682 
683     return true;
684 }
685