1 /*
2 * Copyright 2021 Google LLC
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8 #include "src/sksl/ir/SkSLSwitchStatement.h"
9
10 #include <forward_list>
11
12 #include "include/private/SkTHash.h"
13 #include "src/sksl/SkSLAnalysis.h"
14 #include "src/sksl/SkSLConstantFolder.h"
15 #include "src/sksl/SkSLContext.h"
16 #include "src/sksl/SkSLProgramSettings.h"
17 #include "src/sksl/ir/SkSLBlock.h"
18 #include "src/sksl/ir/SkSLNop.h"
19 #include "src/sksl/ir/SkSLSymbolTable.h"
20 #include "src/sksl/ir/SkSLType.h"
21
22 namespace SkSL {
23
clone() const24 std::unique_ptr<Statement> SwitchStatement::clone() const {
25 StatementArray cases;
26 cases.reserve_back(this->cases().size());
27 for (const std::unique_ptr<Statement>& stmt : this->cases()) {
28 cases.push_back(stmt->clone());
29 }
30 return std::make_unique<SwitchStatement>(fLine,
31 this->isStatic(),
32 this->value()->clone(),
33 std::move(cases),
34 SymbolTable::WrapIfBuiltin(this->symbols()));
35 }
36
description() const37 std::string SwitchStatement::description() const {
38 std::string result;
39 if (this->isStatic()) {
40 result += "@";
41 }
42 result += String::printf("switch (%s) {\n", this->value()->description().c_str());
43 for (const auto& c : this->cases()) {
44 result += c->description();
45 }
46 result += "}";
47 return result;
48 }
49
find_duplicate_case_values(const StatementArray & cases)50 static std::forward_list<const SwitchCase*> find_duplicate_case_values(
51 const StatementArray& cases) {
52 std::forward_list<const SwitchCase*> duplicateCases;
53 SkTHashSet<SKSL_INT> intValues;
54 bool foundDefault = false;
55
56 for (const std::unique_ptr<Statement>& stmt : cases) {
57 const SwitchCase* sc = &stmt->as<SwitchCase>();
58 if (sc->isDefault()) {
59 if (foundDefault) {
60 duplicateCases.push_front(sc);
61 continue;
62 }
63 foundDefault = true;
64 } else {
65 SKSL_INT value = sc->value();
66 if (intValues.contains(value)) {
67 duplicateCases.push_front(sc);
68 continue;
69 }
70 intValues.add(value);
71 }
72 }
73
74 return duplicateCases;
75 }
76
move_all_but_break(std::unique_ptr<Statement> & stmt,StatementArray * target)77 static void move_all_but_break(std::unique_ptr<Statement>& stmt, StatementArray* target) {
78 switch (stmt->kind()) {
79 case Statement::Kind::kBlock: {
80 // Recurse into the block.
81 Block& block = stmt->as<Block>();
82
83 StatementArray blockStmts;
84 blockStmts.reserve_back(block.children().size());
85 for (std::unique_ptr<Statement>& blockStmt : block.children()) {
86 move_all_but_break(blockStmt, &blockStmts);
87 }
88
89 target->push_back(Block::Make(block.fLine, std::move(blockStmts),
90 block.symbolTable(), block.isScope()));
91 break;
92 }
93
94 case Statement::Kind::kBreak:
95 // Do not append a break to the target.
96 break;
97
98 default:
99 // Append normal statements to the target.
100 target->push_back(std::move(stmt));
101 break;
102 }
103 }
104
BlockForCase(StatementArray * cases,SwitchCase * caseToCapture,std::shared_ptr<SymbolTable> symbolTable)105 std::unique_ptr<Statement> SwitchStatement::BlockForCase(StatementArray* cases,
106 SwitchCase* caseToCapture,
107 std::shared_ptr<SymbolTable> symbolTable) {
108 // We have to be careful to not move any of the pointers until after we're sure we're going to
109 // succeed, so before we make any changes at all, we check the switch-cases to decide on a plan
110 // of action. First, find the switch-case we are interested in.
111 auto iter = cases->begin();
112 for (; iter != cases->end(); ++iter) {
113 const SwitchCase& sc = (*iter)->as<SwitchCase>();
114 if (&sc == caseToCapture) {
115 break;
116 }
117 }
118
119 // Next, walk forward through the rest of the switch. If we find a conditional break, we're
120 // stuck and can't simplify at all. If we find an unconditional break, we have a range of
121 // statements that we can use for simplification.
122 auto startIter = iter;
123 Statement* stripBreakStmt = nullptr;
124 for (; iter != cases->end(); ++iter) {
125 std::unique_ptr<Statement>& stmt = (*iter)->as<SwitchCase>().statement();
126 if (Analysis::SwitchCaseContainsConditionalExit(*stmt)) {
127 // We can't reduce switch-cases to a block when they have conditional exits.
128 return nullptr;
129 }
130 if (Analysis::SwitchCaseContainsUnconditionalExit(*stmt)) {
131 // We found an unconditional exit. We can use this block, but we'll need to strip
132 // out the break statement if there is one.
133 stripBreakStmt = stmt.get();
134 break;
135 }
136 }
137
138 // We fell off the bottom of the switch or encountered a break. We know the range of statements
139 // that we need to move over, and we know it's safe to do so.
140 StatementArray caseStmts;
141 caseStmts.reserve_back(std::distance(startIter, iter) + 1);
142
143 // We can move over most of the statements as-is.
144 while (startIter != iter) {
145 caseStmts.push_back(std::move((*startIter)->as<SwitchCase>().statement()));
146 ++startIter;
147 }
148
149 // If we found an unconditional break at the end, we need to move what we can while avoiding
150 // that break.
151 if (stripBreakStmt != nullptr) {
152 SkASSERT((*startIter)->as<SwitchCase>().statement().get() == stripBreakStmt);
153 move_all_but_break((*startIter)->as<SwitchCase>().statement(), &caseStmts);
154 }
155
156 // Return our newly-synthesized block.
157 return Block::Make(caseToCapture->fLine, std::move(caseStmts), std::move(symbolTable));
158 }
159
Convert(const Context & context,int line,bool isStatic,std::unique_ptr<Expression> value,ExpressionArray caseValues,StatementArray caseStatements,std::shared_ptr<SymbolTable> symbolTable)160 std::unique_ptr<Statement> SwitchStatement::Convert(const Context& context,
161 int line,
162 bool isStatic,
163 std::unique_ptr<Expression> value,
164 ExpressionArray caseValues,
165 StatementArray caseStatements,
166 std::shared_ptr<SymbolTable> symbolTable) {
167 SkASSERT(caseValues.size() == caseStatements.size());
168
169 value = context.fTypes.fInt->coerceExpression(std::move(value), context);
170 if (!value) {
171 return nullptr;
172 }
173
174 StatementArray cases;
175 for (int i = 0; i < caseValues.count(); ++i) {
176 if (caseValues[i]) {
177 int caseLine = caseValues[i]->fLine;
178 // Case values must be constant integers of the same type as the switch value
179 std::unique_ptr<Expression> caseValue = value->type().coerceExpression(
180 std::move(caseValues[i]), context);
181 if (!caseValue) {
182 return nullptr;
183 }
184 SKSL_INT intValue;
185 if (!ConstantFolder::GetConstantInt(*caseValue, &intValue)) {
186 context.fErrors->error(caseValue->fLine, "case value must be a constant integer");
187 return nullptr;
188 }
189 cases.push_back(SwitchCase::Make(caseLine, intValue, std::move(caseStatements[i])));
190 } else {
191 cases.push_back(SwitchCase::MakeDefault(line, std::move(caseStatements[i])));
192 }
193 }
194
195 // Detect duplicate `case` labels and report an error.
196 // (Using forward_list here to optimize for the common case of no results.)
197 std::forward_list<const SwitchCase*> duplicateCases = find_duplicate_case_values(cases);
198 if (!duplicateCases.empty()) {
199 duplicateCases.reverse();
200 for (const SwitchCase* sc : duplicateCases) {
201 if (sc->isDefault()) {
202 context.fErrors->error(sc->fLine, "duplicate default case");
203 } else {
204 context.fErrors->error(sc->fLine, "duplicate case value '" +
205 std::to_string(sc->value()) + "'");
206 }
207 }
208 return nullptr;
209 }
210
211 return SwitchStatement::Make(context, line, isStatic, std::move(value), std::move(cases),
212 std::move(symbolTable));
213 }
214
Make(const Context & context,int line,bool isStatic,std::unique_ptr<Expression> value,StatementArray cases,std::shared_ptr<SymbolTable> symbolTable)215 std::unique_ptr<Statement> SwitchStatement::Make(const Context& context,
216 int line,
217 bool isStatic,
218 std::unique_ptr<Expression> value,
219 StatementArray cases,
220 std::shared_ptr<SymbolTable> symbolTable) {
221 // Confirm that every statement in `cases` is a SwitchCase.
222 SkASSERT(std::all_of(cases.begin(), cases.end(), [&](const std::unique_ptr<Statement>& stmt) {
223 return stmt->is<SwitchCase>();
224 }));
225
226 // Confirm that every switch-case value is unique.
227 SkASSERT(find_duplicate_case_values(cases).empty());
228
229 // Flatten @switch statements.
230 if (isStatic || context.fConfig->fSettings.fOptimize) {
231 SKSL_INT switchValue;
232 if (ConstantFolder::GetConstantInt(*value, &switchValue)) {
233 SwitchCase* defaultCase = nullptr;
234 SwitchCase* matchingCase = nullptr;
235 for (const std::unique_ptr<Statement>& stmt : cases) {
236 SwitchCase& sc = stmt->as<SwitchCase>();
237 if (sc.isDefault()) {
238 defaultCase = ≻
239 continue;
240 }
241
242 if (sc.value() == switchValue) {
243 matchingCase = ≻
244 break;
245 }
246 }
247
248 if (!matchingCase) {
249 // No case value matches the switch value.
250 if (!defaultCase) {
251 // No default switch-case exists; the switch had no effect.
252 // We can eliminate the entire switch!
253 return Nop::Make();
254 }
255 // We had a default case; that's what we matched with.
256 matchingCase = defaultCase;
257 }
258
259 // Convert the switch-case that we matched with into a block.
260 std::unique_ptr<Statement> newBlock = BlockForCase(&cases, matchingCase, symbolTable);
261 if (newBlock) {
262 return newBlock;
263 }
264
265 // Report an error if this was a static switch and BlockForCase failed us.
266 if (isStatic && !context.fConfig->fSettings.fPermitInvalidStaticTests) {
267 context.fErrors->error(value->fLine,
268 "static switch contains non-static conditional exit");
269 return nullptr;
270 }
271 }
272 }
273
274 // The switch couldn't be optimized away; emit it normally.
275 return std::make_unique<SwitchStatement>(line, isStatic, std::move(value), std::move(cases),
276 std::move(symbolTable));
277 }
278
279 } // namespace SkSL
280