1 /*
2 * Copyright 2021 Google LLC
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8 #include "src/sksl/ir/SkSLSwitchStatement.h"
9
10 #include "include/core/SkTypes.h"
11 #include "include/private/SkSLString.h"
12 #include "include/private/base/SkTArray.h"
13 #include "include/sksl/SkSLErrorReporter.h"
14 #include "src/core/SkTHash.h"
15 #include "src/sksl/SkSLAnalysis.h"
16 #include "src/sksl/SkSLBuiltinTypes.h"
17 #include "src/sksl/SkSLConstantFolder.h"
18 #include "src/sksl/SkSLContext.h"
19 #include "src/sksl/SkSLProgramSettings.h"
20 #include "src/sksl/ir/SkSLBlock.h"
21 #include "src/sksl/ir/SkSLNop.h"
22 #include "src/sksl/ir/SkSLSwitchCase.h"
23 #include "src/sksl/ir/SkSLSymbolTable.h"
24 #include "src/sksl/ir/SkSLType.h"
25
26 #include <algorithm>
27 #include <forward_list>
28 #include <iterator>
29
30 namespace SkSL {
31
clone() const32 std::unique_ptr<Statement> SwitchStatement::clone() const {
33 StatementArray cases;
34 cases.reserve_back(this->cases().size());
35 for (const std::unique_ptr<Statement>& stmt : this->cases()) {
36 cases.push_back(stmt->clone());
37 }
38 return std::make_unique<SwitchStatement>(fPosition,
39 this->value()->clone(),
40 std::move(cases),
41 SymbolTable::WrapIfBuiltin(this->symbols()));
42 }
43
description() const44 std::string SwitchStatement::description() const {
45 std::string result;
46 result += String::printf("switch (%s) {\n", this->value()->description().c_str());
47 for (const auto& c : this->cases()) {
48 result += c->description();
49 }
50 result += "}";
51 return result;
52 }
53
find_duplicate_case_values(const StatementArray & cases)54 static std::forward_list<const SwitchCase*> find_duplicate_case_values(
55 const StatementArray& cases) {
56 std::forward_list<const SwitchCase*> duplicateCases;
57 SkTHashSet<SKSL_INT> intValues;
58 bool foundDefault = false;
59
60 for (const std::unique_ptr<Statement>& stmt : cases) {
61 const SwitchCase* sc = &stmt->as<SwitchCase>();
62 if (sc->isDefault()) {
63 if (foundDefault) {
64 duplicateCases.push_front(sc);
65 continue;
66 }
67 foundDefault = true;
68 } else {
69 SKSL_INT value = sc->value();
70 if (intValues.contains(value)) {
71 duplicateCases.push_front(sc);
72 continue;
73 }
74 intValues.add(value);
75 }
76 }
77
78 return duplicateCases;
79 }
80
move_all_but_break(std::unique_ptr<Statement> & stmt,StatementArray * target)81 static void move_all_but_break(std::unique_ptr<Statement>& stmt, StatementArray* target) {
82 switch (stmt->kind()) {
83 case Statement::Kind::kBlock: {
84 // Recurse into the block.
85 Block& block = stmt->as<Block>();
86
87 StatementArray blockStmts;
88 blockStmts.reserve_back(block.children().size());
89 for (std::unique_ptr<Statement>& blockStmt : block.children()) {
90 move_all_but_break(blockStmt, &blockStmts);
91 }
92
93 target->push_back(Block::Make(block.fPosition, std::move(blockStmts), block.blockKind(),
94 block.symbolTable()));
95 break;
96 }
97
98 case Statement::Kind::kBreak:
99 // Do not append a break to the target.
100 break;
101
102 default:
103 // Append normal statements to the target.
104 target->push_back(std::move(stmt));
105 break;
106 }
107 }
108
BlockForCase(StatementArray * cases,SwitchCase * caseToCapture,std::shared_ptr<SymbolTable> symbolTable)109 std::unique_ptr<Statement> SwitchStatement::BlockForCase(StatementArray* cases,
110 SwitchCase* caseToCapture,
111 std::shared_ptr<SymbolTable> symbolTable) {
112 // We have to be careful to not move any of the pointers until after we're sure we're going to
113 // succeed, so before we make any changes at all, we check the switch-cases to decide on a plan
114 // of action. First, find the switch-case we are interested in.
115 auto iter = cases->begin();
116 for (; iter != cases->end(); ++iter) {
117 const SwitchCase& sc = (*iter)->as<SwitchCase>();
118 if (&sc == caseToCapture) {
119 break;
120 }
121 }
122
123 // Next, walk forward through the rest of the switch. If we find a conditional break, we're
124 // stuck and can't simplify at all. If we find an unconditional break, we have a range of
125 // statements that we can use for simplification.
126 auto startIter = iter;
127 Statement* stripBreakStmt = nullptr;
128 for (; iter != cases->end(); ++iter) {
129 std::unique_ptr<Statement>& stmt = (*iter)->as<SwitchCase>().statement();
130 if (Analysis::SwitchCaseContainsConditionalExit(*stmt)) {
131 // We can't reduce switch-cases to a block when they have conditional exits.
132 return nullptr;
133 }
134 if (Analysis::SwitchCaseContainsUnconditionalExit(*stmt)) {
135 // We found an unconditional exit. We can use this block, but we'll need to strip
136 // out the break statement if there is one.
137 stripBreakStmt = stmt.get();
138 break;
139 }
140 }
141
142 // We fell off the bottom of the switch or encountered a break. We know the range of statements
143 // that we need to move over, and we know it's safe to do so.
144 StatementArray caseStmts;
145 caseStmts.reserve_back(std::distance(startIter, iter) + 1);
146
147 // We can move over most of the statements as-is.
148 while (startIter != iter) {
149 caseStmts.push_back(std::move((*startIter)->as<SwitchCase>().statement()));
150 ++startIter;
151 }
152
153 // If we found an unconditional break at the end, we need to move what we can while avoiding
154 // that break.
155 if (stripBreakStmt != nullptr) {
156 SkASSERT((*startIter)->as<SwitchCase>().statement().get() == stripBreakStmt);
157 move_all_but_break((*startIter)->as<SwitchCase>().statement(), &caseStmts);
158 }
159
160 // Return our newly-synthesized block.
161 return Block::Make(caseToCapture->fPosition, std::move(caseStmts), Block::Kind::kBracedScope,
162 std::move(symbolTable));
163 }
164
Convert(const Context & context,Position pos,std::unique_ptr<Expression> value,ExpressionArray caseValues,StatementArray caseStatements,std::shared_ptr<SymbolTable> symbolTable)165 std::unique_ptr<Statement> SwitchStatement::Convert(const Context& context,
166 Position pos,
167 std::unique_ptr<Expression> value,
168 ExpressionArray caseValues,
169 StatementArray caseStatements,
170 std::shared_ptr<SymbolTable> symbolTable) {
171 SkASSERT(caseValues.size() == caseStatements.size());
172
173 value = context.fTypes.fInt->coerceExpression(std::move(value), context);
174 if (!value) {
175 return nullptr;
176 }
177
178 StatementArray cases;
179 for (int i = 0; i < caseValues.size(); ++i) {
180 if (caseValues[i]) {
181 Position casePos = caseValues[i]->fPosition;
182 // Case values must be constant integers of the same type as the switch value
183 std::unique_ptr<Expression> caseValue = value->type().coerceExpression(
184 std::move(caseValues[i]), context);
185 if (!caseValue) {
186 return nullptr;
187 }
188 SKSL_INT intValue;
189 if (!ConstantFolder::GetConstantInt(*caseValue, &intValue)) {
190 context.fErrors->error(casePos, "case value must be a constant integer");
191 return nullptr;
192 }
193 cases.push_back(SwitchCase::Make(casePos, intValue, std::move(caseStatements[i])));
194 } else {
195 cases.push_back(SwitchCase::MakeDefault(pos, std::move(caseStatements[i])));
196 }
197 }
198
199 // Detect duplicate `case` labels and report an error.
200 // (Using forward_list here to optimize for the common case of no results.)
201 std::forward_list<const SwitchCase*> duplicateCases = find_duplicate_case_values(cases);
202 if (!duplicateCases.empty()) {
203 duplicateCases.reverse();
204 for (const SwitchCase* sc : duplicateCases) {
205 if (sc->isDefault()) {
206 context.fErrors->error(sc->fPosition, "duplicate default case");
207 } else {
208 context.fErrors->error(sc->fPosition, "duplicate case value '" +
209 std::to_string(sc->value()) + "'");
210 }
211 }
212 return nullptr;
213 }
214
215 return SwitchStatement::Make(
216 context, pos, std::move(value), std::move(cases), std::move(symbolTable));
217 }
218
Make(const Context & context,Position pos,std::unique_ptr<Expression> value,StatementArray cases,std::shared_ptr<SymbolTable> symbolTable)219 std::unique_ptr<Statement> SwitchStatement::Make(const Context& context,
220 Position pos,
221 std::unique_ptr<Expression> value,
222 StatementArray cases,
223 std::shared_ptr<SymbolTable> symbolTable) {
224 // Confirm that every statement in `cases` is a SwitchCase.
225 SkASSERT(std::all_of(cases.begin(), cases.end(), [&](const std::unique_ptr<Statement>& stmt) {
226 return stmt->is<SwitchCase>();
227 }));
228
229 // Confirm that every switch-case value is unique.
230 SkASSERT(find_duplicate_case_values(cases).empty());
231
232 // Flatten switch statements if we're optimizing, and the value is known
233 if (context.fConfig->fSettings.fOptimize) {
234 SKSL_INT switchValue;
235 if (ConstantFolder::GetConstantInt(*value, &switchValue)) {
236 SwitchCase* defaultCase = nullptr;
237 SwitchCase* matchingCase = nullptr;
238 for (const std::unique_ptr<Statement>& stmt : cases) {
239 SwitchCase& sc = stmt->as<SwitchCase>();
240 if (sc.isDefault()) {
241 defaultCase = ≻
242 continue;
243 }
244
245 if (sc.value() == switchValue) {
246 matchingCase = ≻
247 break;
248 }
249 }
250
251 if (!matchingCase) {
252 // No case value matches the switch value.
253 if (!defaultCase) {
254 // No default switch-case exists; the switch had no effect.
255 // We can eliminate the entire switch!
256 return Nop::Make();
257 }
258 // We had a default case; that's what we matched with.
259 matchingCase = defaultCase;
260 }
261
262 // Convert the switch-case that we matched with into a block.
263 std::unique_ptr<Statement> newBlock = BlockForCase(&cases, matchingCase, symbolTable);
264 if (newBlock) {
265 return newBlock;
266 }
267 }
268 }
269
270 // The switch couldn't be optimized away; emit it normally.
271 return std::make_unique<SwitchStatement>(
272 pos, std::move(value), std::move(cases), std::move(symbolTable));
273 }
274
275 } // namespace SkSL
276