1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include <unordered_map>
8
9 #include "common/system_utils.h"
10 #include "compiler/translator/msl/AstHelpers.h"
11 #include "compiler/translator/msl/IntermRebuild.h"
12 #include "compiler/translator/tree_ops/SimplifyLoopConditions.h"
13 #include "compiler/translator/tree_ops/msl/SeparateCompoundExpressions.h"
14
15 using namespace sh;
16
17 ////////////////////////////////////////////////////////////////////////////////
18
19 namespace
20 {
21
IsIndex(TOperator op)22 bool IsIndex(TOperator op)
23 {
24 switch (op)
25 {
26 case TOperator::EOpIndexDirect:
27 case TOperator::EOpIndexDirectInterfaceBlock:
28 case TOperator::EOpIndexDirectStruct:
29 case TOperator::EOpIndexIndirect:
30 return true;
31 default:
32 return false;
33 }
34 }
35
IsIndex(TIntermTyped & expr)36 bool IsIndex(TIntermTyped &expr)
37 {
38 if (auto *binary = expr.getAsBinaryNode())
39 {
40 return IsIndex(binary->getOp());
41 }
42 return expr.getAsSwizzleNode();
43 }
44
IsCompoundAssignment(TOperator op)45 bool IsCompoundAssignment(TOperator op)
46 {
47 switch (op)
48 {
49 case EOpAddAssign:
50 case EOpSubAssign:
51 case EOpMulAssign:
52 case EOpVectorTimesMatrixAssign:
53 case EOpVectorTimesScalarAssign:
54 case EOpMatrixTimesScalarAssign:
55 case EOpMatrixTimesMatrixAssign:
56 case EOpDivAssign:
57 case EOpIModAssign:
58 case EOpBitShiftLeftAssign:
59 case EOpBitShiftRightAssign:
60 case EOpBitwiseAndAssign:
61 case EOpBitwiseXorAssign:
62 case EOpBitwiseOrAssign:
63 return true;
64 default:
65 return false;
66 }
67 }
68
ViewBinaryChain(TOperator op,TIntermTyped & node,std::vector<TIntermTyped * > & out)69 bool ViewBinaryChain(TOperator op, TIntermTyped &node, std::vector<TIntermTyped *> &out)
70 {
71 TIntermBinary *binary = node.getAsBinaryNode();
72 if (!binary || binary->getOp() != op)
73 {
74 return false;
75 }
76
77 TIntermTyped *left = binary->getLeft();
78 TIntermTyped *right = binary->getRight();
79
80 if (!ViewBinaryChain(op, *left, out))
81 {
82 out.push_back(left);
83 }
84
85 if (!ViewBinaryChain(op, *right, out))
86 {
87 out.push_back(right);
88 }
89
90 return true;
91 }
92
ViewBinaryChain(TIntermBinary & node)93 std::vector<TIntermTyped *> ViewBinaryChain(TIntermBinary &node)
94 {
95 std::vector<TIntermTyped *> chain;
96 ViewBinaryChain(node.getOp(), node, chain);
97 ASSERT(chain.size() >= 2);
98 return chain;
99 }
100
101 class PrePass : public TIntermRebuild
102 {
103 public:
PrePass(TCompiler & compiler)104 PrePass(TCompiler &compiler) : TIntermRebuild(compiler, true, true) {}
105
106 private:
107 // Change chains of
108 // x OP y OP z
109 // to
110 // x OP (y OP z)
111 // regardless of original parenthesization.
reassociateRight(TIntermBinary & node)112 TIntermTyped &reassociateRight(TIntermBinary &node)
113 {
114 const TOperator op = node.getOp();
115 std::vector<TIntermTyped *> chain = ViewBinaryChain(node);
116
117 TIntermTyped *result = chain.back();
118 chain.pop_back();
119 ASSERT(result);
120
121 const auto begin = chain.rbegin();
122 const auto end = chain.rend();
123
124 for (auto iter = begin; iter != end; ++iter)
125 {
126 TIntermTyped *part = *iter;
127 ASSERT(part);
128 TIntermNode *temp = rebuild(*part).single();
129 ASSERT(temp);
130 part = temp->getAsTyped();
131 ASSERT(part);
132 result = new TIntermBinary(op, part, result);
133 }
134 return *result;
135 }
136
137 private:
visitBinaryPre(TIntermBinary & node)138 PreResult visitBinaryPre(TIntermBinary &node) override
139 {
140 const TOperator op = node.getOp();
141 if (op == TOperator::EOpLogicalAnd || op == TOperator::EOpLogicalOr)
142 {
143 return {reassociateRight(node), VisitBits::Neither};
144 }
145 return node;
146 }
147 };
148
149 class Separator : public TIntermRebuild
150 {
151 IdGen &mIdGen;
152 std::vector<std::vector<TIntermNode *>> mStmtsStack;
153 std::vector<std::unordered_map<const TVariable *, TIntermDeclaration *>> mBindingMapStack;
154 std::unordered_map<TIntermTyped *, TIntermTyped *> mExprMap;
155 std::unordered_set<TIntermDeclaration *> mMaskedDecls;
156
157 public:
Separator(TCompiler & compiler,SymbolEnv & symbolEnv,IdGen & idGen)158 Separator(TCompiler &compiler, SymbolEnv &symbolEnv, IdGen &idGen)
159 : TIntermRebuild(compiler, true, true), mIdGen(idGen)
160 {}
161
~Separator()162 ~Separator() override
163 {
164 ASSERT(mStmtsStack.empty());
165 ASSERT(mExprMap.empty());
166 ASSERT(mBindingMapStack.empty());
167 }
168
169 private:
getCurrStmts()170 std::vector<TIntermNode *> &getCurrStmts()
171 {
172 ASSERT(!mStmtsStack.empty());
173 return mStmtsStack.back();
174 }
175
getCurrBindingMap()176 std::unordered_map<const TVariable *, TIntermDeclaration *> &getCurrBindingMap()
177 {
178 ASSERT(!mBindingMapStack.empty());
179 return mBindingMapStack.back();
180 }
181
pushStmt(TIntermNode & node)182 void pushStmt(TIntermNode &node) { getCurrStmts().push_back(&node); }
183
isTerminalExpr(TIntermNode & node)184 bool isTerminalExpr(TIntermNode &node)
185 {
186 NodeType nodeType = getNodeType(node);
187 switch (nodeType)
188 {
189 case NodeType::Symbol:
190 case NodeType::ConstantUnion:
191 return true;
192 default:
193 return false;
194 }
195 }
196
pullMappedExpr(TIntermTyped * node,bool allowBacktrack)197 TIntermTyped *pullMappedExpr(TIntermTyped *node, bool allowBacktrack)
198 {
199 TIntermTyped *expr;
200
201 {
202 auto iter = mExprMap.find(node);
203 if (iter == mExprMap.end())
204 {
205 return node;
206 }
207 ASSERT(node);
208 expr = iter->second;
209 ASSERT(expr);
210 mExprMap.erase(iter);
211 }
212
213 if (allowBacktrack)
214 {
215 auto &bindingMap = getCurrBindingMap();
216 while (TIntermSymbol *symbol = expr->getAsSymbolNode())
217 {
218 const TVariable &var = symbol->variable();
219 auto iter = bindingMap.find(&var);
220 if (iter == bindingMap.end())
221 {
222 return expr;
223 }
224 ASSERT(var.symbolType() == SymbolType::AngleInternal);
225 TIntermDeclaration *decl = iter->second;
226 ASSERT(decl);
227 expr = ViewDeclaration(*decl).initExpr;
228 ASSERT(expr);
229 bindingMap.erase(iter);
230 mMaskedDecls.insert(decl);
231 }
232 }
233
234 return expr;
235 }
236
isStandaloneExpr(TIntermTyped & expr)237 bool isStandaloneExpr(TIntermTyped &expr)
238 {
239 if (getParentNode()->getAsBlock())
240 {
241 return true;
242 }
243 // https://bugs.webkit.org/show_bug.cgi?id=227723: Fix for sequence operator.
244 if ((expr.getType().getBasicType() == TBasicType::EbtVoid))
245 {
246 return true;
247 }
248 return false;
249 }
250
pushBinding(TIntermTyped & oldExpr,TIntermTyped & newExpr)251 void pushBinding(TIntermTyped &oldExpr, TIntermTyped &newExpr)
252 {
253 if (isStandaloneExpr(newExpr))
254 {
255 pushStmt(newExpr);
256 return;
257 }
258 if (IsIndex(newExpr))
259 {
260 mExprMap[&oldExpr] = &newExpr;
261 return;
262 }
263 auto &bindingMap = getCurrBindingMap();
264 const Name name = mIdGen.createNewName();
265 auto *var =
266 new TVariable(&mSymbolTable, name.rawName(), &newExpr.getType(), name.symbolType());
267 auto *decl = new TIntermDeclaration(var, &newExpr);
268 pushStmt(*decl);
269 mExprMap[&oldExpr] = new TIntermSymbol(var);
270 bindingMap[var] = decl;
271 }
272
pushStacks()273 void pushStacks()
274 {
275 mStmtsStack.emplace_back();
276 mBindingMapStack.emplace_back();
277 }
278
popStacks()279 void popStacks()
280 {
281 ASSERT(!mBindingMapStack.empty());
282 ASSERT(!mStmtsStack.empty());
283 ASSERT(mStmtsStack.back().empty());
284 mBindingMapStack.pop_back();
285 mStmtsStack.pop_back();
286 }
287
pushStmtsIntoBlock(TIntermBlock & block,std::vector<TIntermNode * > & stmts)288 void pushStmtsIntoBlock(TIntermBlock &block, std::vector<TIntermNode *> &stmts)
289 {
290 TIntermSequence &seq = *block.getSequence();
291 for (TIntermNode *stmt : stmts)
292 {
293 if (TIntermDeclaration *decl = stmt->getAsDeclarationNode())
294 {
295 auto iter = mMaskedDecls.find(decl);
296 if (iter != mMaskedDecls.end())
297 {
298 mMaskedDecls.erase(iter);
299 continue;
300 }
301 }
302 seq.push_back(stmt);
303 }
304 }
305
buildBlockWithTailAssign(const TVariable & var,TIntermTyped & newExpr)306 TIntermBlock &buildBlockWithTailAssign(const TVariable &var, TIntermTyped &newExpr)
307 {
308 std::vector<TIntermNode *> stmts = std::move(getCurrStmts());
309 popStacks();
310
311 auto &block = *new TIntermBlock();
312 auto &seq = *block.getSequence();
313 seq.reserve(1 + stmts.size());
314 pushStmtsIntoBlock(block, stmts);
315 seq.push_back(new TIntermBinary(TOperator::EOpAssign, new TIntermSymbol(&var), &newExpr));
316
317 return block;
318 }
319
320 private:
visitBlockPre(TIntermBlock & node)321 PreResult visitBlockPre(TIntermBlock &node) override
322 {
323 pushStacks();
324 return node;
325 }
326
visitBlockPost(TIntermBlock & node)327 PostResult visitBlockPost(TIntermBlock &node) override
328 {
329 std::vector<TIntermNode *> stmts = std::move(getCurrStmts());
330 popStacks();
331
332 TIntermSequence &seq = *node.getSequence();
333 seq.clear();
334 seq.reserve(stmts.size());
335 pushStmtsIntoBlock(node, stmts);
336
337 TIntermNode *parent = getParentNode();
338 if (parent && parent->getAsBlock())
339 {
340 pushStmt(node);
341 }
342
343 return node;
344 }
345
visitDeclarationPre(TIntermDeclaration & node)346 PreResult visitDeclarationPre(TIntermDeclaration &node) override
347 {
348 Declaration decl = ViewDeclaration(node);
349 if (!decl.initExpr || isTerminalExpr(*decl.initExpr))
350 {
351 pushStmt(node);
352 return {node, VisitBits::Neither};
353 }
354 return node;
355 }
356
visitDeclarationPost(TIntermDeclaration & node)357 PostResult visitDeclarationPost(TIntermDeclaration &node) override
358 {
359 Declaration decl = ViewDeclaration(node);
360 ASSERT(decl.symbol.variable().symbolType() != SymbolType::Empty);
361 ASSERT(!decl.symbol.variable().getType().isStructSpecifier());
362
363 TIntermTyped *newInitExpr = pullMappedExpr(decl.initExpr, true);
364 if (decl.initExpr == newInitExpr)
365 {
366 pushStmt(node);
367 }
368 else
369 {
370 auto &newNode = *new TIntermDeclaration();
371 newNode.appendDeclarator(
372 new TIntermBinary(TOperator::EOpInitialize, &decl.symbol, newInitExpr));
373 pushStmt(newNode);
374 }
375 return node;
376 }
377
visitUnaryPost(TIntermUnary & node)378 PostResult visitUnaryPost(TIntermUnary &node) override
379 {
380 TIntermTyped *expr = node.getOperand();
381 TIntermTyped *newExpr = pullMappedExpr(expr, false);
382 if (expr == newExpr)
383 {
384 pushBinding(node, node);
385 }
386 else
387 {
388 pushBinding(node, *new TIntermUnary(node.getOp(), newExpr, node.getFunction()));
389 }
390 return node;
391 }
392
visitBinaryPre(TIntermBinary & node)393 PreResult visitBinaryPre(TIntermBinary &node) override
394 {
395 const TOperator op = node.getOp();
396 if (op == TOperator::EOpLogicalAnd || op == TOperator::EOpLogicalOr)
397 {
398 TIntermTyped *left = node.getLeft();
399 TIntermTyped *right = node.getRight();
400
401 PostResult leftResult = rebuild(*left);
402 ASSERT(leftResult.single());
403
404 pushStacks();
405 PostResult rightResult = rebuild(*right);
406 ASSERT(rightResult.single());
407
408 return {node, VisitBits::Post};
409 }
410
411 return node;
412 }
413
visitBinaryPost(TIntermBinary & node)414 PostResult visitBinaryPost(TIntermBinary &node) override
415 {
416 const TOperator op = node.getOp();
417 if (op == TOperator::EOpInitialize && getParentNode()->getAsDeclarationNode())
418 {
419 // Special case is handled by visitDeclarationPost
420 return node;
421 }
422
423 TIntermTyped *left = node.getLeft();
424 TIntermTyped *right = node.getRight();
425
426 if (op == TOperator::EOpLogicalAnd || op == TOperator::EOpLogicalOr)
427 {
428 const Name name = mIdGen.createNewName();
429 auto *var = new TVariable(&mSymbolTable, name.rawName(), new TType(TBasicType::EbtBool),
430 name.symbolType());
431
432 TIntermTyped *newRight = pullMappedExpr(right, true);
433 TIntermBlock *rightBlock = &buildBlockWithTailAssign(*var, *newRight);
434 TIntermTyped *newLeft = pullMappedExpr(left, true);
435
436 TIntermTyped *cond = new TIntermSymbol(var);
437 if (op == TOperator::EOpLogicalOr)
438 {
439 cond = new TIntermUnary(TOperator::EOpLogicalNot, cond, nullptr);
440 }
441
442 pushStmt(*new TIntermDeclaration(var, newLeft));
443 pushStmt(*new TIntermIfElse(cond, rightBlock, nullptr));
444 if (!isStandaloneExpr(node))
445 {
446 mExprMap[&node] = new TIntermSymbol(var);
447 }
448
449 return node;
450 }
451
452 const bool isAssign = IsAssignment(op);
453 const bool isCompoundAssign = IsCompoundAssignment(op);
454 TIntermTyped *newLeft = pullMappedExpr(left, false);
455 TIntermTyped *newRight = pullMappedExpr(right, isAssign && !isCompoundAssign);
456 if (op == TOperator::EOpComma)
457 {
458 pushBinding(node, *newRight);
459 return node;
460 }
461 else
462 {
463 TIntermBinary *newNode;
464 if (left == newLeft && right == newRight)
465 {
466 newNode = &node;
467 }
468 else
469 {
470 newNode = new TIntermBinary(op, newLeft, newRight);
471 }
472 pushBinding(node, *newNode);
473 return node;
474 }
475 }
476
visitTernaryPre(TIntermTernary & node)477 PreResult visitTernaryPre(TIntermTernary &node) override
478 {
479 PostResult condResult = rebuild(*node.getCondition());
480 ASSERT(condResult.single());
481
482 pushStacks();
483 PostResult thenResult = rebuild(*node.getTrueExpression());
484 ASSERT(thenResult.single());
485
486 pushStacks();
487 PostResult elseResult = rebuild(*node.getFalseExpression());
488 ASSERT(elseResult.single());
489
490 return {node, VisitBits::Post};
491 }
492
visitTernaryPost(TIntermTernary & node)493 PostResult visitTernaryPost(TIntermTernary &node) override
494 {
495 TIntermTyped *cond = node.getCondition();
496 TIntermTyped *then = node.getTrueExpression();
497 TIntermTyped *else_ = node.getFalseExpression();
498
499 const Name name = mIdGen.createNewName();
500 TType *newType = new TType(node.getType());
501 newType->setInterfaceBlock(nullptr);
502 auto *var = new TVariable(&mSymbolTable, name.rawName(), newType, name.symbolType());
503
504 TIntermTyped *newElse = pullMappedExpr(else_, false);
505 TIntermBlock *elseBlock = &buildBlockWithTailAssign(*var, *newElse);
506 TIntermTyped *newThen = pullMappedExpr(then, true);
507 TIntermBlock *thenBlock = &buildBlockWithTailAssign(*var, *newThen);
508 TIntermTyped *newCond = pullMappedExpr(cond, true);
509
510 pushStmt(*new TIntermDeclaration{var});
511 pushStmt(*new TIntermIfElse(newCond, thenBlock, elseBlock));
512 if (!isStandaloneExpr(node))
513 {
514 mExprMap[&node] = new TIntermSymbol(var);
515 }
516
517 return node;
518 }
519
visitSwizzlePost(TIntermSwizzle & node)520 PostResult visitSwizzlePost(TIntermSwizzle &node) override
521 {
522 TIntermTyped *expr = node.getOperand();
523 TIntermTyped *newExpr = pullMappedExpr(expr, false);
524 if (expr == newExpr)
525 {
526 pushBinding(node, node);
527 }
528 else
529 {
530 pushBinding(node, *new TIntermSwizzle(newExpr, node.getSwizzleOffsets()));
531 }
532 return node;
533 }
534
visitAggregatePost(TIntermAggregate & node)535 PostResult visitAggregatePost(TIntermAggregate &node) override
536 {
537 TIntermSequence &args = *node.getSequence();
538 for (TIntermNode *&arg : args)
539 {
540 TIntermTyped *targ = arg->getAsTyped();
541 ASSERT(targ);
542 arg = pullMappedExpr(targ, false);
543 }
544 pushBinding(node, node);
545 return node;
546 }
547
visitPreprocessorDirectivePost(TIntermPreprocessorDirective & node)548 PostResult visitPreprocessorDirectivePost(TIntermPreprocessorDirective &node) override
549 {
550 pushStmt(node);
551 return node;
552 }
553
visitFunctionPrototypePost(TIntermFunctionPrototype & node)554 PostResult visitFunctionPrototypePost(TIntermFunctionPrototype &node) override
555 {
556 if (!getParentFunction())
557 {
558 pushStmt(node);
559 }
560 return node;
561 }
562
visitCasePre(TIntermCase & node)563 PreResult visitCasePre(TIntermCase &node) override
564 {
565 if (TIntermTyped *cond = node.getCondition())
566 {
567 ASSERT(isTerminalExpr(*cond));
568 }
569 pushStmt(node);
570 return {node, VisitBits::Neither};
571 }
572
visitSwitchPost(TIntermSwitch & node)573 PostResult visitSwitchPost(TIntermSwitch &node) override
574 {
575 TIntermTyped *init = node.getInit();
576 TIntermTyped *newInit = pullMappedExpr(init, false);
577 if (init == newInit)
578 {
579 pushStmt(node);
580 }
581 else
582 {
583 pushStmt(*new TIntermSwitch(newInit, node.getStatementList()));
584 }
585
586 return node;
587 }
588
visitFunctionDefinitionPost(TIntermFunctionDefinition & node)589 PostResult visitFunctionDefinitionPost(TIntermFunctionDefinition &node) override
590 {
591 pushStmt(node);
592 return node;
593 }
594
visitIfElsePost(TIntermIfElse & node)595 PostResult visitIfElsePost(TIntermIfElse &node) override
596 {
597 TIntermTyped *cond = node.getCondition();
598 TIntermTyped *newCond = pullMappedExpr(cond, false);
599 if (cond == newCond)
600 {
601 pushStmt(node);
602 }
603 else
604 {
605 pushStmt(*new TIntermIfElse(newCond, node.getTrueBlock(), node.getFalseBlock()));
606 }
607 return node;
608 }
609
visitBranchPost(TIntermBranch & node)610 PostResult visitBranchPost(TIntermBranch &node) override
611 {
612 TIntermTyped *expr = node.getExpression();
613 TIntermTyped *newExpr = pullMappedExpr(expr, false);
614 if (expr == newExpr)
615 {
616 pushStmt(node);
617 }
618 else
619 {
620 pushStmt(*new TIntermBranch(node.getFlowOp(), newExpr));
621 }
622 return node;
623 }
624
visitLoopPre(TIntermLoop & node)625 PreResult visitLoopPre(TIntermLoop &node) override
626 {
627 if (!rebuildInPlace(*node.getBody()))
628 {
629 UNREACHABLE();
630 }
631 pushStmt(node);
632 return {node, VisitBits::Neither};
633 }
634
visitConstantUnionPost(TIntermConstantUnion & node)635 PostResult visitConstantUnionPost(TIntermConstantUnion &node) override
636 {
637 const TType &type = node.getType();
638 if (!type.isScalar())
639 {
640 pushBinding(node, node);
641 }
642 return node;
643 }
644
visitGlobalQualifierDeclarationPost(TIntermGlobalQualifierDeclaration & node)645 PostResult visitGlobalQualifierDeclarationPost(TIntermGlobalQualifierDeclaration &node) override
646 {
647 // With the removal of RewriteGlobalQualifierDecls, we may encounter globals while
648 // seperating compound expressions.
649 pushStmt(node);
650 return node;
651 }
652 };
653
654 } // anonymous namespace
655
656 ////////////////////////////////////////////////////////////////////////////////
657
SeparateCompoundExpressions(TCompiler & compiler,SymbolEnv & symbolEnv,IdGen & idGen,TIntermBlock & root)658 bool sh::SeparateCompoundExpressions(TCompiler &compiler,
659 SymbolEnv &symbolEnv,
660 IdGen &idGen,
661 TIntermBlock &root)
662 {
663 if (angle::GetBoolEnvironmentVar("GMT_DISABLE_SEPARATE_COMPOUND_EXPRESSIONS"))
664 {
665 return true;
666 }
667
668 if (!SimplifyLoopConditions(&compiler, &root, &compiler.getSymbolTable()))
669 {
670 return false;
671 }
672
673 if (!PrePass(compiler).rebuildRoot(root))
674 {
675 return false;
676 }
677
678 if (!Separator(compiler, symbolEnv, idGen).rebuildRoot(root))
679 {
680 return false;
681 }
682
683 return true;
684 }
685