• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2021 Google LLC
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/ir/SkSLFunctionDefinition.h"
9 
10 #include "include/core/SkTypes.h"
11 #include "include/private/SkSLDefines.h"
12 #include "include/private/SkSLSymbol.h"
13 #include "include/sksl/DSLCore.h"
14 #include "include/sksl/DSLExpression.h"
15 #include "include/sksl/DSLStatement.h"
16 #include "include/sksl/DSLType.h"
17 #include "include/sksl/SkSLErrorReporter.h"
18 #include "src/base/SkSafeMath.h"
19 #include "src/sksl/SkSLAnalysis.h"
20 #include "src/sksl/SkSLCompiler.h"
21 #include "src/sksl/SkSLContext.h"
22 #include "src/sksl/SkSLProgramSettings.h"
23 #include "src/sksl/SkSLThreadContext.h"
24 #include "src/sksl/ir/SkSLBlock.h"
25 #include "src/sksl/ir/SkSLExpression.h"
26 #include "src/sksl/ir/SkSLField.h"
27 #include "src/sksl/ir/SkSLFieldAccess.h"
28 #include "src/sksl/ir/SkSLReturnStatement.h"
29 #include "src/sksl/ir/SkSLSymbolTable.h"
30 #include "src/sksl/ir/SkSLType.h"
31 #include "src/sksl/ir/SkSLVarDeclarations.h"
32 #include "src/sksl/ir/SkSLVariable.h"
33 #include "src/sksl/ir/SkSLVariableReference.h"
34 #include "src/sksl/transform/SkSLProgramWriter.h"
35 
36 #include <algorithm>
37 #include <cstddef>
38 #include <forward_list>
39 #include <string_view>
40 #include <vector>
41 
42 namespace SkSL {
43 
append_rtadjust_fixup_to_vertex_main(const Context & context,const FunctionDeclaration & decl,Block & body)44 static void append_rtadjust_fixup_to_vertex_main(const Context& context,
45                                                  const FunctionDeclaration& decl,
46                                                  Block& body) {
47     using namespace SkSL::dsl;
48     using SkSL::dsl::Swizzle;  // disambiguate from SkSL::Swizzle
49     using OwnerKind = SkSL::FieldAccess::OwnerKind;
50 
51     // If this program uses RTAdjust...
52     ThreadContext::RTAdjustData& rtAdjust = ThreadContext::RTAdjustState();
53     if (rtAdjust.fVar || rtAdjust.fInterfaceBlock) {
54         // ...append a line to the end of the function body which fixes up sk_Position.
55         const SymbolTable* symbolTable = ThreadContext::SymbolTable().get();
56         const Field& skPositionField = symbolTable->find(Compiler::POSITION_NAME)->as<Field>();
57 
58         auto Ref = [](const Variable* var) -> std::unique_ptr<Expression> {
59             return VariableReference::Make(Position(), var);
60         };
61         auto Field = [&](const Variable* var, int idx) -> std::unique_ptr<Expression> {
62             return FieldAccess::Make(context, Position(), Ref(var), idx,
63                                      OwnerKind::kAnonymousInterfaceBlock);
64         };
65         auto Pos = [&]() -> DSLExpression {
66             return DSLExpression(Field(&skPositionField.owner(), skPositionField.fieldIndex()));
67         };
68         auto Adjust = [&]() -> DSLExpression {
69             return DSLExpression(rtAdjust.fInterfaceBlock
70                                          ? Field(rtAdjust.fInterfaceBlock, rtAdjust.fFieldIndex)
71                                          : Ref(rtAdjust.fVar));
72         };
73 
74         auto fixupStmt = DSLStatement(
75             Pos().assign(Float4(Swizzle(Pos(), X, Y) * Swizzle(Adjust(), X, Z) +
76                                 Swizzle(Pos(), W, W) * Swizzle(Adjust(), Y, W),
77                                 0,
78                                 Pos().w()))
79         );
80 
81         body.children().push_back(fixupStmt.release());
82     }
83 }
84 
Convert(const Context & context,Position pos,const FunctionDeclaration & function,std::unique_ptr<Statement> body,bool builtin)85 std::unique_ptr<FunctionDefinition> FunctionDefinition::Convert(const Context& context,
86                                                                 Position pos,
87                                                                 const FunctionDeclaration& function,
88                                                                 std::unique_ptr<Statement> body,
89                                                                 bool builtin) {
90     class Finalizer : public ProgramWriter {
91     public:
92         Finalizer(const Context& context, const FunctionDeclaration& function, Position pos)
93             : fContext(context)
94             , fFunction(function) {
95             // Function parameters count as local variables.
96             for (const Variable* var : function.parameters()) {
97                 this->addLocalVariable(var, pos);
98             }
99         }
100 
101         void addLocalVariable(const Variable* var, Position pos) {
102             // We count the number of slots used, but don't consider the precision of the base type.
103             // In practice, this reflects what GPUs actually do pretty well. (i.e., RelaxedPrecision
104             // math doesn't mean your variable takes less space.) We also don't attempt to reclaim
105             // slots at the end of a Block.
106             size_t prevSlotsUsed = fSlotsUsed;
107             fSlotsUsed = SkSafeMath::Add(fSlotsUsed, var->type().slotCount());
108             // To avoid overzealous error reporting, only trigger the error at the first
109             // place where the stack limit is exceeded.
110             if (prevSlotsUsed < kVariableSlotLimit && fSlotsUsed >= kVariableSlotLimit) {
111                 fContext.fErrors->error(pos, "variable '" + std::string(var->name()) +
112                                              "' exceeds the stack size limit");
113             }
114         }
115 
116         ~Finalizer() override {
117             SkASSERT(fBreakableLevel == 0);
118             SkASSERT(fContinuableLevel == std::forward_list<int>{0});
119         }
120 
121         bool functionReturnsValue() const {
122             return !fFunction.returnType().isVoid();
123         }
124 
125         bool visitExpression(Expression& expr) override {
126             // We don't need to scan expressions.
127             return false;
128         }
129 
130         bool visitStatement(Statement& stmt) override {
131             switch (stmt.kind()) {
132                 case Statement::Kind::kVarDeclaration: {
133                     const Variable* var = stmt.as<VarDeclaration>().var();
134                     if (var->type().isOrContainsUnsizedArray()) {
135                         fContext.fErrors->error(stmt.fPosition,
136                                                 "unsized arrays are not permitted here");
137                     } else {
138                         this->addLocalVariable(var, stmt.fPosition);
139                     }
140                     break;
141                 }
142                 case Statement::Kind::kReturn: {
143                     // Early returns from a vertex main() function will bypass sk_Position
144                     // normalization, so SkASSERT that we aren't doing that. If this becomes an
145                     // issue, we can add normalization before each return statement.
146                     if (ProgramConfig::IsVertex(fContext.fConfig->fKind) && fFunction.isMain()) {
147                         fContext.fErrors->error(
148                                 stmt.fPosition,
149                                 "early returns from vertex programs are not supported");
150                     }
151 
152                     // Verify that the return statement matches the function's return type.
153                     ReturnStatement& returnStmt = stmt.as<ReturnStatement>();
154                     if (returnStmt.expression()) {
155                         if (this->functionReturnsValue()) {
156                             // Coerce return expression to the function's return type.
157                             returnStmt.setExpression(fFunction.returnType().coerceExpression(
158                                     std::move(returnStmt.expression()), fContext));
159                         } else {
160                             // Returning something from a function with a void return type.
161                             fContext.fErrors->error(returnStmt.expression()->fPosition,
162                                                     "may not return a value from a void function");
163                             returnStmt.setExpression(nullptr);
164                         }
165                     } else {
166                         if (this->functionReturnsValue()) {
167                             // Returning nothing from a function with a non-void return type.
168                             fContext.fErrors->error(returnStmt.fPosition,
169                                                     "expected function to return '" +
170                                                     fFunction.returnType().displayName() + "'");
171                         }
172                     }
173                     break;
174                 }
175                 case Statement::Kind::kDo:
176                 case Statement::Kind::kFor: {
177                     ++fBreakableLevel;
178                     ++fContinuableLevel.front();
179                     bool result = INHERITED::visitStatement(stmt);
180                     --fContinuableLevel.front();
181                     --fBreakableLevel;
182                     return result;
183                 }
184                 case Statement::Kind::kSwitch: {
185                     ++fBreakableLevel;
186                     fContinuableLevel.push_front(0);
187                     bool result = INHERITED::visitStatement(stmt);
188                     fContinuableLevel.pop_front();
189                     --fBreakableLevel;
190                     return result;
191                 }
192                 case Statement::Kind::kBreak:
193                     if (fBreakableLevel == 0) {
194                         fContext.fErrors->error(stmt.fPosition,
195                                                 "break statement must be inside a loop or switch");
196                     }
197                     break;
198                 case Statement::Kind::kContinue:
199                     if (fContinuableLevel.front() == 0) {
200                         if (std::any_of(fContinuableLevel.begin(),
201                                         fContinuableLevel.end(),
202                                         [](int level) { return level > 0; })) {
203                             fContext.fErrors->error(stmt.fPosition,
204                                                    "continue statement cannot be used in a switch");
205                         } else {
206                             fContext.fErrors->error(stmt.fPosition,
207                                                     "continue statement must be inside a loop");
208                         }
209                     }
210                     break;
211                 default:
212                     break;
213             }
214             return INHERITED::visitStatement(stmt);
215         }
216 
217     private:
218         const Context& fContext;
219         const FunctionDeclaration& fFunction;
220         // how deeply nested we are in breakable constructs (for, do, switch).
221         int fBreakableLevel = 0;
222         // number of slots consumed by all variables declared in the function
223         size_t fSlotsUsed = 0;
224         // how deeply nested we are in continuable constructs (for, do).
225         // We keep a stack (via a forward_list) in order to disallow continue inside of switch.
226         std::forward_list<int> fContinuableLevel{0};
227 
228         using INHERITED = ProgramWriter;
229     };
230 
231     Finalizer(context, function, pos).visitStatement(*body);
232     if (function.isMain() && ProgramConfig::IsVertex(context.fConfig->fKind)) {
233         append_rtadjust_fixup_to_vertex_main(context, function, body->as<Block>());
234     }
235 
236     if (Analysis::CanExitWithoutReturningValue(function, *body)) {
237         context.fErrors->error(body->fPosition, "function '" + std::string(function.name()) +
238                                                 "' can exit without returning a value");
239     }
240 
241     SkASSERTF(!function.isIntrinsic(), "Intrinsic function '%.*s' should not have a definition",
242               (int)function.name().size(), function.name().data());
243     return std::make_unique<FunctionDefinition>(pos, &function, builtin, std::move(body));
244 }
245 
246 }  // namespace SkSL
247