• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2021 Google LLC
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "include/sksl/DSLCore.h"
9 #include "src/core/SkSafeMath.h"
10 #include "src/sksl/SkSLAnalysis.h"
11 #include "src/sksl/SkSLBuiltinMap.h"
12 #include "src/sksl/SkSLCompiler.h"
13 #include "src/sksl/SkSLContext.h"
14 #include "src/sksl/SkSLProgramSettings.h"
15 #include "src/sksl/SkSLThreadContext.h"
16 #include "src/sksl/ir/SkSLFieldAccess.h"
17 #include "src/sksl/ir/SkSLFunctionCall.h"
18 #include "src/sksl/ir/SkSLFunctionDefinition.h"
19 #include "src/sksl/ir/SkSLInterfaceBlock.h"
20 #include "src/sksl/ir/SkSLReturnStatement.h"
21 #include "src/sksl/transform/SkSLProgramWriter.h"
22 
23 #include <forward_list>
24 #include <string_view>
25 #include <vector>
26 
27 namespace SkSL {
28 
append_rtadjust_fixup_to_vertex_main(const Context & context,const FunctionDeclaration & decl,Block & body)29 static void append_rtadjust_fixup_to_vertex_main(const Context& context,
30         const FunctionDeclaration& decl, Block& body) {
31     using namespace SkSL::dsl;
32     using SkSL::dsl::Swizzle;  // disambiguate from SkSL::Swizzle
33     using OwnerKind = SkSL::FieldAccess::OwnerKind;
34 
35     // If this program uses RTAdjust...
36     ThreadContext::RTAdjustData& rtAdjust = ThreadContext::RTAdjustState();
37     if (rtAdjust.fVar || rtAdjust.fInterfaceBlock) {
38         // ...append a line to the end of the function body which fixes up sk_Position.
39         const Variable* skPerVertex = nullptr;
40         if (const ProgramElement* perVertexDecl =
41                 context.fBuiltins->find(Compiler::PERVERTEX_NAME)) {
42             SkASSERT(perVertexDecl->is<SkSL::InterfaceBlock>());
43             skPerVertex = &perVertexDecl->as<SkSL::InterfaceBlock>().variable();
44         }
45 
46         SkASSERT(skPerVertex);
47         auto Ref = [](const Variable* var) -> std::unique_ptr<Expression> {
48             return VariableReference::Make(/*line=*/-1, var);
49         };
50         auto Field = [&](const Variable* var, int idx) -> std::unique_ptr<Expression> {
51             return FieldAccess::Make(context, Ref(var), idx, OwnerKind::kAnonymousInterfaceBlock);
52         };
53         auto Pos = [&]() -> DSLExpression {
54             return DSLExpression(FieldAccess::Make(context, Ref(skPerVertex), /*fieldIndex=*/0,
55                                                    OwnerKind::kAnonymousInterfaceBlock));
56         };
57         auto Adjust = [&]() -> DSLExpression {
58             return DSLExpression(rtAdjust.fInterfaceBlock
59                                          ? Field(rtAdjust.fInterfaceBlock, rtAdjust.fFieldIndex)
60                                          : Ref(rtAdjust.fVar));
61         };
62 
63         auto fixupStmt = DSLStatement(
64             Pos() = Float4(Swizzle(Pos(), X, Y) * Swizzle(Adjust(), X, Z) +
65                            Swizzle(Pos(), W, W) * Swizzle(Adjust(), Y, W),
66                            0,
67                            Pos().w())
68         );
69 
70         body.children().push_back(fixupStmt.release());
71     }
72 }
73 
Convert(const Context & context,int line,const FunctionDeclaration & function,std::unique_ptr<Statement> body,bool builtin)74 std::unique_ptr<FunctionDefinition> FunctionDefinition::Convert(const Context& context,
75                                                                 int line,
76                                                                 const FunctionDeclaration& function,
77                                                                 std::unique_ptr<Statement> body,
78                                                                 bool builtin) {
79     class Finalizer : public ProgramWriter {
80     public:
81         Finalizer(const Context& context, const FunctionDeclaration& function,
82                   FunctionSet* referencedBuiltinFunctions)
83             : fContext(context)
84             , fFunction(function)
85             , fReferencedBuiltinFunctions(referencedBuiltinFunctions) {
86             // Function parameters count as local variables.
87             for (const Variable* var : function.parameters()) {
88                 this->addLocalVariable(var, function.fLine);
89             }
90         }
91 
92         void addLocalVariable(const Variable* var, int line) {
93             // We count the number of slots used, but don't consider the precision of the base type.
94             // In practice, this reflects what GPUs actually do pretty well. (i.e., RelaxedPrecision
95             // math doesn't mean your variable takes less space.) We also don't attempt to reclaim
96             // slots at the end of a Block.
97             size_t prevSlotsUsed = fSlotsUsed;
98             fSlotsUsed = SkSafeMath::Add(fSlotsUsed, var->type().slotCount());
99             // To avoid overzealous error reporting, only trigger the error at the first
100             // place where the stack limit is exceeded.
101             if (prevSlotsUsed < kVariableSlotLimit && fSlotsUsed >= kVariableSlotLimit) {
102                 fContext.fErrors->error(line, "variable '" + std::string(var->name()) +
103                                               "' exceeds the stack size limit");
104             }
105         }
106 
107         ~Finalizer() override {
108             SkASSERT(fBreakableLevel == 0);
109             SkASSERT(fContinuableLevel == std::forward_list<int>{0});
110         }
111 
112         void copyBuiltinFunctionIfNeeded(const FunctionDeclaration& function) {
113             if (const ProgramElement* found =
114                         fContext.fBuiltins->findAndInclude(function.description())) {
115                 const FunctionDefinition& original = found->as<FunctionDefinition>();
116 
117                 // Sort the referenced builtin functions into a consistent order; otherwise our
118                 // output will become non-deterministic.
119                 std::vector<const FunctionDeclaration*> builtinFunctions(
120                         original.referencedBuiltinFunctions().begin(),
121                         original.referencedBuiltinFunctions().end());
122                 std::sort(builtinFunctions.begin(), builtinFunctions.end(),
123                           [](const FunctionDeclaration* a, const FunctionDeclaration* b) {
124                               if (a->isBuiltin() != b->isBuiltin()) {
125                                   return a->isBuiltin() < b->isBuiltin();
126                               }
127                               if (a->fLine != b->fLine) {
128                                   return a->fLine < b->fLine;
129                               }
130                               if (a->name() != b->name()) {
131                                   return a->name() < b->name();
132                               }
133                               return a->description() < b->description();
134                           });
135                 for (const FunctionDeclaration* f : builtinFunctions) {
136                     this->copyBuiltinFunctionIfNeeded(*f);
137                 }
138 
139                 ThreadContext::SharedElements().push_back(found);
140             }
141         }
142 
143         bool functionReturnsValue() const {
144             return !fFunction.returnType().isVoid();
145         }
146 
147         bool visitExpression(Expression& expr) override {
148             if (expr.is<FunctionCall>()) {
149                 const FunctionDeclaration& func = expr.as<FunctionCall>().function();
150                 if (func.isBuiltin()) {
151                     if (func.intrinsicKind() == k_dFdy_IntrinsicKind) {
152                         ThreadContext::Inputs().fUseFlipRTUniform = true;
153                     }
154                     if (func.definition()) {
155                         fReferencedBuiltinFunctions->insert(&func);
156                     }
157                     if (!fContext.fConfig->fIsBuiltinCode && fContext.fBuiltins) {
158                         this->copyBuiltinFunctionIfNeeded(func);
159                     }
160                 }
161 
162             }
163             return INHERITED::visitExpression(expr);
164         }
165 
166         bool visitStatement(Statement& stmt) override {
167             switch (stmt.kind()) {
168                 case Statement::Kind::kVarDeclaration: {
169                     const Variable* var = &stmt.as<VarDeclaration>().var();
170                     this->addLocalVariable(var, stmt.fLine);
171                     break;
172                 }
173                 case Statement::Kind::kReturn: {
174                     // Early returns from a vertex main() function will bypass sk_Position
175                     // normalization, so SkASSERT that we aren't doing that. If this becomes an
176                     // issue, we can add normalization before each return statement.
177                     if (fContext.fConfig->fKind == ProgramKind::kVertex && fFunction.isMain()) {
178                         fContext.fErrors->error(
179                                 stmt.fLine,
180                                 "early returns from vertex programs are not supported");
181                     }
182 
183                     // Verify that the return statement matches the function's return type.
184                     ReturnStatement& returnStmt = stmt.as<ReturnStatement>();
185                     if (returnStmt.expression()) {
186                         if (this->functionReturnsValue()) {
187                             // Coerce return expression to the function's return type.
188                             returnStmt.setExpression(fFunction.returnType().coerceExpression(
189                                     std::move(returnStmt.expression()), fContext));
190                         } else {
191                             // Returning something from a function with a void return type.
192                             returnStmt.setExpression(nullptr);
193                             fContext.fErrors->error(returnStmt.fLine,
194                                                     "may not return a value from a void function");
195                         }
196                     } else {
197                         if (this->functionReturnsValue()) {
198                             // Returning nothing from a function with a non-void return type.
199                             fContext.fErrors->error(returnStmt.fLine,
200                                                     "expected function to return '" +
201                                                     fFunction.returnType().displayName() + "'");
202                         }
203                     }
204                     break;
205                 }
206                 case Statement::Kind::kDo:
207                 case Statement::Kind::kFor: {
208                     ++fBreakableLevel;
209                     ++fContinuableLevel.front();
210                     bool result = INHERITED::visitStatement(stmt);
211                     --fContinuableLevel.front();
212                     --fBreakableLevel;
213                     return result;
214                 }
215                 case Statement::Kind::kSwitch: {
216                     ++fBreakableLevel;
217                     fContinuableLevel.push_front(0);
218                     bool result = INHERITED::visitStatement(stmt);
219                     fContinuableLevel.pop_front();
220                     --fBreakableLevel;
221                     return result;
222                 }
223                 case Statement::Kind::kBreak:
224                     if (fBreakableLevel == 0) {
225                         fContext.fErrors->error(stmt.fLine,
226                                                 "break statement must be inside a loop or switch");
227                     }
228                     break;
229                 case Statement::Kind::kContinue:
230                     if (fContinuableLevel.front() == 0) {
231                         if (std::any_of(fContinuableLevel.begin(),
232                                         fContinuableLevel.end(),
233                                         [](int level) { return level > 0; })) {
234                             fContext.fErrors->error(stmt.fLine,
235                                                    "continue statement cannot be used in a switch");
236                         } else {
237                             fContext.fErrors->error(stmt.fLine,
238                                                     "continue statement must be inside a loop");
239                         }
240                     }
241                     break;
242                 default:
243                     break;
244             }
245             return INHERITED::visitStatement(stmt);
246         }
247 
248     private:
249         const Context& fContext;
250         const FunctionDeclaration& fFunction;
251         // which builtin functions have we encountered in this function
252         FunctionSet* fReferencedBuiltinFunctions;
253         // how deeply nested we are in breakable constructs (for, do, switch).
254         int fBreakableLevel = 0;
255         // number of slots consumed by all variables declared in the function
256         size_t fSlotsUsed = 0;
257         // how deeply nested we are in continuable constructs (for, do).
258         // We keep a stack (via a forward_list) in order to disallow continue inside of switch.
259         std::forward_list<int> fContinuableLevel{0};
260 
261         using INHERITED = ProgramWriter;
262     };
263 
264     FunctionSet referencedBuiltinFunctions;
265     Finalizer(context, function, &referencedBuiltinFunctions).visitStatement(*body);
266     if (function.isMain() && context.fConfig->fKind == ProgramKind::kVertex) {
267         append_rtadjust_fixup_to_vertex_main(context, function, body->as<Block>());
268     }
269 
270     if (Analysis::CanExitWithoutReturningValue(function, *body)) {
271         context.fErrors->error(function.fLine, "function '" + std::string(function.name()) +
272                                                "' can exit without returning a value");
273     }
274 
275     SkASSERTF(!function.isIntrinsic(),
276               "Intrinsic %s should not have a definition",
277               std::string(function.name()).c_str());
278     return std::make_unique<FunctionDefinition>(line, &function, builtin, std::move(body),
279                                                 std::move(referencedBuiltinFunctions));
280 }
281 
282 }  // namespace SkSL
283