• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2021 Google LLC
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/ir/SkSLSwitchStatement.h"
9 
10 #include <forward_list>
11 
12 #include "include/private/SkTHash.h"
13 #include "src/sksl/SkSLAnalysis.h"
14 #include "src/sksl/SkSLConstantFolder.h"
15 #include "src/sksl/SkSLContext.h"
16 #include "src/sksl/SkSLProgramSettings.h"
17 #include "src/sksl/ir/SkSLBlock.h"
18 #include "src/sksl/ir/SkSLNop.h"
19 #include "src/sksl/ir/SkSLSymbolTable.h"
20 #include "src/sksl/ir/SkSLType.h"
21 
22 namespace SkSL {
23 
clone() const24 std::unique_ptr<Statement> SwitchStatement::clone() const {
25     StatementArray cases;
26     cases.reserve_back(this->cases().size());
27     for (const std::unique_ptr<Statement>& stmt : this->cases()) {
28         cases.push_back(stmt->clone());
29     }
30     return std::make_unique<SwitchStatement>(fLine,
31                                              this->isStatic(),
32                                              this->value()->clone(),
33                                              std::move(cases),
34                                              SymbolTable::WrapIfBuiltin(this->symbols()));
35 }
36 
description() const37 std::string SwitchStatement::description() const {
38     std::string result;
39     if (this->isStatic()) {
40         result += "@";
41     }
42     result += String::printf("switch (%s) {\n", this->value()->description().c_str());
43     for (const auto& c : this->cases()) {
44         result += c->description();
45     }
46     result += "}";
47     return result;
48 }
49 
find_duplicate_case_values(const StatementArray & cases)50 static std::forward_list<const SwitchCase*> find_duplicate_case_values(
51         const StatementArray& cases) {
52     std::forward_list<const SwitchCase*> duplicateCases;
53     SkTHashSet<SKSL_INT> intValues;
54     bool foundDefault = false;
55 
56     for (const std::unique_ptr<Statement>& stmt : cases) {
57         const SwitchCase* sc = &stmt->as<SwitchCase>();
58         if (sc->isDefault()) {
59             if (foundDefault) {
60                 duplicateCases.push_front(sc);
61                 continue;
62             }
63             foundDefault = true;
64         } else {
65             SKSL_INT value = sc->value();
66             if (intValues.contains(value)) {
67                 duplicateCases.push_front(sc);
68                 continue;
69             }
70             intValues.add(value);
71         }
72     }
73 
74     return duplicateCases;
75 }
76 
move_all_but_break(std::unique_ptr<Statement> & stmt,StatementArray * target)77 static void move_all_but_break(std::unique_ptr<Statement>& stmt, StatementArray* target) {
78     switch (stmt->kind()) {
79         case Statement::Kind::kBlock: {
80             // Recurse into the block.
81             Block& block = stmt->as<Block>();
82 
83             StatementArray blockStmts;
84             blockStmts.reserve_back(block.children().size());
85             for (std::unique_ptr<Statement>& blockStmt : block.children()) {
86                 move_all_but_break(blockStmt, &blockStmts);
87             }
88 
89             target->push_back(Block::Make(block.fLine, std::move(blockStmts),
90                                           block.symbolTable(), block.isScope()));
91             break;
92         }
93 
94         case Statement::Kind::kBreak:
95             // Do not append a break to the target.
96             break;
97 
98         default:
99             // Append normal statements to the target.
100             target->push_back(std::move(stmt));
101             break;
102     }
103 }
104 
BlockForCase(StatementArray * cases,SwitchCase * caseToCapture,std::shared_ptr<SymbolTable> symbolTable)105 std::unique_ptr<Statement> SwitchStatement::BlockForCase(StatementArray* cases,
106                                                          SwitchCase* caseToCapture,
107                                                          std::shared_ptr<SymbolTable> symbolTable) {
108     // We have to be careful to not move any of the pointers until after we're sure we're going to
109     // succeed, so before we make any changes at all, we check the switch-cases to decide on a plan
110     // of action. First, find the switch-case we are interested in.
111     auto iter = cases->begin();
112     for (; iter != cases->end(); ++iter) {
113         const SwitchCase& sc = (*iter)->as<SwitchCase>();
114         if (&sc == caseToCapture) {
115             break;
116         }
117     }
118 
119     // Next, walk forward through the rest of the switch. If we find a conditional break, we're
120     // stuck and can't simplify at all. If we find an unconditional break, we have a range of
121     // statements that we can use for simplification.
122     auto startIter = iter;
123     Statement* stripBreakStmt = nullptr;
124     for (; iter != cases->end(); ++iter) {
125         std::unique_ptr<Statement>& stmt = (*iter)->as<SwitchCase>().statement();
126         if (Analysis::SwitchCaseContainsConditionalExit(*stmt)) {
127             // We can't reduce switch-cases to a block when they have conditional exits.
128             return nullptr;
129         }
130         if (Analysis::SwitchCaseContainsUnconditionalExit(*stmt)) {
131             // We found an unconditional exit. We can use this block, but we'll need to strip
132             // out the break statement if there is one.
133             stripBreakStmt = stmt.get();
134             break;
135         }
136     }
137 
138     // We fell off the bottom of the switch or encountered a break. We know the range of statements
139     // that we need to move over, and we know it's safe to do so.
140     StatementArray caseStmts;
141     caseStmts.reserve_back(std::distance(startIter, iter) + 1);
142 
143     // We can move over most of the statements as-is.
144     while (startIter != iter) {
145         caseStmts.push_back(std::move((*startIter)->as<SwitchCase>().statement()));
146         ++startIter;
147     }
148 
149     // If we found an unconditional break at the end, we need to move what we can while avoiding
150     // that break.
151     if (stripBreakStmt != nullptr) {
152         SkASSERT((*startIter)->as<SwitchCase>().statement().get() == stripBreakStmt);
153         move_all_but_break((*startIter)->as<SwitchCase>().statement(), &caseStmts);
154     }
155 
156     // Return our newly-synthesized block.
157     return Block::Make(caseToCapture->fLine, std::move(caseStmts), std::move(symbolTable));
158 }
159 
Convert(const Context & context,int line,bool isStatic,std::unique_ptr<Expression> value,ExpressionArray caseValues,StatementArray caseStatements,std::shared_ptr<SymbolTable> symbolTable)160 std::unique_ptr<Statement> SwitchStatement::Convert(const Context& context,
161                                                     int line,
162                                                     bool isStatic,
163                                                     std::unique_ptr<Expression> value,
164                                                     ExpressionArray caseValues,
165                                                     StatementArray caseStatements,
166                                                     std::shared_ptr<SymbolTable> symbolTable) {
167     SkASSERT(caseValues.size() == caseStatements.size());
168 
169     value = context.fTypes.fInt->coerceExpression(std::move(value), context);
170     if (!value) {
171         return nullptr;
172     }
173 
174     StatementArray cases;
175     for (int i = 0; i < caseValues.count(); ++i) {
176         if (caseValues[i]) {
177             int caseLine = caseValues[i]->fLine;
178             // Case values must be constant integers of the same type as the switch value
179             std::unique_ptr<Expression> caseValue = value->type().coerceExpression(
180                     std::move(caseValues[i]), context);
181             if (!caseValue) {
182                 return nullptr;
183             }
184             SKSL_INT intValue;
185             if (!ConstantFolder::GetConstantInt(*caseValue, &intValue)) {
186                 context.fErrors->error(caseValue->fLine, "case value must be a constant integer");
187                 return nullptr;
188             }
189             cases.push_back(SwitchCase::Make(caseLine, intValue, std::move(caseStatements[i])));
190         } else {
191             cases.push_back(SwitchCase::MakeDefault(line, std::move(caseStatements[i])));
192         }
193     }
194 
195     // Detect duplicate `case` labels and report an error.
196     // (Using forward_list here to optimize for the common case of no results.)
197     std::forward_list<const SwitchCase*> duplicateCases = find_duplicate_case_values(cases);
198     if (!duplicateCases.empty()) {
199         duplicateCases.reverse();
200         for (const SwitchCase* sc : duplicateCases) {
201             if (sc->isDefault()) {
202                 context.fErrors->error(sc->fLine, "duplicate default case");
203             } else {
204                 context.fErrors->error(sc->fLine, "duplicate case value '" +
205                                                   std::to_string(sc->value()) + "'");
206             }
207         }
208         return nullptr;
209     }
210 
211     return SwitchStatement::Make(context, line, isStatic, std::move(value), std::move(cases),
212                                  std::move(symbolTable));
213 }
214 
Make(const Context & context,int line,bool isStatic,std::unique_ptr<Expression> value,StatementArray cases,std::shared_ptr<SymbolTable> symbolTable)215 std::unique_ptr<Statement> SwitchStatement::Make(const Context& context,
216                                                  int line,
217                                                  bool isStatic,
218                                                  std::unique_ptr<Expression> value,
219                                                  StatementArray cases,
220                                                  std::shared_ptr<SymbolTable> symbolTable) {
221     // Confirm that every statement in `cases` is a SwitchCase.
222     SkASSERT(std::all_of(cases.begin(), cases.end(), [&](const std::unique_ptr<Statement>& stmt) {
223         return stmt->is<SwitchCase>();
224     }));
225 
226     // Confirm that every switch-case value is unique.
227     SkASSERT(find_duplicate_case_values(cases).empty());
228 
229     // Flatten @switch statements.
230     if (isStatic || context.fConfig->fSettings.fOptimize) {
231         SKSL_INT switchValue;
232         if (ConstantFolder::GetConstantInt(*value, &switchValue)) {
233             SwitchCase* defaultCase = nullptr;
234             SwitchCase* matchingCase = nullptr;
235             for (const std::unique_ptr<Statement>& stmt : cases) {
236                 SwitchCase& sc = stmt->as<SwitchCase>();
237                 if (sc.isDefault()) {
238                     defaultCase = &sc;
239                     continue;
240                 }
241 
242                 if (sc.value() == switchValue) {
243                     matchingCase = &sc;
244                     break;
245                 }
246             }
247 
248             if (!matchingCase) {
249                 // No case value matches the switch value.
250                 if (!defaultCase) {
251                     // No default switch-case exists; the switch had no effect.
252                     // We can eliminate the entire switch!
253                     return Nop::Make();
254                 }
255                 // We had a default case; that's what we matched with.
256                 matchingCase = defaultCase;
257             }
258 
259             // Convert the switch-case that we matched with into a block.
260             std::unique_ptr<Statement> newBlock = BlockForCase(&cases, matchingCase, symbolTable);
261             if (newBlock) {
262                 return newBlock;
263             }
264 
265             // Report an error if this was a static switch and BlockForCase failed us.
266             if (isStatic && !context.fConfig->fSettings.fPermitInvalidStaticTests) {
267                 context.fErrors->error(value->fLine,
268                                        "static switch contains non-static conditional exit");
269                 return nullptr;
270             }
271         }
272     }
273 
274     // The switch couldn't be optimized away; emit it normally.
275     return std::make_unique<SwitchStatement>(line, isStatic, std::move(value), std::move(cases),
276                                              std::move(symbolTable));
277 }
278 
279 }  // namespace SkSL
280