• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2016 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // SimplifyLoopConditions is an AST traverser that converts loop conditions and loop expressions
7 // to regular statements inside the loop. This way further transformations that generate statements
8 // from loop conditions and loop expressions work correctly.
9 //
10 
11 #include "compiler/translator/tree_ops/SimplifyLoopConditions.h"
12 
13 #include "compiler/translator/StaticType.h"
14 #include "compiler/translator/tree_util/IntermNodePatternMatcher.h"
15 #include "compiler/translator/tree_util/IntermNode_util.h"
16 #include "compiler/translator/tree_util/IntermTraverse.h"
17 
18 namespace sh
19 {
20 
21 namespace
22 {
23 
24 struct LoopInfo
25 {
26     const TVariable *conditionVariable = nullptr;
27     TIntermTyped *condition            = nullptr;
28     TIntermTyped *expression           = nullptr;
29 };
30 
31 class SimplifyLoopConditionsTraverser : public TLValueTrackingTraverser
32 {
33   public:
34     SimplifyLoopConditionsTraverser(const IntermNodePatternMatcher *conditionsToSimplify,
35                                     TSymbolTable *symbolTable);
36 
37     void traverseLoop(TIntermLoop *node) override;
38 
39     bool visitUnary(Visit visit, TIntermUnary *node) override;
40     bool visitBinary(Visit visit, TIntermBinary *node) override;
41     bool visitAggregate(Visit visit, TIntermAggregate *node) override;
42     bool visitTernary(Visit visit, TIntermTernary *node) override;
43     bool visitDeclaration(Visit visit, TIntermDeclaration *node) override;
44     bool visitBranch(Visit visit, TIntermBranch *node) override;
45 
foundLoopToChange() const46     bool foundLoopToChange() const { return mFoundLoopToChange; }
47 
48   protected:
49     // Marked to true once an operation that needs to be hoisted out of a loop expression has been
50     // found.
51     bool mFoundLoopToChange;
52     bool mInsideLoopInitConditionOrExpression;
53     const IntermNodePatternMatcher *mConditionsToSimplify;
54 
55   private:
56     LoopInfo mLoop;
57 };
58 
SimplifyLoopConditionsTraverser(const IntermNodePatternMatcher * conditionsToSimplify,TSymbolTable * symbolTable)59 SimplifyLoopConditionsTraverser::SimplifyLoopConditionsTraverser(
60     const IntermNodePatternMatcher *conditionsToSimplify,
61     TSymbolTable *symbolTable)
62     : TLValueTrackingTraverser(true, false, false, symbolTable),
63       mFoundLoopToChange(false),
64       mInsideLoopInitConditionOrExpression(false),
65       mConditionsToSimplify(conditionsToSimplify)
66 {}
67 
68 // If we're inside a loop initialization, condition, or expression, we check for expressions that
69 // should be moved out of the loop condition or expression. If one is found, the loop is
70 // transformed.
71 // If we're not inside loop initialization, condition, or expression, we only need to traverse nodes
72 // that may contain loops.
73 
visitUnary(Visit visit,TIntermUnary * node)74 bool SimplifyLoopConditionsTraverser::visitUnary(Visit visit, TIntermUnary *node)
75 {
76     if (!mInsideLoopInitConditionOrExpression)
77         return false;
78 
79     if (mFoundLoopToChange)
80         return false;  // Already decided to change this loop.
81 
82     ASSERT(mConditionsToSimplify);
83     mFoundLoopToChange = mConditionsToSimplify->match(node);
84     return !mFoundLoopToChange;
85 }
86 
visitBinary(Visit visit,TIntermBinary * node)87 bool SimplifyLoopConditionsTraverser::visitBinary(Visit visit, TIntermBinary *node)
88 {
89     if (!mInsideLoopInitConditionOrExpression)
90         return false;
91 
92     if (mFoundLoopToChange)
93         return false;  // Already decided to change this loop.
94 
95     ASSERT(mConditionsToSimplify);
96     mFoundLoopToChange =
97         mConditionsToSimplify->match(node, getParentNode(), isLValueRequiredHere());
98     return !mFoundLoopToChange;
99 }
100 
visitAggregate(Visit visit,TIntermAggregate * node)101 bool SimplifyLoopConditionsTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
102 {
103     if (!mInsideLoopInitConditionOrExpression)
104         return false;
105 
106     if (mFoundLoopToChange)
107         return false;  // Already decided to change this loop.
108 
109     ASSERT(mConditionsToSimplify);
110     mFoundLoopToChange = mConditionsToSimplify->match(node, getParentNode());
111     return !mFoundLoopToChange;
112 }
113 
visitTernary(Visit visit,TIntermTernary * node)114 bool SimplifyLoopConditionsTraverser::visitTernary(Visit visit, TIntermTernary *node)
115 {
116     if (!mInsideLoopInitConditionOrExpression)
117         return false;
118 
119     if (mFoundLoopToChange)
120         return false;  // Already decided to change this loop.
121 
122     ASSERT(mConditionsToSimplify);
123     mFoundLoopToChange = mConditionsToSimplify->match(node);
124     return !mFoundLoopToChange;
125 }
126 
visitDeclaration(Visit visit,TIntermDeclaration * node)127 bool SimplifyLoopConditionsTraverser::visitDeclaration(Visit visit, TIntermDeclaration *node)
128 {
129     if (!mInsideLoopInitConditionOrExpression)
130         return false;
131 
132     if (mFoundLoopToChange)
133         return false;  // Already decided to change this loop.
134 
135     ASSERT(mConditionsToSimplify);
136     mFoundLoopToChange = mConditionsToSimplify->match(node);
137     return !mFoundLoopToChange;
138 }
139 
visitBranch(Visit visit,TIntermBranch * node)140 bool SimplifyLoopConditionsTraverser::visitBranch(Visit visit, TIntermBranch *node)
141 {
142     if (node->getFlowOp() == EOpContinue && (mLoop.condition || mLoop.expression))
143     {
144         TIntermBlock *parent = getParentNode()->getAsBlock();
145         ASSERT(parent);
146         TIntermSequence seq;
147         if (mLoop.expression)
148         {
149             seq.push_back(mLoop.expression->deepCopy());
150         }
151         if (mLoop.condition)
152         {
153             ASSERT(mLoop.conditionVariable);
154             seq.push_back(
155                 CreateTempAssignmentNode(mLoop.conditionVariable, mLoop.condition->deepCopy()));
156         }
157         seq.push_back(node);
158         mMultiReplacements.push_back(NodeReplaceWithMultipleEntry(parent, node, std::move(seq)));
159     }
160 
161     return true;
162 }
163 
CreateFromBody(TIntermLoop * node,bool * bodyEndsInBranchOut)164 TIntermBlock *CreateFromBody(TIntermLoop *node, bool *bodyEndsInBranchOut)
165 {
166     TIntermBlock *newBody = new TIntermBlock();
167     *bodyEndsInBranchOut  = false;
168 
169     TIntermBlock *nodeBody = node->getBody();
170     if (nodeBody != nullptr)
171     {
172         newBody->getSequence()->push_back(nodeBody);
173         *bodyEndsInBranchOut = EndsInBranch(nodeBody);
174     }
175     return newBody;
176 }
177 
traverseLoop(TIntermLoop * node)178 void SimplifyLoopConditionsTraverser::traverseLoop(TIntermLoop *node)
179 {
180     // Mark that we're inside a loop condition or expression, and determine if the loop needs to be
181     // transformed.
182 
183     ScopedNodeInTraversalPath addToPath(this, node);
184 
185     mInsideLoopInitConditionOrExpression = true;
186     mFoundLoopToChange                   = !mConditionsToSimplify;
187 
188     if (!mFoundLoopToChange && node->getInit())
189     {
190         node->getInit()->traverse(this);
191     }
192 
193     if (!mFoundLoopToChange && node->getCondition())
194     {
195         node->getCondition()->traverse(this);
196     }
197 
198     if (!mFoundLoopToChange && node->getExpression())
199     {
200         node->getExpression()->traverse(this);
201     }
202 
203     mInsideLoopInitConditionOrExpression = false;
204 
205     const LoopInfo prevLoop = mLoop;
206 
207     if (mFoundLoopToChange)
208     {
209         const TType *boolType   = StaticType::Get<EbtBool, EbpUndefined, EvqTemporary, 1, 1>();
210         mLoop.conditionVariable = CreateTempVariable(mSymbolTable, boolType);
211         mLoop.condition         = node->getCondition();
212         mLoop.expression        = node->getExpression();
213 
214         // Replace the loop condition with a boolean variable that's updated on each iteration.
215         TLoopType loopType = node->getType();
216         if (loopType == ELoopWhile)
217         {
218             ASSERT(!mLoop.expression);
219 
220             if (mLoop.condition->getAsSymbolNode())
221             {
222                 // Mask continue statement condition variable update.
223                 mLoop.condition = nullptr;
224             }
225             else if (mLoop.condition->getAsConstantUnion())
226             {
227                 // Transform:
228                 //   while (expr) { body; }
229                 // into
230                 //   bool s0 = expr;
231                 //   while (s0) { body; }
232                 TIntermDeclaration *tempInitDeclaration =
233                     CreateTempInitDeclarationNode(mLoop.conditionVariable, mLoop.condition);
234                 insertStatementInParentBlock(tempInitDeclaration);
235 
236                 node->setCondition(CreateTempSymbolNode(mLoop.conditionVariable));
237 
238                 // Mask continue statement condition variable update.
239                 mLoop.condition = nullptr;
240             }
241             else
242             {
243                 // Transform:
244                 //   while (expr) { body; }
245                 // into
246                 //   bool s0 = expr;
247                 //   while (s0) { { body; } s0 = expr; }
248                 //
249                 // Local case statements are transformed into:
250                 //   s0 = expr; continue;
251                 TIntermDeclaration *tempInitDeclaration =
252                     CreateTempInitDeclarationNode(mLoop.conditionVariable, mLoop.condition);
253                 insertStatementInParentBlock(tempInitDeclaration);
254 
255                 bool bodyEndsInBranch;
256                 TIntermBlock *newBody = CreateFromBody(node, &bodyEndsInBranch);
257                 if (!bodyEndsInBranch)
258                 {
259                     newBody->getSequence()->push_back(CreateTempAssignmentNode(
260                         mLoop.conditionVariable, mLoop.condition->deepCopy()));
261                 }
262 
263                 // Can't use queueReplacement to replace old body, since it may have been nullptr.
264                 // It's safe to do the replacements in place here - the new body will still be
265                 // traversed, but that won't create any problems.
266                 node->setBody(newBody);
267                 node->setCondition(CreateTempSymbolNode(mLoop.conditionVariable));
268             }
269         }
270         else if (loopType == ELoopDoWhile)
271         {
272             ASSERT(!mLoop.expression);
273 
274             if (mLoop.condition->getAsSymbolNode())
275             {
276                 // Mask continue statement condition variable update.
277                 mLoop.condition = nullptr;
278             }
279             else if (mLoop.condition->getAsConstantUnion())
280             {
281                 // Transform:
282                 //   do {
283                 //     body;
284                 //   } while (expr);
285                 // into
286                 //   bool s0 = expr;
287                 //   do {
288                 //     body;
289                 //   } while (s0);
290                 TIntermDeclaration *tempInitDeclaration =
291                     CreateTempInitDeclarationNode(mLoop.conditionVariable, mLoop.condition);
292                 insertStatementInParentBlock(tempInitDeclaration);
293 
294                 node->setCondition(CreateTempSymbolNode(mLoop.conditionVariable));
295 
296                 // Mask continue statement condition variable update.
297                 mLoop.condition = nullptr;
298             }
299             else
300             {
301                 // Transform:
302                 //   do {
303                 //     body;
304                 //   } while (expr);
305                 // into
306                 //   bool s0;
307                 //   do {
308                 //     { body; }
309                 //     s0 = expr;
310                 //   } while (s0);
311                 // Local case statements are transformed into:
312                 //   s0 = expr; continue;
313                 TIntermDeclaration *tempInitDeclaration =
314                     CreateTempDeclarationNode(mLoop.conditionVariable);
315                 insertStatementInParentBlock(tempInitDeclaration);
316 
317                 bool bodyEndsInBranch;
318                 TIntermBlock *newBody = CreateFromBody(node, &bodyEndsInBranch);
319                 if (!bodyEndsInBranch)
320                 {
321                     newBody->getSequence()->push_back(
322                         CreateTempAssignmentNode(mLoop.conditionVariable, mLoop.condition));
323                 }
324 
325                 // Can't use queueReplacement to replace old body, since it may have been nullptr.
326                 // It's safe to do the replacements in place here - the new body will still be
327                 // traversed, but that won't create any problems.
328                 node->setBody(newBody);
329                 node->setCondition(CreateTempSymbolNode(mLoop.conditionVariable));
330             }
331         }
332         else if (loopType == ELoopFor)
333         {
334             if (!mLoop.condition)
335             {
336                 mLoop.condition = CreateBoolNode(true);
337             }
338 
339             TIntermLoop *whileLoop;
340             TIntermBlock *loopScope            = new TIntermBlock();
341             TIntermSequence *loopScopeSequence = loopScope->getSequence();
342 
343             // Insert "init;"
344             if (node->getInit())
345             {
346                 loopScopeSequence->push_back(node->getInit());
347             }
348 
349             if (mLoop.condition->getAsSymbolNode())
350             {
351                 // Move the loop condition inside the loop.
352                 // Transform:
353                 //   for (init; expr; exprB) { body; }
354                 // into
355                 //   {
356                 //     init;
357                 //     while (expr) {
358                 //       { body; }
359                 //       exprB;
360                 //     }
361                 //   }
362                 //
363                 // Local case statements are transformed into:
364                 //   exprB; continue;
365 
366                 // Insert "{ body; }" in the while loop
367                 bool bodyEndsInBranch;
368                 TIntermBlock *whileLoopBody = CreateFromBody(node, &bodyEndsInBranch);
369                 // Insert "exprB;" in the while loop
370                 if (!bodyEndsInBranch && node->getExpression())
371                 {
372                     whileLoopBody->getSequence()->push_back(node->getExpression());
373                 }
374                 // Create "while(expr) { whileLoopBody }"
375                 whileLoop =
376                     new TIntermLoop(ELoopWhile, nullptr, mLoop.condition, nullptr, whileLoopBody);
377 
378                 // Mask continue statement condition variable update.
379                 mLoop.condition = nullptr;
380             }
381             else if (mLoop.condition->getAsConstantUnion())
382             {
383                 // Move the loop condition inside the loop.
384                 // Transform:
385                 //   for (init; expr; exprB) { body; }
386                 // into
387                 //   {
388                 //     init;
389                 //     bool s0 = expr;
390                 //     while (s0) {
391                 //       { body; }
392                 //       exprB;
393                 //     }
394                 //   }
395                 //
396                 // Local case statements are transformed into:
397                 //   exprB; continue;
398 
399                 // Insert "bool s0 = expr;"
400                 loopScopeSequence->push_back(
401                     CreateTempInitDeclarationNode(mLoop.conditionVariable, mLoop.condition));
402                 // Insert "{ body; }" in the while loop
403                 bool bodyEndsInBranch;
404                 TIntermBlock *whileLoopBody = CreateFromBody(node, &bodyEndsInBranch);
405                 // Insert "exprB;" in the while loop
406                 if (!bodyEndsInBranch && node->getExpression())
407                 {
408                     whileLoopBody->getSequence()->push_back(node->getExpression());
409                 }
410                 // Create "while(s0) { whileLoopBody }"
411                 whileLoop = new TIntermLoop(ELoopWhile, nullptr,
412                                             CreateTempSymbolNode(mLoop.conditionVariable), nullptr,
413                                             whileLoopBody);
414 
415                 // Mask continue statement condition variable update.
416                 mLoop.condition = nullptr;
417             }
418             else
419             {
420                 // Move the loop condition inside the loop.
421                 // Transform:
422                 //   for (init; expr; exprB) { body; }
423                 // into
424                 //   {
425                 //     init;
426                 //     bool s0 = expr;
427                 //     while (s0) {
428                 //       { body; }
429                 //       exprB;
430                 //       s0 = expr;
431                 //     }
432                 //   }
433                 //
434                 // Local case statements are transformed into:
435                 //   exprB; s0 = expr; continue;
436 
437                 // Insert "bool s0 = expr;"
438                 loopScopeSequence->push_back(
439                     CreateTempInitDeclarationNode(mLoop.conditionVariable, mLoop.condition));
440                 // Insert "{ body; }" in the while loop
441                 bool bodyEndsInBranch;
442                 TIntermBlock *whileLoopBody = CreateFromBody(node, &bodyEndsInBranch);
443                 // Insert "exprB;" in the while loop
444                 if (!bodyEndsInBranch && node->getExpression())
445                 {
446                     whileLoopBody->getSequence()->push_back(node->getExpression());
447                 }
448                 // Insert "s0 = expr;" in the while loop
449                 if (!bodyEndsInBranch)
450                 {
451                     whileLoopBody->getSequence()->push_back(CreateTempAssignmentNode(
452                         mLoop.conditionVariable, mLoop.condition->deepCopy()));
453                 }
454                 // Create "while(s0) { whileLoopBody }"
455                 whileLoop = new TIntermLoop(ELoopWhile, nullptr,
456                                             CreateTempSymbolNode(mLoop.conditionVariable), nullptr,
457                                             whileLoopBody);
458             }
459 
460             loopScope->getSequence()->push_back(whileLoop);
461             queueReplacement(loopScope, OriginalNode::IS_DROPPED);
462 
463             // After this the old body node will be traversed and loops inside it may be
464             // transformed. This is fine, since the old body node will still be in the AST after
465             // the transformation that's queued here, and transforming loops inside it doesn't
466             // need to know the exact post-transform path to it.
467         }
468     }
469 
470     mFoundLoopToChange = false;
471 
472     // We traverse the body of the loop even if the loop is transformed.
473     if (node->getBody())
474         node->getBody()->traverse(this);
475 
476     mLoop = prevLoop;
477 }
478 
479 }  // namespace
480 
SimplifyLoopConditions(TCompiler * compiler,TIntermNode * root,TSymbolTable * symbolTable)481 bool SimplifyLoopConditions(TCompiler *compiler, TIntermNode *root, TSymbolTable *symbolTable)
482 {
483     SimplifyLoopConditionsTraverser traverser(nullptr, symbolTable);
484     root->traverse(&traverser);
485     return traverser.updateTree(compiler, root);
486 }
487 
SimplifyLoopConditions(TCompiler * compiler,TIntermNode * root,unsigned int conditionsToSimplifyMask,TSymbolTable * symbolTable)488 bool SimplifyLoopConditions(TCompiler *compiler,
489                             TIntermNode *root,
490                             unsigned int conditionsToSimplifyMask,
491                             TSymbolTable *symbolTable)
492 {
493     IntermNodePatternMatcher conditionsToSimplify(conditionsToSimplifyMask);
494     SimplifyLoopConditionsTraverser traverser(&conditionsToSimplify, symbolTable);
495     root->traverse(&traverser);
496     return traverser.updateTree(compiler, root);
497 }
498 
499 }  // namespace sh
500