1 //
2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RemoveSwitchFallThrough.cpp: Remove fall-through from switch statements.
7 // Note that it is unsafe to do further AST transformations on the AST generated
8 // by this function. It leaves duplicate nodes in the AST making replacements
9 // unreliable.
10
11 #include "compiler/translator/tree_ops/d3d/RemoveSwitchFallThrough.h"
12
13 #include "compiler/translator/Diagnostics.h"
14 #include "compiler/translator/tree_util/IntermTraverse.h"
15
16 namespace sh
17 {
18
19 namespace
20 {
21
22 class RemoveSwitchFallThroughTraverser : public TIntermTraverser
23 {
24 public:
25 static TIntermBlock *removeFallThrough(TIntermBlock *statementList,
26 PerformanceDiagnostics *perfDiagnostics);
27
28 private:
29 RemoveSwitchFallThroughTraverser(TIntermBlock *statementList,
30 PerformanceDiagnostics *perfDiagnostics);
31
32 void visitSymbol(TIntermSymbol *node) override;
33 void visitConstantUnion(TIntermConstantUnion *node) override;
34 bool visitDeclaration(Visit, TIntermDeclaration *node) override;
35 bool visitBinary(Visit, TIntermBinary *node) override;
36 bool visitUnary(Visit, TIntermUnary *node) override;
37 bool visitTernary(Visit visit, TIntermTernary *node) override;
38 bool visitSwizzle(Visit, TIntermSwizzle *node) override;
39 bool visitIfElse(Visit visit, TIntermIfElse *node) override;
40 bool visitSwitch(Visit, TIntermSwitch *node) override;
41 bool visitCase(Visit, TIntermCase *node) override;
42 bool visitAggregate(Visit, TIntermAggregate *node) override;
43 bool visitBlock(Visit, TIntermBlock *node) override;
44 bool visitLoop(Visit, TIntermLoop *node) override;
45 bool visitBranch(Visit, TIntermBranch *node) override;
46
47 void outputSequence(TIntermSequence *sequence, size_t startIndex);
48 void handlePreviousCase();
49
50 TIntermBlock *mStatementList;
51 TIntermBlock *mStatementListOut;
52 bool mLastStatementWasBreak;
53 TIntermBlock *mPreviousCase;
54 std::vector<TIntermBlock *> mCasesSharingBreak;
55 PerformanceDiagnostics *mPerfDiagnostics;
56 };
57
removeFallThrough(TIntermBlock * statementList,PerformanceDiagnostics * perfDiagnostics)58 TIntermBlock *RemoveSwitchFallThroughTraverser::removeFallThrough(
59 TIntermBlock *statementList,
60 PerformanceDiagnostics *perfDiagnostics)
61 {
62 RemoveSwitchFallThroughTraverser rm(statementList, perfDiagnostics);
63 ASSERT(statementList);
64 statementList->traverse(&rm);
65 ASSERT(rm.mPreviousCase || statementList->getSequence()->empty());
66 if (!rm.mLastStatementWasBreak && rm.mPreviousCase)
67 {
68 // Make sure that there's a branch at the end of the final case inside the switch statement.
69 // This also ensures that any cases that fall through to the final case will get the break.
70 TIntermBranch *finalBreak = new TIntermBranch(EOpBreak, nullptr);
71 rm.mPreviousCase->getSequence()->push_back(finalBreak);
72 rm.mLastStatementWasBreak = true;
73 }
74 rm.handlePreviousCase();
75 return rm.mStatementListOut;
76 }
77
RemoveSwitchFallThroughTraverser(TIntermBlock * statementList,PerformanceDiagnostics * perfDiagnostics)78 RemoveSwitchFallThroughTraverser::RemoveSwitchFallThroughTraverser(
79 TIntermBlock *statementList,
80 PerformanceDiagnostics *perfDiagnostics)
81 : TIntermTraverser(true, false, false),
82 mStatementList(statementList),
83 mLastStatementWasBreak(false),
84 mPreviousCase(nullptr),
85 mPerfDiagnostics(perfDiagnostics)
86 {
87 mStatementListOut = new TIntermBlock();
88 }
89
visitSymbol(TIntermSymbol * node)90 void RemoveSwitchFallThroughTraverser::visitSymbol(TIntermSymbol *node)
91 {
92 // Note that this assumes that switch statements which don't begin by a case statement
93 // have already been weeded out in validation.
94 mPreviousCase->getSequence()->push_back(node);
95 mLastStatementWasBreak = false;
96 }
97
visitConstantUnion(TIntermConstantUnion * node)98 void RemoveSwitchFallThroughTraverser::visitConstantUnion(TIntermConstantUnion *node)
99 {
100 // Conditions of case labels are not traversed, so this is a constant statement like "0;".
101 // These are no-ops so there's no need to add them back to the statement list. Should have
102 // already been pruned out of the AST, in fact.
103 UNREACHABLE();
104 }
105
visitDeclaration(Visit,TIntermDeclaration * node)106 bool RemoveSwitchFallThroughTraverser::visitDeclaration(Visit, TIntermDeclaration *node)
107 {
108 mPreviousCase->getSequence()->push_back(node);
109 mLastStatementWasBreak = false;
110 return false;
111 }
112
visitBinary(Visit,TIntermBinary * node)113 bool RemoveSwitchFallThroughTraverser::visitBinary(Visit, TIntermBinary *node)
114 {
115 mPreviousCase->getSequence()->push_back(node);
116 mLastStatementWasBreak = false;
117 return false;
118 }
119
visitUnary(Visit,TIntermUnary * node)120 bool RemoveSwitchFallThroughTraverser::visitUnary(Visit, TIntermUnary *node)
121 {
122 mPreviousCase->getSequence()->push_back(node);
123 mLastStatementWasBreak = false;
124 return false;
125 }
126
visitTernary(Visit,TIntermTernary * node)127 bool RemoveSwitchFallThroughTraverser::visitTernary(Visit, TIntermTernary *node)
128 {
129 mPreviousCase->getSequence()->push_back(node);
130 mLastStatementWasBreak = false;
131 return false;
132 }
133
visitSwizzle(Visit,TIntermSwizzle * node)134 bool RemoveSwitchFallThroughTraverser::visitSwizzle(Visit, TIntermSwizzle *node)
135 {
136 mPreviousCase->getSequence()->push_back(node);
137 mLastStatementWasBreak = false;
138 return false;
139 }
140
visitIfElse(Visit,TIntermIfElse * node)141 bool RemoveSwitchFallThroughTraverser::visitIfElse(Visit, TIntermIfElse *node)
142 {
143 mPreviousCase->getSequence()->push_back(node);
144 mLastStatementWasBreak = false;
145 return false;
146 }
147
visitSwitch(Visit,TIntermSwitch * node)148 bool RemoveSwitchFallThroughTraverser::visitSwitch(Visit, TIntermSwitch *node)
149 {
150 mPreviousCase->getSequence()->push_back(node);
151 mLastStatementWasBreak = false;
152 // Don't go into nested switch statements
153 return false;
154 }
155
outputSequence(TIntermSequence * sequence,size_t startIndex)156 void RemoveSwitchFallThroughTraverser::outputSequence(TIntermSequence *sequence, size_t startIndex)
157 {
158 for (size_t i = startIndex; i < sequence->size(); ++i)
159 {
160 mStatementListOut->getSequence()->push_back(sequence->at(i));
161 }
162 }
163
handlePreviousCase()164 void RemoveSwitchFallThroughTraverser::handlePreviousCase()
165 {
166 if (mPreviousCase)
167 mCasesSharingBreak.push_back(mPreviousCase);
168 if (mLastStatementWasBreak)
169 {
170 for (size_t i = 0; i < mCasesSharingBreak.size(); ++i)
171 {
172 ASSERT(!mCasesSharingBreak.at(i)->getSequence()->empty());
173 if (mCasesSharingBreak.at(i)->getSequence()->size() == 1)
174 {
175 // Fall-through is allowed in case the label has no statements.
176 outputSequence(mCasesSharingBreak.at(i)->getSequence(), 0);
177 }
178 else
179 {
180 // Include all the statements that this case can fall through under the same label.
181 if (mCasesSharingBreak.size() > i + 1u)
182 {
183 mPerfDiagnostics->warning(mCasesSharingBreak.at(i)->getLine(),
184 "Performance: non-empty fall-through cases in "
185 "switch statements generate extra code.",
186 "switch");
187 }
188 for (size_t j = i; j < mCasesSharingBreak.size(); ++j)
189 {
190 size_t startIndex =
191 j > i ? 1 : 0; // Add the label only from the first sequence.
192 outputSequence(mCasesSharingBreak.at(j)->getSequence(), startIndex);
193 }
194 }
195 }
196 mCasesSharingBreak.clear();
197 }
198 mLastStatementWasBreak = false;
199 mPreviousCase = nullptr;
200 }
201
visitCase(Visit,TIntermCase * node)202 bool RemoveSwitchFallThroughTraverser::visitCase(Visit, TIntermCase *node)
203 {
204 handlePreviousCase();
205 mPreviousCase = new TIntermBlock();
206 mPreviousCase->getSequence()->push_back(node);
207 mPreviousCase->setLine(node->getLine());
208 // Don't traverse the condition of the case statement
209 return false;
210 }
211
visitAggregate(Visit,TIntermAggregate * node)212 bool RemoveSwitchFallThroughTraverser::visitAggregate(Visit, TIntermAggregate *node)
213 {
214 mPreviousCase->getSequence()->push_back(node);
215 mLastStatementWasBreak = false;
216 return false;
217 }
218
DoesBlockAlwaysBreak(TIntermBlock * node)219 bool DoesBlockAlwaysBreak(TIntermBlock *node)
220 {
221 if (node->getSequence()->empty())
222 {
223 return false;
224 }
225
226 TIntermBlock *lastStatementAsBlock = node->getSequence()->back()->getAsBlock();
227 if (lastStatementAsBlock)
228 {
229 return DoesBlockAlwaysBreak(lastStatementAsBlock);
230 }
231
232 TIntermBranch *lastStatementAsBranch = node->getSequence()->back()->getAsBranchNode();
233 return lastStatementAsBranch != nullptr;
234 }
235
visitBlock(Visit,TIntermBlock * node)236 bool RemoveSwitchFallThroughTraverser::visitBlock(Visit, TIntermBlock *node)
237 {
238 if (node != mStatementList)
239 {
240 mPreviousCase->getSequence()->push_back(node);
241 mLastStatementWasBreak = DoesBlockAlwaysBreak(node);
242 return false;
243 }
244 return true;
245 }
246
visitLoop(Visit,TIntermLoop * node)247 bool RemoveSwitchFallThroughTraverser::visitLoop(Visit, TIntermLoop *node)
248 {
249 mPreviousCase->getSequence()->push_back(node);
250 mLastStatementWasBreak = false;
251 return false;
252 }
253
visitBranch(Visit,TIntermBranch * node)254 bool RemoveSwitchFallThroughTraverser::visitBranch(Visit, TIntermBranch *node)
255 {
256 mPreviousCase->getSequence()->push_back(node);
257 // TODO: Verify that accepting return or continue statements here doesn't cause problems.
258 mLastStatementWasBreak = true;
259 return false;
260 }
261
262 } // anonymous namespace
263
RemoveSwitchFallThrough(TIntermBlock * statementList,PerformanceDiagnostics * perfDiagnostics)264 TIntermBlock *RemoveSwitchFallThrough(TIntermBlock *statementList,
265 PerformanceDiagnostics *perfDiagnostics)
266 {
267 return RemoveSwitchFallThroughTraverser::removeFallThrough(statementList, perfDiagnostics);
268 }
269
270 } // namespace sh
271