• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2021 Google LLC
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/ir/SkSLSwitchStatement.h"
9 
10 #include "include/core/SkTypes.h"
11 #include "include/private/SkSLString.h"
12 #include "include/private/base/SkTArray.h"
13 #include "include/sksl/SkSLErrorReporter.h"
14 #include "src/core/SkTHash.h"
15 #include "src/sksl/SkSLAnalysis.h"
16 #include "src/sksl/SkSLBuiltinTypes.h"
17 #include "src/sksl/SkSLConstantFolder.h"
18 #include "src/sksl/SkSLContext.h"
19 #include "src/sksl/SkSLProgramSettings.h"
20 #include "src/sksl/ir/SkSLBlock.h"
21 #include "src/sksl/ir/SkSLNop.h"
22 #include "src/sksl/ir/SkSLSwitchCase.h"
23 #include "src/sksl/ir/SkSLSymbolTable.h"
24 #include "src/sksl/ir/SkSLType.h"
25 
26 #include <algorithm>
27 #include <forward_list>
28 #include <iterator>
29 
30 namespace SkSL {
31 
clone() const32 std::unique_ptr<Statement> SwitchStatement::clone() const {
33     StatementArray cases;
34     cases.reserve_back(this->cases().size());
35     for (const std::unique_ptr<Statement>& stmt : this->cases()) {
36         cases.push_back(stmt->clone());
37     }
38     return std::make_unique<SwitchStatement>(fPosition,
39                                              this->value()->clone(),
40                                              std::move(cases),
41                                              SymbolTable::WrapIfBuiltin(this->symbols()));
42 }
43 
description() const44 std::string SwitchStatement::description() const {
45     std::string result;
46     result += String::printf("switch (%s) {\n", this->value()->description().c_str());
47     for (const auto& c : this->cases()) {
48         result += c->description();
49     }
50     result += "}";
51     return result;
52 }
53 
find_duplicate_case_values(const StatementArray & cases)54 static std::forward_list<const SwitchCase*> find_duplicate_case_values(
55         const StatementArray& cases) {
56     std::forward_list<const SwitchCase*> duplicateCases;
57     SkTHashSet<SKSL_INT> intValues;
58     bool foundDefault = false;
59 
60     for (const std::unique_ptr<Statement>& stmt : cases) {
61         const SwitchCase* sc = &stmt->as<SwitchCase>();
62         if (sc->isDefault()) {
63             if (foundDefault) {
64                 duplicateCases.push_front(sc);
65                 continue;
66             }
67             foundDefault = true;
68         } else {
69             SKSL_INT value = sc->value();
70             if (intValues.contains(value)) {
71                 duplicateCases.push_front(sc);
72                 continue;
73             }
74             intValues.add(value);
75         }
76     }
77 
78     return duplicateCases;
79 }
80 
move_all_but_break(std::unique_ptr<Statement> & stmt,StatementArray * target)81 static void move_all_but_break(std::unique_ptr<Statement>& stmt, StatementArray* target) {
82     switch (stmt->kind()) {
83         case Statement::Kind::kBlock: {
84             // Recurse into the block.
85             Block& block = stmt->as<Block>();
86 
87             StatementArray blockStmts;
88             blockStmts.reserve_back(block.children().size());
89             for (std::unique_ptr<Statement>& blockStmt : block.children()) {
90                 move_all_but_break(blockStmt, &blockStmts);
91             }
92 
93             target->push_back(Block::Make(block.fPosition, std::move(blockStmts), block.blockKind(),
94                                           block.symbolTable()));
95             break;
96         }
97 
98         case Statement::Kind::kBreak:
99             // Do not append a break to the target.
100             break;
101 
102         default:
103             // Append normal statements to the target.
104             target->push_back(std::move(stmt));
105             break;
106     }
107 }
108 
BlockForCase(StatementArray * cases,SwitchCase * caseToCapture,std::shared_ptr<SymbolTable> symbolTable)109 std::unique_ptr<Statement> SwitchStatement::BlockForCase(StatementArray* cases,
110                                                          SwitchCase* caseToCapture,
111                                                          std::shared_ptr<SymbolTable> symbolTable) {
112     // We have to be careful to not move any of the pointers until after we're sure we're going to
113     // succeed, so before we make any changes at all, we check the switch-cases to decide on a plan
114     // of action. First, find the switch-case we are interested in.
115     auto iter = cases->begin();
116     for (; iter != cases->end(); ++iter) {
117         const SwitchCase& sc = (*iter)->as<SwitchCase>();
118         if (&sc == caseToCapture) {
119             break;
120         }
121     }
122 
123     // Next, walk forward through the rest of the switch. If we find a conditional break, we're
124     // stuck and can't simplify at all. If we find an unconditional break, we have a range of
125     // statements that we can use for simplification.
126     auto startIter = iter;
127     Statement* stripBreakStmt = nullptr;
128     for (; iter != cases->end(); ++iter) {
129         std::unique_ptr<Statement>& stmt = (*iter)->as<SwitchCase>().statement();
130         if (Analysis::SwitchCaseContainsConditionalExit(*stmt)) {
131             // We can't reduce switch-cases to a block when they have conditional exits.
132             return nullptr;
133         }
134         if (Analysis::SwitchCaseContainsUnconditionalExit(*stmt)) {
135             // We found an unconditional exit. We can use this block, but we'll need to strip
136             // out the break statement if there is one.
137             stripBreakStmt = stmt.get();
138             break;
139         }
140     }
141 
142     // We fell off the bottom of the switch or encountered a break. We know the range of statements
143     // that we need to move over, and we know it's safe to do so.
144     StatementArray caseStmts;
145     caseStmts.reserve_back(std::distance(startIter, iter) + 1);
146 
147     // We can move over most of the statements as-is.
148     while (startIter != iter) {
149         caseStmts.push_back(std::move((*startIter)->as<SwitchCase>().statement()));
150         ++startIter;
151     }
152 
153     // If we found an unconditional break at the end, we need to move what we can while avoiding
154     // that break.
155     if (stripBreakStmt != nullptr) {
156         SkASSERT((*startIter)->as<SwitchCase>().statement().get() == stripBreakStmt);
157         move_all_but_break((*startIter)->as<SwitchCase>().statement(), &caseStmts);
158     }
159 
160     // Return our newly-synthesized block.
161     return Block::Make(caseToCapture->fPosition, std::move(caseStmts), Block::Kind::kBracedScope,
162                        std::move(symbolTable));
163 }
164 
Convert(const Context & context,Position pos,std::unique_ptr<Expression> value,ExpressionArray caseValues,StatementArray caseStatements,std::shared_ptr<SymbolTable> symbolTable)165 std::unique_ptr<Statement> SwitchStatement::Convert(const Context& context,
166                                                     Position pos,
167                                                     std::unique_ptr<Expression> value,
168                                                     ExpressionArray caseValues,
169                                                     StatementArray caseStatements,
170                                                     std::shared_ptr<SymbolTable> symbolTable) {
171     SkASSERT(caseValues.size() == caseStatements.size());
172 
173     value = context.fTypes.fInt->coerceExpression(std::move(value), context);
174     if (!value) {
175         return nullptr;
176     }
177 
178     StatementArray cases;
179     for (int i = 0; i < caseValues.size(); ++i) {
180         if (caseValues[i]) {
181             Position casePos = caseValues[i]->fPosition;
182             // Case values must be constant integers of the same type as the switch value
183             std::unique_ptr<Expression> caseValue = value->type().coerceExpression(
184                     std::move(caseValues[i]), context);
185             if (!caseValue) {
186                 return nullptr;
187             }
188             SKSL_INT intValue;
189             if (!ConstantFolder::GetConstantInt(*caseValue, &intValue)) {
190                 context.fErrors->error(casePos, "case value must be a constant integer");
191                 return nullptr;
192             }
193             cases.push_back(SwitchCase::Make(casePos, intValue, std::move(caseStatements[i])));
194         } else {
195             cases.push_back(SwitchCase::MakeDefault(pos, std::move(caseStatements[i])));
196         }
197     }
198 
199     // Detect duplicate `case` labels and report an error.
200     // (Using forward_list here to optimize for the common case of no results.)
201     std::forward_list<const SwitchCase*> duplicateCases = find_duplicate_case_values(cases);
202     if (!duplicateCases.empty()) {
203         duplicateCases.reverse();
204         for (const SwitchCase* sc : duplicateCases) {
205             if (sc->isDefault()) {
206                 context.fErrors->error(sc->fPosition, "duplicate default case");
207             } else {
208                 context.fErrors->error(sc->fPosition, "duplicate case value '" +
209                                                   std::to_string(sc->value()) + "'");
210             }
211         }
212         return nullptr;
213     }
214 
215     return SwitchStatement::Make(
216             context, pos, std::move(value), std::move(cases), std::move(symbolTable));
217 }
218 
Make(const Context & context,Position pos,std::unique_ptr<Expression> value,StatementArray cases,std::shared_ptr<SymbolTable> symbolTable)219 std::unique_ptr<Statement> SwitchStatement::Make(const Context& context,
220                                                  Position pos,
221                                                  std::unique_ptr<Expression> value,
222                                                  StatementArray cases,
223                                                  std::shared_ptr<SymbolTable> symbolTable) {
224     // Confirm that every statement in `cases` is a SwitchCase.
225     SkASSERT(std::all_of(cases.begin(), cases.end(), [&](const std::unique_ptr<Statement>& stmt) {
226         return stmt->is<SwitchCase>();
227     }));
228 
229     // Confirm that every switch-case value is unique.
230     SkASSERT(find_duplicate_case_values(cases).empty());
231 
232     // Flatten switch statements if we're optimizing, and the value is known
233     if (context.fConfig->fSettings.fOptimize) {
234         SKSL_INT switchValue;
235         if (ConstantFolder::GetConstantInt(*value, &switchValue)) {
236             SwitchCase* defaultCase = nullptr;
237             SwitchCase* matchingCase = nullptr;
238             for (const std::unique_ptr<Statement>& stmt : cases) {
239                 SwitchCase& sc = stmt->as<SwitchCase>();
240                 if (sc.isDefault()) {
241                     defaultCase = &sc;
242                     continue;
243                 }
244 
245                 if (sc.value() == switchValue) {
246                     matchingCase = &sc;
247                     break;
248                 }
249             }
250 
251             if (!matchingCase) {
252                 // No case value matches the switch value.
253                 if (!defaultCase) {
254                     // No default switch-case exists; the switch had no effect.
255                     // We can eliminate the entire switch!
256                     return Nop::Make();
257                 }
258                 // We had a default case; that's what we matched with.
259                 matchingCase = defaultCase;
260             }
261 
262             // Convert the switch-case that we matched with into a block.
263             std::unique_ptr<Statement> newBlock = BlockForCase(&cases, matchingCase, symbolTable);
264             if (newBlock) {
265                 return newBlock;
266             }
267         }
268     }
269 
270     // The switch couldn't be optimized away; emit it normally.
271     return std::make_unique<SwitchStatement>(
272             pos, std::move(value), std::move(cases), std::move(symbolTable));
273 }
274 
275 }  // namespace SkSL
276