1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include <unordered_map>
8
9 #include "common/system_utils.h"
10 #include "compiler/translator/TranslatorMetalDirect/AstHelpers.h"
11 #include "compiler/translator/TranslatorMetalDirect/IntermRebuild.h"
12 #include "compiler/translator/TranslatorMetalDirect/SeparateCompoundExpressions.h"
13 #include "compiler/translator/tree_ops/SimplifyLoopConditions.h"
14
15 using namespace sh;
16
17 ////////////////////////////////////////////////////////////////////////////////
18
19 namespace
20 {
21
IsIndex(TOperator op)22 bool IsIndex(TOperator op)
23 {
24 switch (op)
25 {
26 case TOperator::EOpIndexDirect:
27 case TOperator::EOpIndexDirectInterfaceBlock:
28 case TOperator::EOpIndexDirectStruct:
29 case TOperator::EOpIndexIndirect:
30 return true;
31 default:
32 return false;
33 }
34 }
35
IsIndex(TIntermTyped & expr)36 bool IsIndex(TIntermTyped &expr)
37 {
38 if (auto *binary = expr.getAsBinaryNode())
39 {
40 return IsIndex(binary->getOp());
41 }
42 return expr.getAsSwizzleNode();
43 }
44
ViewBinaryChain(TOperator op,TIntermTyped & node,std::vector<TIntermTyped * > & out)45 bool ViewBinaryChain(TOperator op, TIntermTyped &node, std::vector<TIntermTyped *> &out)
46 {
47 TIntermBinary *binary = node.getAsBinaryNode();
48 if (!binary || binary->getOp() != op)
49 {
50 return false;
51 }
52
53 TIntermTyped *left = binary->getLeft();
54 TIntermTyped *right = binary->getRight();
55
56 if (!ViewBinaryChain(op, *left, out))
57 {
58 out.push_back(left);
59 }
60
61 if (!ViewBinaryChain(op, *right, out))
62 {
63 out.push_back(right);
64 }
65
66 return true;
67 }
68
ViewBinaryChain(TIntermBinary & node)69 std::vector<TIntermTyped *> ViewBinaryChain(TIntermBinary &node)
70 {
71 std::vector<TIntermTyped *> chain;
72 ViewBinaryChain(node.getOp(), node, chain);
73 ASSERT(chain.size() >= 2);
74 return chain;
75 }
76
77 class PrePass : public TIntermRebuild
78 {
79 public:
PrePass(TCompiler & compiler)80 PrePass(TCompiler &compiler) : TIntermRebuild(compiler, true, true) {}
81
82 private:
83 // Change chains of
84 // x OP y OP z
85 // to
86 // x OP (y OP z)
87 // regardless of original parenthesization.
reassociateRight(TIntermBinary & node)88 TIntermTyped &reassociateRight(TIntermBinary &node)
89 {
90 const TOperator op = node.getOp();
91 std::vector<TIntermTyped *> chain = ViewBinaryChain(node);
92
93 TIntermTyped *result = chain.back();
94 chain.pop_back();
95 ASSERT(result);
96
97 const auto begin = chain.rbegin();
98 const auto end = chain.rend();
99
100 for (auto iter = begin; iter != end; ++iter)
101 {
102 TIntermTyped *part = *iter;
103 ASSERT(part);
104 TIntermNode *temp = rebuild(*part).single();
105 ASSERT(temp);
106 part = temp->getAsTyped();
107 ASSERT(part);
108 result = new TIntermBinary(op, part, result);
109 }
110 return *result;
111 }
112
113 private:
visitBinaryPre(TIntermBinary & node)114 PreResult visitBinaryPre(TIntermBinary &node) override
115 {
116 const TOperator op = node.getOp();
117 if (op == TOperator::EOpLogicalAnd || op == TOperator::EOpLogicalOr)
118 {
119 return {reassociateRight(node), VisitBits::Neither};
120 }
121 return node;
122 }
123 };
124
125 class Separator : public TIntermRebuild
126 {
127 IdGen &mIdGen;
128 std::vector<std::vector<TIntermNode *>> mStmtsStack;
129 std::vector<std::unordered_map<const TVariable *, TIntermDeclaration *>> mBindingMapStack;
130 std::unordered_map<TIntermTyped *, TIntermTyped *> mExprMap;
131 std::unordered_set<TIntermDeclaration *> mMaskedDecls;
132
133 public:
Separator(TCompiler & compiler,SymbolEnv & symbolEnv,IdGen & idGen)134 Separator(TCompiler &compiler, SymbolEnv &symbolEnv, IdGen &idGen)
135 : TIntermRebuild(compiler, true, true), mIdGen(idGen)
136 {}
137
~Separator()138 ~Separator() override
139 {
140 ASSERT(mStmtsStack.empty());
141 ASSERT(mExprMap.empty());
142 ASSERT(mBindingMapStack.empty());
143 }
144
145 private:
getCurrStmts()146 std::vector<TIntermNode *> &getCurrStmts()
147 {
148 ASSERT(!mStmtsStack.empty());
149 return mStmtsStack.back();
150 }
151
getCurrBindingMap()152 std::unordered_map<const TVariable *, TIntermDeclaration *> &getCurrBindingMap()
153 {
154 ASSERT(!mBindingMapStack.empty());
155 return mBindingMapStack.back();
156 }
157
pushStmt(TIntermNode & node)158 void pushStmt(TIntermNode &node) { getCurrStmts().push_back(&node); }
159
isTerminalExpr(TIntermNode & node)160 bool isTerminalExpr(TIntermNode &node)
161 {
162 NodeType nodeType = getNodeType(node);
163 switch (nodeType)
164 {
165 case NodeType::Symbol:
166 case NodeType::ConstantUnion:
167 return true;
168 default:
169 return false;
170 }
171 }
172
pullMappedExpr(TIntermTyped * node,bool allowBacktrack)173 TIntermTyped *pullMappedExpr(TIntermTyped *node, bool allowBacktrack)
174 {
175 TIntermTyped *expr;
176
177 {
178 auto iter = mExprMap.find(node);
179 if (iter == mExprMap.end())
180 {
181 return node;
182 }
183 ASSERT(node);
184 expr = iter->second;
185 ASSERT(expr);
186 mExprMap.erase(iter);
187 }
188
189 if (allowBacktrack)
190 {
191 auto &bindingMap = getCurrBindingMap();
192 while (TIntermSymbol *symbol = expr->getAsSymbolNode())
193 {
194 const TVariable &var = symbol->variable();
195 auto iter = bindingMap.find(&var);
196 if (iter == bindingMap.end())
197 {
198 return expr;
199 }
200 ASSERT(var.symbolType() == SymbolType::AngleInternal);
201 TIntermDeclaration *decl = iter->second;
202 ASSERT(decl);
203 expr = ViewDeclaration(*decl).initExpr;
204 ASSERT(expr);
205 bindingMap.erase(iter);
206 mMaskedDecls.insert(decl);
207 }
208 }
209
210 return expr;
211 }
212
isStandaloneExpr(TIntermTyped & expr)213 bool isStandaloneExpr(TIntermTyped &expr)
214 {
215 if (getParentNode()->getAsBlock())
216 {
217 return true;
218 }
219 ASSERT(expr.getType().getBasicType() != TBasicType::EbtVoid);
220 return false;
221 }
222
pushBinding(TIntermTyped & oldExpr,TIntermTyped & newExpr)223 void pushBinding(TIntermTyped &oldExpr, TIntermTyped &newExpr)
224 {
225 if (isStandaloneExpr(newExpr))
226 {
227 pushStmt(newExpr);
228 return;
229 }
230 if (IsIndex(newExpr))
231 {
232 mExprMap[&oldExpr] = &newExpr;
233 return;
234 }
235 auto &bindingMap = getCurrBindingMap();
236 const Name name = mIdGen.createNewName("");
237 auto *var =
238 new TVariable(&mSymbolTable, name.rawName(), &newExpr.getType(), name.symbolType());
239 auto *decl = new TIntermDeclaration(var, &newExpr);
240 pushStmt(*decl);
241 mExprMap[&oldExpr] = new TIntermSymbol(var);
242 bindingMap[var] = decl;
243 }
244
pushStacks()245 void pushStacks()
246 {
247 mStmtsStack.emplace_back();
248 mBindingMapStack.emplace_back();
249 }
250
popStacks()251 void popStacks()
252 {
253 ASSERT(!mBindingMapStack.empty());
254 ASSERT(!mStmtsStack.empty());
255 ASSERT(mStmtsStack.back().empty());
256 mBindingMapStack.pop_back();
257 mStmtsStack.pop_back();
258 }
259
pushStmtsIntoBlock(TIntermBlock & block,std::vector<TIntermNode * > & stmts)260 void pushStmtsIntoBlock(TIntermBlock &block, std::vector<TIntermNode *> &stmts)
261 {
262 TIntermSequence &seq = *block.getSequence();
263 for (TIntermNode *stmt : stmts)
264 {
265 if (TIntermDeclaration *decl = stmt->getAsDeclarationNode())
266 {
267 auto iter = mMaskedDecls.find(decl);
268 if (iter != mMaskedDecls.end())
269 {
270 mMaskedDecls.erase(iter);
271 continue;
272 }
273 }
274 seq.push_back(stmt);
275 }
276 }
277
buildBlockWithTailAssign(const TVariable & var,TIntermTyped & newExpr)278 TIntermBlock &buildBlockWithTailAssign(const TVariable &var, TIntermTyped &newExpr)
279 {
280 std::vector<TIntermNode *> stmts = std::move(getCurrStmts());
281 popStacks();
282
283 auto &block = *new TIntermBlock();
284 auto &seq = *block.getSequence();
285 seq.reserve(1 + stmts.size());
286 pushStmtsIntoBlock(block, stmts);
287 seq.push_back(new TIntermBinary(TOperator::EOpAssign, new TIntermSymbol(&var), &newExpr));
288
289 return block;
290 }
291
292 private:
visitBlockPre(TIntermBlock & node)293 PreResult visitBlockPre(TIntermBlock &node) override
294 {
295 pushStacks();
296 return node;
297 }
298
visitBlockPost(TIntermBlock & node)299 PostResult visitBlockPost(TIntermBlock &node) override
300 {
301 std::vector<TIntermNode *> stmts = std::move(getCurrStmts());
302 popStacks();
303
304 TIntermSequence &seq = *node.getSequence();
305 seq.clear();
306 seq.reserve(stmts.size());
307 pushStmtsIntoBlock(node, stmts);
308
309 TIntermNode *parent = getParentNode();
310 if (parent && parent->getAsBlock())
311 {
312 pushStmt(node);
313 }
314
315 return node;
316 }
317
visitDeclarationPre(TIntermDeclaration & node)318 PreResult visitDeclarationPre(TIntermDeclaration &node) override
319 {
320 Declaration decl = ViewDeclaration(node);
321 if (!decl.initExpr || isTerminalExpr(*decl.initExpr))
322 {
323 pushStmt(node);
324 return {node, VisitBits::Neither};
325 }
326 return node;
327 }
328
visitDeclarationPost(TIntermDeclaration & node)329 PostResult visitDeclarationPost(TIntermDeclaration &node) override
330 {
331 Declaration decl = ViewDeclaration(node);
332 ASSERT(decl.symbol.variable().symbolType() != SymbolType::Empty);
333 ASSERT(!decl.symbol.variable().getType().isStructSpecifier());
334
335 TIntermTyped *newInitExpr = pullMappedExpr(decl.initExpr, true);
336 if (decl.initExpr == newInitExpr)
337 {
338 pushStmt(node);
339 }
340 else
341 {
342 auto &newNode = *new TIntermDeclaration();
343 newNode.appendDeclarator(
344 new TIntermBinary(TOperator::EOpInitialize, &decl.symbol, newInitExpr));
345 pushStmt(newNode);
346 }
347 return node;
348 }
349
visitUnaryPost(TIntermUnary & node)350 PostResult visitUnaryPost(TIntermUnary &node) override
351 {
352 TIntermTyped *expr = node.getOperand();
353 TIntermTyped *newExpr = pullMappedExpr(expr, false);
354 if (expr == newExpr)
355 {
356 pushBinding(node, node);
357 }
358 else
359 {
360 pushBinding(node, *new TIntermUnary(node.getOp(), newExpr, node.getFunction()));
361 }
362 return node;
363 }
364
visitBinaryPre(TIntermBinary & node)365 PreResult visitBinaryPre(TIntermBinary &node) override
366 {
367 const TOperator op = node.getOp();
368 if (op == TOperator::EOpLogicalAnd || op == TOperator::EOpLogicalOr)
369 {
370 TIntermTyped *left = node.getLeft();
371 TIntermTyped *right = node.getRight();
372
373 PostResult leftResult = rebuild(*left);
374 ASSERT(leftResult.single());
375
376 pushStacks();
377 PostResult rightResult = rebuild(*right);
378 ASSERT(rightResult.single());
379
380 return {node, VisitBits::Post};
381 }
382
383 return node;
384 }
385
visitBinaryPost(TIntermBinary & node)386 PostResult visitBinaryPost(TIntermBinary &node) override
387 {
388 const TOperator op = node.getOp();
389 if (op == TOperator::EOpInitialize && getParentNode()->getAsDeclarationNode())
390 {
391 // Special case is handled by visitDeclarationPost
392 return node;
393 }
394
395 TIntermTyped *left = node.getLeft();
396 TIntermTyped *right = node.getRight();
397
398 if (op == TOperator::EOpLogicalAnd || op == TOperator::EOpLogicalOr)
399 {
400 const Name name = mIdGen.createNewName("");
401 auto *var = new TVariable(&mSymbolTable, name.rawName(), new TType(TBasicType::EbtBool),
402 name.symbolType());
403
404 TIntermTyped *newRight = pullMappedExpr(right, true);
405 TIntermBlock *rightBlock = &buildBlockWithTailAssign(*var, *newRight);
406 TIntermTyped *newLeft = pullMappedExpr(left, true);
407
408 TIntermTyped *cond = new TIntermSymbol(var);
409 if (op == TOperator::EOpLogicalOr)
410 {
411 cond = new TIntermUnary(TOperator::EOpLogicalNot, cond, nullptr);
412 }
413
414 pushStmt(*new TIntermDeclaration(var, newLeft));
415 pushStmt(*new TIntermIfElse(cond, rightBlock, nullptr));
416 if (!isStandaloneExpr(node))
417 {
418 mExprMap[&node] = new TIntermSymbol(var);
419 }
420
421 return node;
422 }
423
424 const bool isAssign = IsAssignment(op);
425 TIntermTyped *newLeft = pullMappedExpr(left, false);
426 TIntermTyped *newRight = pullMappedExpr(right, isAssign);
427
428 if (op == TOperator::EOpComma)
429 {
430 pushBinding(node, *newRight);
431 return node;
432 }
433 else
434 {
435 TIntermBinary *newNode;
436 if (left == newLeft && right == newRight)
437 {
438 newNode = &node;
439 }
440 else
441 {
442 newNode = new TIntermBinary(op, newLeft, newRight);
443 }
444 pushBinding(node, *newNode);
445 return node;
446 }
447 }
448
visitTernaryPre(TIntermTernary & node)449 PreResult visitTernaryPre(TIntermTernary &node) override
450 {
451 PostResult condResult = rebuild(*node.getCondition());
452 ASSERT(condResult.single());
453
454 pushStacks();
455 PostResult thenResult = rebuild(*node.getTrueExpression());
456 ASSERT(thenResult.single());
457
458 pushStacks();
459 PostResult elseResult = rebuild(*node.getFalseExpression());
460 ASSERT(elseResult.single());
461
462 return {node, VisitBits::Post};
463 }
464
visitTernaryPost(TIntermTernary & node)465 PostResult visitTernaryPost(TIntermTernary &node) override
466 {
467 TIntermTyped *cond = node.getCondition();
468 TIntermTyped *then = node.getTrueExpression();
469 TIntermTyped *else_ = node.getFalseExpression();
470
471 const Name name = mIdGen.createNewName("");
472 auto *var =
473 new TVariable(&mSymbolTable, name.rawName(), &node.getType(), name.symbolType());
474
475 TIntermTyped *newElse = pullMappedExpr(else_, false);
476 TIntermBlock *elseBlock = &buildBlockWithTailAssign(*var, *newElse);
477 TIntermTyped *newThen = pullMappedExpr(then, true);
478 TIntermBlock *thenBlock = &buildBlockWithTailAssign(*var, *newThen);
479 TIntermTyped *newCond = pullMappedExpr(cond, true);
480
481 pushStmt(*new TIntermDeclaration{var});
482 pushStmt(*new TIntermIfElse(newCond, thenBlock, elseBlock));
483 if (!isStandaloneExpr(node))
484 {
485 mExprMap[&node] = new TIntermSymbol(var);
486 }
487
488 return node;
489 }
490
visitSwizzlePost(TIntermSwizzle & node)491 PostResult visitSwizzlePost(TIntermSwizzle &node) override
492 {
493 TIntermTyped *expr = node.getOperand();
494 TIntermTyped *newExpr = pullMappedExpr(expr, false);
495 if (expr == newExpr)
496 {
497 pushBinding(node, node);
498 }
499 else
500 {
501 pushBinding(node, *new TIntermSwizzle(newExpr, node.getSwizzleOffsets()));
502 }
503 return node;
504 }
505
visitAggregatePost(TIntermAggregate & node)506 PostResult visitAggregatePost(TIntermAggregate &node) override
507 {
508 TIntermSequence &args = *node.getSequence();
509 for (TIntermNode *&arg : args)
510 {
511 TIntermTyped *targ = arg->getAsTyped();
512 ASSERT(targ);
513 arg = pullMappedExpr(targ, false);
514 }
515 pushBinding(node, node);
516 return node;
517 }
518
visitPreprocessorDirectivePost(TIntermPreprocessorDirective & node)519 PostResult visitPreprocessorDirectivePost(TIntermPreprocessorDirective &node) override
520 {
521 pushStmt(node);
522 return node;
523 }
524
visitFunctionPrototypePost(TIntermFunctionPrototype & node)525 PostResult visitFunctionPrototypePost(TIntermFunctionPrototype &node) override
526 {
527 if (!getParentFunction())
528 {
529 pushStmt(node);
530 }
531 return node;
532 }
533
visitCasePre(TIntermCase & node)534 PreResult visitCasePre(TIntermCase &node) override
535 {
536 if (TIntermTyped *cond = node.getCondition())
537 {
538 ASSERT(isTerminalExpr(*cond));
539 }
540 pushStmt(node);
541 return {node, VisitBits::Neither};
542 }
543
visitSwitchPost(TIntermSwitch & node)544 PostResult visitSwitchPost(TIntermSwitch &node) override
545 {
546 TIntermTyped *init = node.getInit();
547 TIntermTyped *newInit = pullMappedExpr(init, false);
548 if (init == newInit)
549 {
550 pushStmt(node);
551 }
552 else
553 {
554 pushStmt(*new TIntermSwitch(newInit, node.getStatementList()));
555 }
556
557 return node;
558 }
559
visitFunctionDefinitionPost(TIntermFunctionDefinition & node)560 PostResult visitFunctionDefinitionPost(TIntermFunctionDefinition &node) override
561 {
562 pushStmt(node);
563 return node;
564 }
565
visitIfElsePost(TIntermIfElse & node)566 PostResult visitIfElsePost(TIntermIfElse &node) override
567 {
568 TIntermTyped *cond = node.getCondition();
569 TIntermTyped *newCond = pullMappedExpr(cond, false);
570 if (cond == newCond)
571 {
572 pushStmt(node);
573 }
574 else
575 {
576 pushStmt(*new TIntermIfElse(newCond, node.getTrueBlock(), node.getFalseBlock()));
577 }
578 return node;
579 }
580
visitBranchPost(TIntermBranch & node)581 PostResult visitBranchPost(TIntermBranch &node) override
582 {
583 TIntermTyped *expr = node.getExpression();
584 TIntermTyped *newExpr = pullMappedExpr(expr, false);
585 if (expr == newExpr)
586 {
587 pushStmt(node);
588 }
589 else
590 {
591 pushStmt(*new TIntermBranch(node.getFlowOp(), newExpr));
592 }
593 return node;
594 }
595
visitLoopPre(TIntermLoop & node)596 PreResult visitLoopPre(TIntermLoop &node) override
597 {
598 if (!rebuildInPlace(*node.getBody()))
599 {
600 UNREACHABLE();
601 }
602 pushStmt(node);
603 return {node, VisitBits::Neither};
604 }
605
visitConstantUnionPost(TIntermConstantUnion & node)606 PostResult visitConstantUnionPost(TIntermConstantUnion &node) override
607 {
608 const TType &type = node.getType();
609 if (!type.isScalar())
610 {
611 pushBinding(node, node);
612 }
613 return node;
614 }
615
visitGlobalQualifierDeclarationPost(TIntermGlobalQualifierDeclaration & node)616 PostResult visitGlobalQualifierDeclarationPost(TIntermGlobalQualifierDeclaration &node) override
617 {
618 ASSERT(false); // These should be scrubbed from AST before rewriter is called.
619 pushStmt(node);
620 return node;
621 }
622 };
623
624 } // anonymous namespace
625
626 ////////////////////////////////////////////////////////////////////////////////
627
SeparateCompoundExpressions(TCompiler & compiler,SymbolEnv & symbolEnv,IdGen & idGen,TIntermBlock & root)628 bool sh::SeparateCompoundExpressions(TCompiler &compiler,
629 SymbolEnv &symbolEnv,
630 IdGen &idGen,
631 TIntermBlock &root)
632 {
633 if (angle::GetBoolEnvironmentVar("GMT_DISABLE_SEPARATE_COMPOUND_EXPRESSIONS"))
634 {
635 return true;
636 }
637
638 if (!SimplifyLoopConditions(&compiler, &root, &compiler.getSymbolTable()))
639 {
640 return false;
641 }
642
643 if (!PrePass(compiler).rebuildRoot(root))
644 {
645 return false;
646 }
647
648 if (!Separator(compiler, symbolEnv, idGen).rebuildRoot(root))
649 {
650 return false;
651 }
652
653 return true;
654 }
655