1 /*
2 * Copyright 2021 Google LLC
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8 #include "include/sksl/DSLCore.h"
9 #include "src/core/SkSafeMath.h"
10 #include "src/sksl/SkSLAnalysis.h"
11 #include "src/sksl/SkSLBuiltinMap.h"
12 #include "src/sksl/SkSLCompiler.h"
13 #include "src/sksl/SkSLContext.h"
14 #include "src/sksl/SkSLProgramSettings.h"
15 #include "src/sksl/SkSLThreadContext.h"
16 #include "src/sksl/ir/SkSLFieldAccess.h"
17 #include "src/sksl/ir/SkSLFunctionCall.h"
18 #include "src/sksl/ir/SkSLFunctionDefinition.h"
19 #include "src/sksl/ir/SkSLInterfaceBlock.h"
20 #include "src/sksl/ir/SkSLReturnStatement.h"
21 #include "src/sksl/transform/SkSLProgramWriter.h"
22
23 #include <forward_list>
24 #include <string_view>
25 #include <vector>
26
27 namespace SkSL {
28
append_rtadjust_fixup_to_vertex_main(const Context & context,const FunctionDeclaration & decl,Block & body)29 static void append_rtadjust_fixup_to_vertex_main(const Context& context,
30 const FunctionDeclaration& decl, Block& body) {
31 using namespace SkSL::dsl;
32 using SkSL::dsl::Swizzle; // disambiguate from SkSL::Swizzle
33 using OwnerKind = SkSL::FieldAccess::OwnerKind;
34
35 // If this program uses RTAdjust...
36 ThreadContext::RTAdjustData& rtAdjust = ThreadContext::RTAdjustState();
37 if (rtAdjust.fVar || rtAdjust.fInterfaceBlock) {
38 // ...append a line to the end of the function body which fixes up sk_Position.
39 const Variable* skPerVertex = nullptr;
40 if (const ProgramElement* perVertexDecl =
41 context.fBuiltins->find(Compiler::PERVERTEX_NAME)) {
42 SkASSERT(perVertexDecl->is<SkSL::InterfaceBlock>());
43 skPerVertex = &perVertexDecl->as<SkSL::InterfaceBlock>().variable();
44 }
45
46 SkASSERT(skPerVertex);
47 auto Ref = [](const Variable* var) -> std::unique_ptr<Expression> {
48 return VariableReference::Make(/*line=*/-1, var);
49 };
50 auto Field = [&](const Variable* var, int idx) -> std::unique_ptr<Expression> {
51 return FieldAccess::Make(context, Ref(var), idx, OwnerKind::kAnonymousInterfaceBlock);
52 };
53 auto Pos = [&]() -> DSLExpression {
54 return DSLExpression(FieldAccess::Make(context, Ref(skPerVertex), /*fieldIndex=*/0,
55 OwnerKind::kAnonymousInterfaceBlock));
56 };
57 auto Adjust = [&]() -> DSLExpression {
58 return DSLExpression(rtAdjust.fInterfaceBlock
59 ? Field(rtAdjust.fInterfaceBlock, rtAdjust.fFieldIndex)
60 : Ref(rtAdjust.fVar));
61 };
62
63 auto fixupStmt = DSLStatement(
64 Pos() = Float4(Swizzle(Pos(), X, Y) * Swizzle(Adjust(), X, Z) +
65 Swizzle(Pos(), W, W) * Swizzle(Adjust(), Y, W),
66 0,
67 Pos().w())
68 );
69
70 body.children().push_back(fixupStmt.release());
71 }
72 }
73
Convert(const Context & context,int line,const FunctionDeclaration & function,std::unique_ptr<Statement> body,bool builtin)74 std::unique_ptr<FunctionDefinition> FunctionDefinition::Convert(const Context& context,
75 int line,
76 const FunctionDeclaration& function,
77 std::unique_ptr<Statement> body,
78 bool builtin) {
79 class Finalizer : public ProgramWriter {
80 public:
81 Finalizer(const Context& context, const FunctionDeclaration& function,
82 FunctionSet* referencedBuiltinFunctions)
83 : fContext(context)
84 , fFunction(function)
85 , fReferencedBuiltinFunctions(referencedBuiltinFunctions) {
86 // Function parameters count as local variables.
87 for (const Variable* var : function.parameters()) {
88 this->addLocalVariable(var, function.fLine);
89 }
90 }
91
92 void addLocalVariable(const Variable* var, int line) {
93 // We count the number of slots used, but don't consider the precision of the base type.
94 // In practice, this reflects what GPUs actually do pretty well. (i.e., RelaxedPrecision
95 // math doesn't mean your variable takes less space.) We also don't attempt to reclaim
96 // slots at the end of a Block.
97 size_t prevSlotsUsed = fSlotsUsed;
98 fSlotsUsed = SkSafeMath::Add(fSlotsUsed, var->type().slotCount());
99 // To avoid overzealous error reporting, only trigger the error at the first
100 // place where the stack limit is exceeded.
101 if (prevSlotsUsed < kVariableSlotLimit && fSlotsUsed >= kVariableSlotLimit) {
102 fContext.fErrors->error(line, "variable '" + std::string(var->name()) +
103 "' exceeds the stack size limit");
104 }
105 }
106
107 ~Finalizer() override {
108 SkASSERT(fBreakableLevel == 0);
109 SkASSERT(fContinuableLevel == std::forward_list<int>{0});
110 }
111
112 void copyBuiltinFunctionIfNeeded(const FunctionDeclaration& function) {
113 if (const ProgramElement* found =
114 fContext.fBuiltins->findAndInclude(function.description())) {
115 const FunctionDefinition& original = found->as<FunctionDefinition>();
116
117 // Sort the referenced builtin functions into a consistent order; otherwise our
118 // output will become non-deterministic.
119 std::vector<const FunctionDeclaration*> builtinFunctions(
120 original.referencedBuiltinFunctions().begin(),
121 original.referencedBuiltinFunctions().end());
122 std::sort(builtinFunctions.begin(), builtinFunctions.end(),
123 [](const FunctionDeclaration* a, const FunctionDeclaration* b) {
124 if (a->isBuiltin() != b->isBuiltin()) {
125 return a->isBuiltin() < b->isBuiltin();
126 }
127 if (a->fLine != b->fLine) {
128 return a->fLine < b->fLine;
129 }
130 if (a->name() != b->name()) {
131 return a->name() < b->name();
132 }
133 return a->description() < b->description();
134 });
135 for (const FunctionDeclaration* f : builtinFunctions) {
136 this->copyBuiltinFunctionIfNeeded(*f);
137 }
138
139 ThreadContext::SharedElements().push_back(found);
140 }
141 }
142
143 bool functionReturnsValue() const {
144 return !fFunction.returnType().isVoid();
145 }
146
147 bool visitExpression(Expression& expr) override {
148 if (expr.is<FunctionCall>()) {
149 const FunctionDeclaration& func = expr.as<FunctionCall>().function();
150 if (func.isBuiltin()) {
151 if (func.intrinsicKind() == k_dFdy_IntrinsicKind) {
152 ThreadContext::Inputs().fUseFlipRTUniform = true;
153 }
154 if (func.definition()) {
155 fReferencedBuiltinFunctions->insert(&func);
156 }
157 if (!fContext.fConfig->fIsBuiltinCode && fContext.fBuiltins) {
158 this->copyBuiltinFunctionIfNeeded(func);
159 }
160 }
161
162 }
163 return INHERITED::visitExpression(expr);
164 }
165
166 bool visitStatement(Statement& stmt) override {
167 switch (stmt.kind()) {
168 case Statement::Kind::kVarDeclaration: {
169 const Variable* var = &stmt.as<VarDeclaration>().var();
170 this->addLocalVariable(var, stmt.fLine);
171 break;
172 }
173 case Statement::Kind::kReturn: {
174 // Early returns from a vertex main() function will bypass sk_Position
175 // normalization, so SkASSERT that we aren't doing that. If this becomes an
176 // issue, we can add normalization before each return statement.
177 if (fContext.fConfig->fKind == ProgramKind::kVertex && fFunction.isMain()) {
178 fContext.fErrors->error(
179 stmt.fLine,
180 "early returns from vertex programs are not supported");
181 }
182
183 // Verify that the return statement matches the function's return type.
184 ReturnStatement& returnStmt = stmt.as<ReturnStatement>();
185 if (returnStmt.expression()) {
186 if (this->functionReturnsValue()) {
187 // Coerce return expression to the function's return type.
188 returnStmt.setExpression(fFunction.returnType().coerceExpression(
189 std::move(returnStmt.expression()), fContext));
190 } else {
191 // Returning something from a function with a void return type.
192 returnStmt.setExpression(nullptr);
193 fContext.fErrors->error(returnStmt.fLine,
194 "may not return a value from a void function");
195 }
196 } else {
197 if (this->functionReturnsValue()) {
198 // Returning nothing from a function with a non-void return type.
199 fContext.fErrors->error(returnStmt.fLine,
200 "expected function to return '" +
201 fFunction.returnType().displayName() + "'");
202 }
203 }
204 break;
205 }
206 case Statement::Kind::kDo:
207 case Statement::Kind::kFor: {
208 ++fBreakableLevel;
209 ++fContinuableLevel.front();
210 bool result = INHERITED::visitStatement(stmt);
211 --fContinuableLevel.front();
212 --fBreakableLevel;
213 return result;
214 }
215 case Statement::Kind::kSwitch: {
216 ++fBreakableLevel;
217 fContinuableLevel.push_front(0);
218 bool result = INHERITED::visitStatement(stmt);
219 fContinuableLevel.pop_front();
220 --fBreakableLevel;
221 return result;
222 }
223 case Statement::Kind::kBreak:
224 if (fBreakableLevel == 0) {
225 fContext.fErrors->error(stmt.fLine,
226 "break statement must be inside a loop or switch");
227 }
228 break;
229 case Statement::Kind::kContinue:
230 if (fContinuableLevel.front() == 0) {
231 if (std::any_of(fContinuableLevel.begin(),
232 fContinuableLevel.end(),
233 [](int level) { return level > 0; })) {
234 fContext.fErrors->error(stmt.fLine,
235 "continue statement cannot be used in a switch");
236 } else {
237 fContext.fErrors->error(stmt.fLine,
238 "continue statement must be inside a loop");
239 }
240 }
241 break;
242 default:
243 break;
244 }
245 return INHERITED::visitStatement(stmt);
246 }
247
248 private:
249 const Context& fContext;
250 const FunctionDeclaration& fFunction;
251 // which builtin functions have we encountered in this function
252 FunctionSet* fReferencedBuiltinFunctions;
253 // how deeply nested we are in breakable constructs (for, do, switch).
254 int fBreakableLevel = 0;
255 // number of slots consumed by all variables declared in the function
256 size_t fSlotsUsed = 0;
257 // how deeply nested we are in continuable constructs (for, do).
258 // We keep a stack (via a forward_list) in order to disallow continue inside of switch.
259 std::forward_list<int> fContinuableLevel{0};
260
261 using INHERITED = ProgramWriter;
262 };
263
264 FunctionSet referencedBuiltinFunctions;
265 Finalizer(context, function, &referencedBuiltinFunctions).visitStatement(*body);
266 if (function.isMain() && context.fConfig->fKind == ProgramKind::kVertex) {
267 append_rtadjust_fixup_to_vertex_main(context, function, body->as<Block>());
268 }
269
270 if (Analysis::CanExitWithoutReturningValue(function, *body)) {
271 context.fErrors->error(function.fLine, "function '" + std::string(function.name()) +
272 "' can exit without returning a value");
273 }
274
275 SkASSERTF(!function.isIntrinsic(),
276 "Intrinsic %s should not have a definition",
277 std::string(function.name()).c_str());
278 return std::make_unique<FunctionDefinition>(line, &function, builtin, std::move(body),
279 std::move(referencedBuiltinFunctions));
280 }
281
282 } // namespace SkSL
283