• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2021 Google LLC
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/ir/SkSLFunctionDefinition.h"
9 
10 #include "include/core/SkSpan.h"
11 #include "include/core/SkTypes.h"
12 #include "src/base/SkSafeMath.h"
13 #include "src/sksl/SkSLAnalysis.h"
14 #include "src/sksl/SkSLCompiler.h"
15 #include "src/sksl/SkSLContext.h"
16 #include "src/sksl/SkSLDefines.h"
17 #include "src/sksl/SkSLErrorReporter.h"
18 #include "src/sksl/SkSLOperator.h"
19 #include "src/sksl/SkSLProgramSettings.h"
20 #include "src/sksl/ir/SkSLBinaryExpression.h"
21 #include "src/sksl/ir/SkSLBlock.h"
22 #include "src/sksl/ir/SkSLExpression.h"
23 #include "src/sksl/ir/SkSLExpressionStatement.h"
24 #include "src/sksl/ir/SkSLFieldSymbol.h"
25 #include "src/sksl/ir/SkSLIRHelpers.h"
26 #include "src/sksl/ir/SkSLNop.h"
27 #include "src/sksl/ir/SkSLReturnStatement.h"
28 #include "src/sksl/ir/SkSLSwizzle.h"
29 #include "src/sksl/ir/SkSLSymbol.h"
30 #include "src/sksl/ir/SkSLSymbolTable.h"  // IWYU pragma: keep
31 #include "src/sksl/ir/SkSLType.h"
32 #include "src/sksl/ir/SkSLVarDeclarations.h"
33 #include "src/sksl/ir/SkSLVariable.h"
34 #include "src/sksl/ir/SkSLVariableReference.h"
35 #include "src/sksl/transform/SkSLProgramWriter.h"
36 
37 #include <algorithm>
38 #include <cstddef>
39 #include <forward_list>
40 
41 namespace SkSL {
42 
append_rtadjust_fixup_to_vertex_main(const Context & context,const FunctionDeclaration & decl,Block & body)43 static void append_rtadjust_fixup_to_vertex_main(const Context& context,
44                                                  const FunctionDeclaration& decl,
45                                                  Block& body) {
46     // If this program uses RTAdjust...
47     if (const SkSL::Symbol* rtAdjust = context.fSymbolTable->find(Compiler::RTADJUST_NAME)) {
48         // ...append a line to the end of the function body which fixes up sk_Position.
49         struct AppendRTAdjustFixupHelper : public IRHelpers {
50             AppendRTAdjustFixupHelper(const Context& ctx, const SkSL::Symbol* rtAdjust)
51                     : IRHelpers(ctx)
52                     , fRTAdjust(rtAdjust) {
53                 fSkPositionField = &fContext.fSymbolTable->find(Compiler::POSITION_NAME)
54                                                          ->as<FieldSymbol>();
55             }
56 
57             std::unique_ptr<Expression> Pos() const {
58                 return Field(&fSkPositionField->owner(), fSkPositionField->fieldIndex());
59             }
60 
61             std::unique_ptr<Expression> Adjust() const {
62                 return fRTAdjust->instantiate(fContext, Position());
63             }
64 
65             std::unique_ptr<Statement> makeFixupStmt() const {
66                 // sk_Position = float4(sk_Position.xy * rtAdjust.xz + sk_Position.ww * rtAdjust.yw,
67                 //                      0,
68                 //                      sk_Position.w);
69                 return Assign(
70                    Pos(),
71                    CtorXYZW(Add(Mul(Swizzle(Pos(),    {SwizzleComponent::X, SwizzleComponent::Y}),
72                                     Swizzle(Adjust(), {SwizzleComponent::X, SwizzleComponent::Z})),
73                                 Mul(Swizzle(Pos(),    {SwizzleComponent::W, SwizzleComponent::W}),
74                                     Swizzle(Adjust(), {SwizzleComponent::Y, SwizzleComponent::W}))),
75                             Float(0.0),
76                             Swizzle(Pos(), {SwizzleComponent::W})));
77             }
78 
79             const FieldSymbol* fSkPositionField;
80             const SkSL::Symbol* fRTAdjust;
81         };
82 
83         AppendRTAdjustFixupHelper helper(context, rtAdjust);
84         body.children().push_back(helper.makeFixupStmt());
85     }
86 }
87 
Convert(const Context & context,Position pos,const FunctionDeclaration & function,std::unique_ptr<Statement> body,bool builtin)88 std::unique_ptr<FunctionDefinition> FunctionDefinition::Convert(const Context& context,
89                                                                 Position pos,
90                                                                 const FunctionDeclaration& function,
91                                                                 std::unique_ptr<Statement> body,
92                                                                 bool builtin) {
93     class Finalizer : public ProgramWriter {
94     public:
95         Finalizer(const Context& context, const FunctionDeclaration& function, Position pos)
96             : fContext(context)
97             , fFunction(function) {
98             // Function parameters count as local variables.
99             for (const Variable* var : function.parameters()) {
100                 this->addLocalVariable(var, pos);
101             }
102         }
103 
104         ~Finalizer() override {
105             SkASSERT(fBreakableLevel == 0);
106             SkASSERT(fContinuableLevel == std::forward_list<int>{0});
107         }
108 
109         void addLocalVariable(const Variable* var, Position pos) {
110             if (var->type().isOrContainsUnsizedArray()) {
111                 fContext.fErrors->error(pos, "unsized arrays are not permitted here");
112                 return;
113             }
114             // We count the number of slots used, but don't consider the precision of the base type.
115             // In practice, this reflects what GPUs actually do pretty well. (i.e., RelaxedPrecision
116             // math doesn't mean your variable takes less space.) We also don't attempt to reclaim
117             // slots at the end of a Block.
118             size_t prevSlotsUsed = fSlotsUsed;
119             fSlotsUsed = SkSafeMath::Add(fSlotsUsed, var->type().slotCount());
120             // To avoid overzealous error reporting, only trigger the error at the first
121             // place where the stack limit is exceeded.
122             if (prevSlotsUsed < kVariableSlotLimit && fSlotsUsed >= kVariableSlotLimit) {
123                 fContext.fErrors->error(pos, "variable '" + std::string(var->name()) +
124                                              "' exceeds the stack size limit");
125             }
126         }
127 
128         void fuseVariableDeclarationsWithInitialization(std::unique_ptr<Statement>& stmt) {
129             switch (stmt->kind()) {
130                 case Statement::Kind::kNop:
131                 case Statement::Kind::kBlock:
132                     // Blocks and no-ops are inert; it is safe to fuse a variable declaration with
133                     // its initialization across a nop or an open-brace, so we don't null out
134                     // `fUninitializedVarDecl` here.
135                     break;
136 
137                 case Statement::Kind::kVarDeclaration:
138                     // Look for variable declarations without an initializer.
139                     if (VarDeclaration& decl = stmt->as<VarDeclaration>(); !decl.value()) {
140                         fUninitializedVarDecl = &decl;
141                         break;
142                     }
143                     [[fallthrough]];
144 
145                 default:
146                     // We found an intervening statement; it's not safe to fuse a declaration
147                     // with an initializer if we encounter any other code.
148                     fUninitializedVarDecl = nullptr;
149                     break;
150 
151                 case Statement::Kind::kExpression: {
152                     // We found an expression-statement. If there was a variable declaration
153                     // immediately above it, it might be possible to fuse them.
154                     if (fUninitializedVarDecl) {
155                         VarDeclaration* vardecl = fUninitializedVarDecl;
156                         fUninitializedVarDecl = nullptr;
157 
158                         std::unique_ptr<Expression>& nextExpr = stmt->as<ExpressionStatement>()
159                                                                      .expression();
160                         // This statement must be a binary-expression...
161                         if (!nextExpr->is<BinaryExpression>()) {
162                             break;
163                         }
164                         // ... performing simple `var = expr` assignment...
165                         BinaryExpression& binaryExpr = nextExpr->as<BinaryExpression>();
166                         if (binaryExpr.getOperator().kind() != OperatorKind::EQ) {
167                             break;
168                         }
169                         // ... directly into the variable (not a field/swizzle)...
170                         Expression& leftExpr = *binaryExpr.left();
171                         if (!leftExpr.is<VariableReference>()) {
172                             break;
173                         }
174                         // ... and it must be the same variable as our vardecl.
175                         VariableReference& varRef = leftExpr.as<VariableReference>();
176                         if (varRef.variable() != vardecl->var()) {
177                             break;
178                         }
179                         // The init-expression must not reference the variable.
180                         // `int x; x = x = 0;` is legal SkSL, but `int x = x = 0;` is not.
181                         if (Analysis::ContainsVariable(*binaryExpr.right(), *varRef.variable())) {
182                             break;
183                         }
184                         // We found a match! Move the init-expression directly onto the vardecl, and
185                         // turn the assignment into a no-op.
186                         vardecl->value() = std::move(binaryExpr.right());
187 
188                         // Turn the expression-statement into a no-op.
189                         stmt = Nop::Make();
190                     }
191                     break;
192                 }
193             }
194         }
195 
196         bool functionReturnsValue() const {
197             return !fFunction.returnType().isVoid();
198         }
199 
200         bool visitExpressionPtr(std::unique_ptr<Expression>& expr) override {
201             // We don't need to scan expressions.
202             return false;
203         }
204 
205         bool visitStatementPtr(std::unique_ptr<Statement>& stmt) override {
206             // When the optimizer is on, we look for variable declarations that are immediately
207             // followed by an initialization expression, and fuse them into one statement.
208             // (e.g.: `int i; i = 1;` can become `int i = 1;`)
209             if (fContext.fConfig->fSettings.fOptimize) {
210                 this->fuseVariableDeclarationsWithInitialization(stmt);
211             }
212 
213             // Perform error checking.
214             switch (stmt->kind()) {
215                 case Statement::Kind::kVarDeclaration:
216                     this->addLocalVariable(stmt->as<VarDeclaration>().var(), stmt->fPosition);
217                     break;
218 
219                 case Statement::Kind::kReturn: {
220                     // Early returns from a vertex main() function will bypass sk_Position
221                     // normalization, so SkASSERT that we aren't doing that. If this becomes an
222                     // issue, we can add normalization before each return statement.
223                     if (ProgramConfig::IsVertex(fContext.fConfig->fKind) && fFunction.isMain()) {
224                         fContext.fErrors->error(
225                                 stmt->fPosition,
226                                 "early returns from vertex programs are not supported");
227                     }
228 
229                     // Verify that the return statement matches the function's return type.
230                     ReturnStatement& returnStmt = stmt->as<ReturnStatement>();
231                     if (returnStmt.expression()) {
232                         if (this->functionReturnsValue()) {
233                             // Coerce return expression to the function's return type.
234                             returnStmt.setExpression(fFunction.returnType().coerceExpression(
235                                     std::move(returnStmt.expression()), fContext));
236                         } else {
237                             // Returning something from a function with a void return type.
238                             fContext.fErrors->error(returnStmt.expression()->fPosition,
239                                                     "may not return a value from a void function");
240                             returnStmt.setExpression(nullptr);
241                         }
242                     } else {
243                         if (this->functionReturnsValue()) {
244                             // Returning nothing from a function with a non-void return type.
245                             fContext.fErrors->error(returnStmt.fPosition,
246                                                     "expected function to return '" +
247                                                     fFunction.returnType().displayName() + "'");
248                         }
249                     }
250                     break;
251                 }
252                 case Statement::Kind::kDo:
253                 case Statement::Kind::kFor: {
254                     ++fBreakableLevel;
255                     ++fContinuableLevel.front();
256                     bool result = INHERITED::visitStatementPtr(stmt);
257                     --fContinuableLevel.front();
258                     --fBreakableLevel;
259                     return result;
260                 }
261                 case Statement::Kind::kSwitch: {
262                     ++fBreakableLevel;
263                     fContinuableLevel.push_front(0);
264                     bool result = INHERITED::visitStatementPtr(stmt);
265                     fContinuableLevel.pop_front();
266                     --fBreakableLevel;
267                     return result;
268                 }
269                 case Statement::Kind::kBreak:
270                     if (fBreakableLevel == 0) {
271                         fContext.fErrors->error(stmt->fPosition,
272                                                 "break statement must be inside a loop or switch");
273                     }
274                     break;
275 
276                 case Statement::Kind::kContinue:
277                     if (fContinuableLevel.front() == 0) {
278                         if (std::any_of(fContinuableLevel.begin(),
279                                         fContinuableLevel.end(),
280                                         [](int level) { return level > 0; })) {
281                             fContext.fErrors->error(stmt->fPosition,
282                                                    "continue statement cannot be used in a switch");
283                         } else {
284                             fContext.fErrors->error(stmt->fPosition,
285                                                     "continue statement must be inside a loop");
286                         }
287                     }
288                     break;
289 
290                 default:
291                     break;
292             }
293             return INHERITED::visitStatementPtr(stmt);
294         }
295 
296     private:
297         const Context& fContext;
298         const FunctionDeclaration& fFunction;
299         // how deeply nested we are in breakable constructs (for, do, switch).
300         int fBreakableLevel = 0;
301         // number of slots consumed by all variables declared in the function
302         size_t fSlotsUsed = 0;
303         // how deeply nested we are in continuable constructs (for, do).
304         // We keep a stack (via a forward_list) in order to disallow continue inside of switch.
305         std::forward_list<int> fContinuableLevel{0};
306         // We track uninitialized variable declarations, and if they are immediately assigned-to,
307         // we can move the assignment directly into the decl.
308         VarDeclaration* fUninitializedVarDecl = nullptr;
309 
310         using INHERITED = ProgramWriter;
311     };
312 
313     // We don't allow modules to define actual functions with intrinsic names. (Those should be
314     // reserved for actual intrinsics.)
315     if (function.isIntrinsic()) {
316         context.fErrors->error(function.fPosition, "Intrinsic function '" +
317                                                    std::string(function.name()) +
318                                                    "' should not have a definition");
319         return nullptr;
320     }
321 
322     // A function body must always be a braced block. (The parser should enforce this already, but
323     // we rely on it, so it's best to be certain.)
324     if (!body || !body->is<Block>() || !body->as<Block>().isScope()) {
325         context.fErrors->error(function.fPosition, "function body '" + function.description() +
326                                                    "' must be a braced block");
327         return nullptr;
328     }
329 
330     // A function can't have more than one definition.
331     if (function.definition()) {
332         context.fErrors->error(function.fPosition, "function '" + function.description() +
333                                                    "' was already defined");
334         return nullptr;
335     }
336 
337     // Run the function finalizer. This checks for illegal constructs and missing return statements,
338     // and also performs some simple code cleanup.
339     Finalizer(context, function, pos).visitStatementPtr(body);
340     if (function.isMain() && ProgramConfig::IsVertex(context.fConfig->fKind)) {
341         append_rtadjust_fixup_to_vertex_main(context, function, body->as<Block>());
342     }
343 
344     if (Analysis::CanExitWithoutReturningValue(function, *body)) {
345         context.fErrors->error(body->fPosition, "function '" + std::string(function.name()) +
346                                                 "' can exit without returning a value");
347     }
348 
349     return FunctionDefinition::Make(context, pos, function, std::move(body), builtin);
350 }
351 
Make(const Context &,Position pos,const FunctionDeclaration & function,std::unique_ptr<Statement> body,bool builtin)352 std::unique_ptr<FunctionDefinition> FunctionDefinition::Make(const Context&,
353                                                              Position pos,
354                                                              const FunctionDeclaration& function,
355                                                              std::unique_ptr<Statement> body,
356                                                              bool builtin) {
357     SkASSERT(!function.isIntrinsic());
358     SkASSERT(body && body->as<Block>().isScope());
359     SkASSERT(!function.definition());
360 
361     return std::make_unique<FunctionDefinition>(pos, &function, builtin, std::move(body));
362 }
363 
364 }  // namespace SkSL
365