• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include <unordered_map>
8 
9 #include "common/system_utils.h"
10 #include "compiler/translator/TranslatorMetalDirect/AstHelpers.h"
11 #include "compiler/translator/TranslatorMetalDirect/IntermRebuild.h"
12 #include "compiler/translator/TranslatorMetalDirect/SeparateCompoundExpressions.h"
13 #include "compiler/translator/tree_ops/SimplifyLoopConditions.h"
14 
15 using namespace sh;
16 
17 ////////////////////////////////////////////////////////////////////////////////
18 
19 namespace
20 {
21 
IsIndex(TOperator op)22 bool IsIndex(TOperator op)
23 {
24     switch (op)
25     {
26         case TOperator::EOpIndexDirect:
27         case TOperator::EOpIndexDirectInterfaceBlock:
28         case TOperator::EOpIndexDirectStruct:
29         case TOperator::EOpIndexIndirect:
30             return true;
31         default:
32             return false;
33     }
34 }
35 
IsIndex(TIntermTyped & expr)36 bool IsIndex(TIntermTyped &expr)
37 {
38     if (auto *binary = expr.getAsBinaryNode())
39     {
40         return IsIndex(binary->getOp());
41     }
42     return expr.getAsSwizzleNode();
43 }
44 
ViewBinaryChain(TOperator op,TIntermTyped & node,std::vector<TIntermTyped * > & out)45 bool ViewBinaryChain(TOperator op, TIntermTyped &node, std::vector<TIntermTyped *> &out)
46 {
47     TIntermBinary *binary = node.getAsBinaryNode();
48     if (!binary || binary->getOp() != op)
49     {
50         return false;
51     }
52 
53     TIntermTyped *left  = binary->getLeft();
54     TIntermTyped *right = binary->getRight();
55 
56     if (!ViewBinaryChain(op, *left, out))
57     {
58         out.push_back(left);
59     }
60 
61     if (!ViewBinaryChain(op, *right, out))
62     {
63         out.push_back(right);
64     }
65 
66     return true;
67 }
68 
ViewBinaryChain(TIntermBinary & node)69 std::vector<TIntermTyped *> ViewBinaryChain(TIntermBinary &node)
70 {
71     std::vector<TIntermTyped *> chain;
72     ViewBinaryChain(node.getOp(), node, chain);
73     ASSERT(chain.size() >= 2);
74     return chain;
75 }
76 
77 class PrePass : public TIntermRebuild
78 {
79   public:
PrePass(TCompiler & compiler)80     PrePass(TCompiler &compiler) : TIntermRebuild(compiler, true, true) {}
81 
82   private:
83     // Change chains of
84     //      x OP y OP z
85     // to
86     //      x OP (y OP z)
87     // regardless of original parenthesization.
reassociateRight(TIntermBinary & node)88     TIntermTyped &reassociateRight(TIntermBinary &node)
89     {
90         const TOperator op                = node.getOp();
91         std::vector<TIntermTyped *> chain = ViewBinaryChain(node);
92 
93         TIntermTyped *result = chain.back();
94         chain.pop_back();
95         ASSERT(result);
96 
97         const auto begin = chain.rbegin();
98         const auto end   = chain.rend();
99 
100         for (auto iter = begin; iter != end; ++iter)
101         {
102             TIntermTyped *part = *iter;
103             ASSERT(part);
104             TIntermNode *temp = rebuild(*part).single();
105             ASSERT(temp);
106             part = temp->getAsTyped();
107             ASSERT(part);
108             result = new TIntermBinary(op, part, result);
109         }
110         return *result;
111     }
112 
113   private:
visitBinaryPre(TIntermBinary & node)114     PreResult visitBinaryPre(TIntermBinary &node) override
115     {
116         const TOperator op = node.getOp();
117         if (op == TOperator::EOpLogicalAnd || op == TOperator::EOpLogicalOr)
118         {
119             return {reassociateRight(node), VisitBits::Neither};
120         }
121         return node;
122     }
123 };
124 
125 class Separator : public TIntermRebuild
126 {
127     IdGen &mIdGen;
128     std::vector<std::vector<TIntermNode *>> mStmtsStack;
129     std::vector<std::unordered_map<const TVariable *, TIntermDeclaration *>> mBindingMapStack;
130     std::unordered_map<TIntermTyped *, TIntermTyped *> mExprMap;
131     std::unordered_set<TIntermDeclaration *> mMaskedDecls;
132 
133   public:
Separator(TCompiler & compiler,SymbolEnv & symbolEnv,IdGen & idGen)134     Separator(TCompiler &compiler, SymbolEnv &symbolEnv, IdGen &idGen)
135         : TIntermRebuild(compiler, true, true), mIdGen(idGen)
136     {}
137 
~Separator()138     ~Separator() override
139     {
140         ASSERT(mStmtsStack.empty());
141         ASSERT(mExprMap.empty());
142         ASSERT(mBindingMapStack.empty());
143     }
144 
145   private:
getCurrStmts()146     std::vector<TIntermNode *> &getCurrStmts()
147     {
148         ASSERT(!mStmtsStack.empty());
149         return mStmtsStack.back();
150     }
151 
getCurrBindingMap()152     std::unordered_map<const TVariable *, TIntermDeclaration *> &getCurrBindingMap()
153     {
154         ASSERT(!mBindingMapStack.empty());
155         return mBindingMapStack.back();
156     }
157 
pushStmt(TIntermNode & node)158     void pushStmt(TIntermNode &node) { getCurrStmts().push_back(&node); }
159 
isTerminalExpr(TIntermNode & node)160     bool isTerminalExpr(TIntermNode &node)
161     {
162         NodeType nodeType = getNodeType(node);
163         switch (nodeType)
164         {
165             case NodeType::Symbol:
166             case NodeType::ConstantUnion:
167                 return true;
168             default:
169                 return false;
170         }
171     }
172 
pullMappedExpr(TIntermTyped * node,bool allowBacktrack)173     TIntermTyped *pullMappedExpr(TIntermTyped *node, bool allowBacktrack)
174     {
175         TIntermTyped *expr;
176 
177         {
178             auto iter = mExprMap.find(node);
179             if (iter == mExprMap.end())
180             {
181                 return node;
182             }
183             ASSERT(node);
184             expr = iter->second;
185             ASSERT(expr);
186             mExprMap.erase(iter);
187         }
188 
189         if (allowBacktrack)
190         {
191             auto &bindingMap = getCurrBindingMap();
192             while (TIntermSymbol *symbol = expr->getAsSymbolNode())
193             {
194                 const TVariable &var = symbol->variable();
195                 auto iter            = bindingMap.find(&var);
196                 if (iter == bindingMap.end())
197                 {
198                     return expr;
199                 }
200                 ASSERT(var.symbolType() == SymbolType::AngleInternal);
201                 TIntermDeclaration *decl = iter->second;
202                 ASSERT(decl);
203                 expr = ViewDeclaration(*decl).initExpr;
204                 ASSERT(expr);
205                 bindingMap.erase(iter);
206                 mMaskedDecls.insert(decl);
207             }
208         }
209 
210         return expr;
211     }
212 
isStandaloneExpr(TIntermTyped & expr)213     bool isStandaloneExpr(TIntermTyped &expr)
214     {
215         if (getParentNode()->getAsBlock())
216         {
217             return true;
218         }
219         ASSERT(expr.getType().getBasicType() != TBasicType::EbtVoid);
220         return false;
221     }
222 
pushBinding(TIntermTyped & oldExpr,TIntermTyped & newExpr)223     void pushBinding(TIntermTyped &oldExpr, TIntermTyped &newExpr)
224     {
225         if (isStandaloneExpr(newExpr))
226         {
227             pushStmt(newExpr);
228             return;
229         }
230         if (IsIndex(newExpr))
231         {
232             mExprMap[&oldExpr] = &newExpr;
233             return;
234         }
235         auto &bindingMap = getCurrBindingMap();
236         const Name name  = mIdGen.createNewName("");
237         auto *var =
238             new TVariable(&mSymbolTable, name.rawName(), &newExpr.getType(), name.symbolType());
239         auto *decl = new TIntermDeclaration(var, &newExpr);
240         pushStmt(*decl);
241         mExprMap[&oldExpr] = new TIntermSymbol(var);
242         bindingMap[var]    = decl;
243     }
244 
pushStacks()245     void pushStacks()
246     {
247         mStmtsStack.emplace_back();
248         mBindingMapStack.emplace_back();
249     }
250 
popStacks()251     void popStacks()
252     {
253         ASSERT(!mBindingMapStack.empty());
254         ASSERT(!mStmtsStack.empty());
255         ASSERT(mStmtsStack.back().empty());
256         mBindingMapStack.pop_back();
257         mStmtsStack.pop_back();
258     }
259 
pushStmtsIntoBlock(TIntermBlock & block,std::vector<TIntermNode * > & stmts)260     void pushStmtsIntoBlock(TIntermBlock &block, std::vector<TIntermNode *> &stmts)
261     {
262         TIntermSequence &seq = *block.getSequence();
263         for (TIntermNode *stmt : stmts)
264         {
265             if (TIntermDeclaration *decl = stmt->getAsDeclarationNode())
266             {
267                 auto iter = mMaskedDecls.find(decl);
268                 if (iter != mMaskedDecls.end())
269                 {
270                     mMaskedDecls.erase(iter);
271                     continue;
272                 }
273             }
274             seq.push_back(stmt);
275         }
276     }
277 
buildBlockWithTailAssign(const TVariable & var,TIntermTyped & newExpr)278     TIntermBlock &buildBlockWithTailAssign(const TVariable &var, TIntermTyped &newExpr)
279     {
280         std::vector<TIntermNode *> stmts = std::move(getCurrStmts());
281         popStacks();
282 
283         auto &block = *new TIntermBlock();
284         auto &seq   = *block.getSequence();
285         seq.reserve(1 + stmts.size());
286         pushStmtsIntoBlock(block, stmts);
287         seq.push_back(new TIntermBinary(TOperator::EOpAssign, new TIntermSymbol(&var), &newExpr));
288 
289         return block;
290     }
291 
292   private:
visitBlockPre(TIntermBlock & node)293     PreResult visitBlockPre(TIntermBlock &node) override
294     {
295         pushStacks();
296         return node;
297     }
298 
visitBlockPost(TIntermBlock & node)299     PostResult visitBlockPost(TIntermBlock &node) override
300     {
301         std::vector<TIntermNode *> stmts = std::move(getCurrStmts());
302         popStacks();
303 
304         TIntermSequence &seq = *node.getSequence();
305         seq.clear();
306         seq.reserve(stmts.size());
307         pushStmtsIntoBlock(node, stmts);
308 
309         TIntermNode *parent = getParentNode();
310         if (parent && parent->getAsBlock())
311         {
312             pushStmt(node);
313         }
314 
315         return node;
316     }
317 
visitDeclarationPre(TIntermDeclaration & node)318     PreResult visitDeclarationPre(TIntermDeclaration &node) override
319     {
320         Declaration decl = ViewDeclaration(node);
321         if (!decl.initExpr || isTerminalExpr(*decl.initExpr))
322         {
323             pushStmt(node);
324             return {node, VisitBits::Neither};
325         }
326         return node;
327     }
328 
visitDeclarationPost(TIntermDeclaration & node)329     PostResult visitDeclarationPost(TIntermDeclaration &node) override
330     {
331         Declaration decl = ViewDeclaration(node);
332         ASSERT(decl.symbol.variable().symbolType() != SymbolType::Empty);
333         ASSERT(!decl.symbol.variable().getType().isStructSpecifier());
334 
335         TIntermTyped *newInitExpr = pullMappedExpr(decl.initExpr, true);
336         if (decl.initExpr == newInitExpr)
337         {
338             pushStmt(node);
339         }
340         else
341         {
342             auto &newNode = *new TIntermDeclaration();
343             newNode.appendDeclarator(
344                 new TIntermBinary(TOperator::EOpInitialize, &decl.symbol, newInitExpr));
345             pushStmt(newNode);
346         }
347         return node;
348     }
349 
visitUnaryPost(TIntermUnary & node)350     PostResult visitUnaryPost(TIntermUnary &node) override
351     {
352         TIntermTyped *expr    = node.getOperand();
353         TIntermTyped *newExpr = pullMappedExpr(expr, false);
354         if (expr == newExpr)
355         {
356             pushBinding(node, node);
357         }
358         else
359         {
360             pushBinding(node, *new TIntermUnary(node.getOp(), newExpr, node.getFunction()));
361         }
362         return node;
363     }
364 
visitBinaryPre(TIntermBinary & node)365     PreResult visitBinaryPre(TIntermBinary &node) override
366     {
367         const TOperator op = node.getOp();
368         if (op == TOperator::EOpLogicalAnd || op == TOperator::EOpLogicalOr)
369         {
370             TIntermTyped *left  = node.getLeft();
371             TIntermTyped *right = node.getRight();
372 
373             PostResult leftResult = rebuild(*left);
374             ASSERT(leftResult.single());
375 
376             pushStacks();
377             PostResult rightResult = rebuild(*right);
378             ASSERT(rightResult.single());
379 
380             return {node, VisitBits::Post};
381         }
382 
383         return node;
384     }
385 
visitBinaryPost(TIntermBinary & node)386     PostResult visitBinaryPost(TIntermBinary &node) override
387     {
388         const TOperator op = node.getOp();
389         if (op == TOperator::EOpInitialize && getParentNode()->getAsDeclarationNode())
390         {
391             // Special case is handled by visitDeclarationPost
392             return node;
393         }
394 
395         TIntermTyped *left  = node.getLeft();
396         TIntermTyped *right = node.getRight();
397 
398         if (op == TOperator::EOpLogicalAnd || op == TOperator::EOpLogicalOr)
399         {
400             const Name name = mIdGen.createNewName("");
401             auto *var = new TVariable(&mSymbolTable, name.rawName(), new TType(TBasicType::EbtBool),
402                                       name.symbolType());
403 
404             TIntermTyped *newRight   = pullMappedExpr(right, true);
405             TIntermBlock *rightBlock = &buildBlockWithTailAssign(*var, *newRight);
406             TIntermTyped *newLeft    = pullMappedExpr(left, true);
407 
408             TIntermTyped *cond = new TIntermSymbol(var);
409             if (op == TOperator::EOpLogicalOr)
410             {
411                 cond = new TIntermUnary(TOperator::EOpLogicalNot, cond, nullptr);
412             }
413 
414             pushStmt(*new TIntermDeclaration(var, newLeft));
415             pushStmt(*new TIntermIfElse(cond, rightBlock, nullptr));
416             if (!isStandaloneExpr(node))
417             {
418                 mExprMap[&node] = new TIntermSymbol(var);
419             }
420 
421             return node;
422         }
423 
424         const bool isAssign    = IsAssignment(op);
425         TIntermTyped *newLeft  = pullMappedExpr(left, false);
426         TIntermTyped *newRight = pullMappedExpr(right, isAssign);
427 
428         if (op == TOperator::EOpComma)
429         {
430             pushBinding(node, *newRight);
431             return node;
432         }
433         else
434         {
435             TIntermBinary *newNode;
436             if (left == newLeft && right == newRight)
437             {
438                 newNode = &node;
439             }
440             else
441             {
442                 newNode = new TIntermBinary(op, newLeft, newRight);
443             }
444             pushBinding(node, *newNode);
445             return node;
446         }
447     }
448 
visitTernaryPre(TIntermTernary & node)449     PreResult visitTernaryPre(TIntermTernary &node) override
450     {
451         PostResult condResult = rebuild(*node.getCondition());
452         ASSERT(condResult.single());
453 
454         pushStacks();
455         PostResult thenResult = rebuild(*node.getTrueExpression());
456         ASSERT(thenResult.single());
457 
458         pushStacks();
459         PostResult elseResult = rebuild(*node.getFalseExpression());
460         ASSERT(elseResult.single());
461 
462         return {node, VisitBits::Post};
463     }
464 
visitTernaryPost(TIntermTernary & node)465     PostResult visitTernaryPost(TIntermTernary &node) override
466     {
467         TIntermTyped *cond  = node.getCondition();
468         TIntermTyped *then  = node.getTrueExpression();
469         TIntermTyped *else_ = node.getFalseExpression();
470 
471         const Name name = mIdGen.createNewName("");
472         auto *var =
473             new TVariable(&mSymbolTable, name.rawName(), &node.getType(), name.symbolType());
474 
475         TIntermTyped *newElse   = pullMappedExpr(else_, false);
476         TIntermBlock *elseBlock = &buildBlockWithTailAssign(*var, *newElse);
477         TIntermTyped *newThen   = pullMappedExpr(then, true);
478         TIntermBlock *thenBlock = &buildBlockWithTailAssign(*var, *newThen);
479         TIntermTyped *newCond   = pullMappedExpr(cond, true);
480 
481         pushStmt(*new TIntermDeclaration{var});
482         pushStmt(*new TIntermIfElse(newCond, thenBlock, elseBlock));
483         if (!isStandaloneExpr(node))
484         {
485             mExprMap[&node] = new TIntermSymbol(var);
486         }
487 
488         return node;
489     }
490 
visitSwizzlePost(TIntermSwizzle & node)491     PostResult visitSwizzlePost(TIntermSwizzle &node) override
492     {
493         TIntermTyped *expr    = node.getOperand();
494         TIntermTyped *newExpr = pullMappedExpr(expr, false);
495         if (expr == newExpr)
496         {
497             pushBinding(node, node);
498         }
499         else
500         {
501             pushBinding(node, *new TIntermSwizzle(newExpr, node.getSwizzleOffsets()));
502         }
503         return node;
504     }
505 
visitAggregatePost(TIntermAggregate & node)506     PostResult visitAggregatePost(TIntermAggregate &node) override
507     {
508         TIntermSequence &args = *node.getSequence();
509         for (TIntermNode *&arg : args)
510         {
511             TIntermTyped *targ = arg->getAsTyped();
512             ASSERT(targ);
513             arg = pullMappedExpr(targ, false);
514         }
515         pushBinding(node, node);
516         return node;
517     }
518 
visitPreprocessorDirectivePost(TIntermPreprocessorDirective & node)519     PostResult visitPreprocessorDirectivePost(TIntermPreprocessorDirective &node) override
520     {
521         pushStmt(node);
522         return node;
523     }
524 
visitFunctionPrototypePost(TIntermFunctionPrototype & node)525     PostResult visitFunctionPrototypePost(TIntermFunctionPrototype &node) override
526     {
527         if (!getParentFunction())
528         {
529             pushStmt(node);
530         }
531         return node;
532     }
533 
visitCasePre(TIntermCase & node)534     PreResult visitCasePre(TIntermCase &node) override
535     {
536         if (TIntermTyped *cond = node.getCondition())
537         {
538             ASSERT(isTerminalExpr(*cond));
539         }
540         pushStmt(node);
541         return {node, VisitBits::Neither};
542     }
543 
visitSwitchPost(TIntermSwitch & node)544     PostResult visitSwitchPost(TIntermSwitch &node) override
545     {
546         TIntermTyped *init    = node.getInit();
547         TIntermTyped *newInit = pullMappedExpr(init, false);
548         if (init == newInit)
549         {
550             pushStmt(node);
551         }
552         else
553         {
554             pushStmt(*new TIntermSwitch(newInit, node.getStatementList()));
555         }
556 
557         return node;
558     }
559 
visitFunctionDefinitionPost(TIntermFunctionDefinition & node)560     PostResult visitFunctionDefinitionPost(TIntermFunctionDefinition &node) override
561     {
562         pushStmt(node);
563         return node;
564     }
565 
visitIfElsePost(TIntermIfElse & node)566     PostResult visitIfElsePost(TIntermIfElse &node) override
567     {
568         TIntermTyped *cond    = node.getCondition();
569         TIntermTyped *newCond = pullMappedExpr(cond, false);
570         if (cond == newCond)
571         {
572             pushStmt(node);
573         }
574         else
575         {
576             pushStmt(*new TIntermIfElse(newCond, node.getTrueBlock(), node.getFalseBlock()));
577         }
578         return node;
579     }
580 
visitBranchPost(TIntermBranch & node)581     PostResult visitBranchPost(TIntermBranch &node) override
582     {
583         TIntermTyped *expr    = node.getExpression();
584         TIntermTyped *newExpr = pullMappedExpr(expr, false);
585         if (expr == newExpr)
586         {
587             pushStmt(node);
588         }
589         else
590         {
591             pushStmt(*new TIntermBranch(node.getFlowOp(), newExpr));
592         }
593         return node;
594     }
595 
visitLoopPre(TIntermLoop & node)596     PreResult visitLoopPre(TIntermLoop &node) override
597     {
598         if (!rebuildInPlace(*node.getBody()))
599         {
600             UNREACHABLE();
601         }
602         pushStmt(node);
603         return {node, VisitBits::Neither};
604     }
605 
visitConstantUnionPost(TIntermConstantUnion & node)606     PostResult visitConstantUnionPost(TIntermConstantUnion &node) override
607     {
608         const TType &type = node.getType();
609         if (!type.isScalar())
610         {
611             pushBinding(node, node);
612         }
613         return node;
614     }
615 
visitGlobalQualifierDeclarationPost(TIntermGlobalQualifierDeclaration & node)616     PostResult visitGlobalQualifierDeclarationPost(TIntermGlobalQualifierDeclaration &node) override
617     {
618         ASSERT(false);  // These should be scrubbed from AST before rewriter is called.
619         pushStmt(node);
620         return node;
621     }
622 };
623 
624 }  // anonymous namespace
625 
626 ////////////////////////////////////////////////////////////////////////////////
627 
SeparateCompoundExpressions(TCompiler & compiler,SymbolEnv & symbolEnv,IdGen & idGen,TIntermBlock & root)628 bool sh::SeparateCompoundExpressions(TCompiler &compiler,
629                                      SymbolEnv &symbolEnv,
630                                      IdGen &idGen,
631                                      TIntermBlock &root)
632 {
633     if (angle::GetBoolEnvironmentVar("GMT_DISABLE_SEPARATE_COMPOUND_EXPRESSIONS"))
634     {
635         return true;
636     }
637 
638     if (!SimplifyLoopConditions(&compiler, &root, &compiler.getSymbolTable()))
639     {
640         return false;
641     }
642 
643     if (!PrePass(compiler).rebuildRoot(root))
644     {
645         return false;
646     }
647 
648     if (!Separator(compiler, symbolEnv, idGen).rebuildRoot(root))
649     {
650         return false;
651     }
652 
653     return true;
654 }
655