1 /*
2 * Copyright 2021 Google LLC
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8 #include "include/core/SkTypes.h"
9 #include "include/private/SkSLIRNode.h"
10 #include "include/private/SkSLStatement.h"
11 #include "include/private/base/SkFloatingPoint.h"
12 #include "include/sksl/SkSLErrorReporter.h"
13 #include "include/sksl/SkSLOperator.h"
14 #include "include/sksl/SkSLPosition.h"
15 #include "src/sksl/SkSLAnalysis.h"
16 #include "src/sksl/SkSLConstantFolder.h"
17 #include "src/sksl/analysis/SkSLNoOpErrorReporter.h"
18 #include "src/sksl/ir/SkSLBinaryExpression.h"
19 #include "src/sksl/ir/SkSLExpression.h"
20 #include "src/sksl/ir/SkSLForStatement.h"
21 #include "src/sksl/ir/SkSLPostfixExpression.h"
22 #include "src/sksl/ir/SkSLPrefixExpression.h"
23 #include "src/sksl/ir/SkSLType.h"
24 #include "src/sksl/ir/SkSLVarDeclarations.h"
25 #include "src/sksl/ir/SkSLVariableReference.h"
26
27 #include <cmath>
28 #include <memory>
29
30 namespace SkSL {
31
32 // Loops that run for 100000+ iterations will exceed our program size limit.
33 static constexpr int kLoopTerminationLimit = 100000;
34
calculate_count(double start,double end,double delta,bool forwards,bool inclusive)35 static int calculate_count(double start, double end, double delta, bool forwards, bool inclusive) {
36 if (forwards != (start < end)) {
37 // The loop starts in a completed state (the start has already advanced past the end).
38 return 0;
39 }
40 if ((delta == 0.0) || forwards != (delta > 0.0)) {
41 // The loop does not progress toward a completed state, and will never terminate.
42 return kLoopTerminationLimit;
43 }
44 double iterations = sk_ieee_double_divide(end - start, delta);
45 double count = std::ceil(iterations);
46 if (inclusive && (count == iterations)) {
47 count += 1.0;
48 }
49 if (count > kLoopTerminationLimit || !std::isfinite(count)) {
50 // The loop runs for more iterations than we can safely unroll.
51 return kLoopTerminationLimit;
52 }
53 return (int)count;
54 }
55
GetLoopUnrollInfo(Position loopPos,const ForLoopPositions & positions,const Statement * loopInitializer,const Expression * loopTest,const Expression * loopNext,const Statement * loopStatement,ErrorReporter * errorPtr)56 std::unique_ptr<LoopUnrollInfo> Analysis::GetLoopUnrollInfo(Position loopPos,
57 const ForLoopPositions& positions,
58 const Statement* loopInitializer,
59 const Expression* loopTest,
60 const Expression* loopNext,
61 const Statement* loopStatement,
62 ErrorReporter* errorPtr) {
63 NoOpErrorReporter unused;
64 ErrorReporter& errors = errorPtr ? *errorPtr : unused;
65 auto loopInfo = std::make_unique<LoopUnrollInfo>();
66
67 //
68 // init_declaration has the form: type_specifier identifier = constant_expression
69 //
70 if (!loopInitializer) {
71 Position pos = positions.initPosition.valid() ? positions.initPosition : loopPos;
72 errors.error(pos, "missing init declaration");
73 return nullptr;
74 }
75 if (!loopInitializer->is<VarDeclaration>()) {
76 errors.error(loopInitializer->fPosition, "invalid init declaration");
77 return nullptr;
78 }
79 const VarDeclaration& initDecl = loopInitializer->as<VarDeclaration>();
80 if (!initDecl.baseType().isNumber()) {
81 errors.error(loopInitializer->fPosition, "invalid type for loop index");
82 return nullptr;
83 }
84 if (initDecl.arraySize() != 0) {
85 errors.error(loopInitializer->fPosition, "invalid type for loop index");
86 return nullptr;
87 }
88 if (!initDecl.value()) {
89 errors.error(loopInitializer->fPosition, "missing loop index initializer");
90 return nullptr;
91 }
92 if (!ConstantFolder::GetConstantValue(*initDecl.value(), &loopInfo->fStart)) {
93 errors.error(loopInitializer->fPosition,
94 "loop index initializer must be a constant expression");
95 return nullptr;
96 }
97
98 loopInfo->fIndex = initDecl.var();
99
100 auto is_loop_index = [&](const std::unique_ptr<Expression>& expr) {
101 return expr->is<VariableReference>() &&
102 expr->as<VariableReference>().variable() == loopInfo->fIndex;
103 };
104
105 //
106 // condition has the form: loop_index relational_operator constant_expression
107 //
108 if (!loopTest) {
109 Position pos = positions.conditionPosition.valid() ? positions.conditionPosition : loopPos;
110 errors.error(pos, "missing condition");
111 return nullptr;
112 }
113 if (!loopTest->is<BinaryExpression>()) {
114 errors.error(loopTest->fPosition, "invalid condition");
115 return nullptr;
116 }
117 const BinaryExpression& cond = loopTest->as<BinaryExpression>();
118 if (!is_loop_index(cond.left())) {
119 errors.error(loopTest->fPosition, "expected loop index on left hand side of condition");
120 return nullptr;
121 }
122 // relational_operator is one of: > >= < <= == or !=
123 switch (cond.getOperator().kind()) {
124 case Operator::Kind::GT:
125 case Operator::Kind::GTEQ:
126 case Operator::Kind::LT:
127 case Operator::Kind::LTEQ:
128 case Operator::Kind::EQEQ:
129 case Operator::Kind::NEQ:
130 break;
131 default:
132 errors.error(loopTest->fPosition, "invalid relational operator");
133 return nullptr;
134 }
135 double loopEnd = 0;
136 if (!ConstantFolder::GetConstantValue(*cond.right(), &loopEnd)) {
137 errors.error(loopTest->fPosition, "loop index must be compared with a constant expression");
138 return nullptr;
139 }
140
141 //
142 // expression has one of the following forms:
143 // loop_index++
144 // loop_index--
145 // loop_index += constant_expression
146 // loop_index -= constant_expression
147 // The spec doesn't mention prefix increment and decrement, but there is some consensus that
148 // it's an oversight, so we allow those as well.
149 //
150 if (!loopNext) {
151 Position pos = positions.nextPosition.valid() ? positions.nextPosition : loopPos;
152 errors.error(pos, "missing loop expression");
153 return nullptr;
154 }
155 switch (loopNext->kind()) {
156 case Expression::Kind::kBinary: {
157 const BinaryExpression& next = loopNext->as<BinaryExpression>();
158 if (!is_loop_index(next.left())) {
159 errors.error(loopNext->fPosition, "expected loop index in loop expression");
160 return nullptr;
161 }
162 if (!ConstantFolder::GetConstantValue(*next.right(), &loopInfo->fDelta)) {
163 errors.error(loopNext->fPosition,
164 "loop index must be modified by a constant expression");
165 return nullptr;
166 }
167 switch (next.getOperator().kind()) {
168 case Operator::Kind::PLUSEQ: break;
169 case Operator::Kind::MINUSEQ: loopInfo->fDelta = -loopInfo->fDelta; break;
170 default:
171 errors.error(loopNext->fPosition, "invalid operator in loop expression");
172 return nullptr;
173 }
174 } break;
175 case Expression::Kind::kPrefix: {
176 const PrefixExpression& next = loopNext->as<PrefixExpression>();
177 if (!is_loop_index(next.operand())) {
178 errors.error(loopNext->fPosition, "expected loop index in loop expression");
179 return nullptr;
180 }
181 switch (next.getOperator().kind()) {
182 case Operator::Kind::PLUSPLUS: loopInfo->fDelta = 1; break;
183 case Operator::Kind::MINUSMINUS: loopInfo->fDelta = -1; break;
184 default:
185 errors.error(loopNext->fPosition, "invalid operator in loop expression");
186 return nullptr;
187 }
188 } break;
189 case Expression::Kind::kPostfix: {
190 const PostfixExpression& next = loopNext->as<PostfixExpression>();
191 if (!is_loop_index(next.operand())) {
192 errors.error(loopNext->fPosition, "expected loop index in loop expression");
193 return nullptr;
194 }
195 switch (next.getOperator().kind()) {
196 case Operator::Kind::PLUSPLUS: loopInfo->fDelta = 1; break;
197 case Operator::Kind::MINUSMINUS: loopInfo->fDelta = -1; break;
198 default:
199 errors.error(loopNext->fPosition, "invalid operator in loop expression");
200 return nullptr;
201 }
202 } break;
203 default:
204 errors.error(loopNext->fPosition, "invalid loop expression");
205 return nullptr;
206 }
207
208 //
209 // Within the body of the loop, the loop index is not statically assigned to, nor is it used as
210 // argument to a function 'out' or 'inout' parameter.
211 //
212 if (Analysis::StatementWritesToVariable(*loopStatement, *initDecl.var())) {
213 errors.error(loopStatement->fPosition,
214 "loop index must not be modified within body of the loop");
215 return nullptr;
216 }
217
218 // Finally, compute the iteration count, based on the bounds, and the termination operator.
219 loopInfo->fCount = 0;
220
221 switch (cond.getOperator().kind()) {
222 case Operator::Kind::LT:
223 loopInfo->fCount = calculate_count(loopInfo->fStart, loopEnd, loopInfo->fDelta,
224 /*forwards=*/true, /*inclusive=*/false);
225 break;
226
227 case Operator::Kind::GT:
228 loopInfo->fCount = calculate_count(loopInfo->fStart, loopEnd, loopInfo->fDelta,
229 /*forwards=*/false, /*inclusive=*/false);
230 break;
231
232 case Operator::Kind::LTEQ:
233 loopInfo->fCount = calculate_count(loopInfo->fStart, loopEnd, loopInfo->fDelta,
234 /*forwards=*/true, /*inclusive=*/true);
235 break;
236
237 case Operator::Kind::GTEQ:
238 loopInfo->fCount = calculate_count(loopInfo->fStart, loopEnd, loopInfo->fDelta,
239 /*forwards=*/false, /*inclusive=*/true);
240 break;
241
242 case Operator::Kind::NEQ: {
243 float iterations = sk_ieee_double_divide(loopEnd - loopInfo->fStart, loopInfo->fDelta);
244 loopInfo->fCount = std::ceil(iterations);
245 if (loopInfo->fCount < 0 || loopInfo->fCount != iterations ||
246 !std::isfinite(iterations)) {
247 // The loop doesn't reach the exact endpoint and so will never terminate.
248 loopInfo->fCount = kLoopTerminationLimit;
249 }
250 break;
251 }
252 case Operator::Kind::EQEQ: {
253 if (loopInfo->fStart == loopEnd) {
254 // Start and end begin in the same place, so we can run one iteration...
255 if (loopInfo->fDelta) {
256 // ... and then they diverge, so the loop terminates.
257 loopInfo->fCount = 1;
258 } else {
259 // ... but they never diverge, so the loop runs forever.
260 loopInfo->fCount = kLoopTerminationLimit;
261 }
262 } else {
263 // Start never equals end, so the loop will not run a single iteration.
264 loopInfo->fCount = 0;
265 }
266 break;
267 }
268 default: SkUNREACHABLE;
269 }
270
271 SkASSERT(loopInfo->fCount >= 0);
272 if (loopInfo->fCount >= kLoopTerminationLimit) {
273 errors.error(loopPos, "loop must guarantee termination in fewer iterations");
274 return nullptr;
275 }
276
277 return loopInfo;
278 }
279
280 } // namespace SkSL
281