1 /*
2 * Copyright 2020 Google LLC
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8 #include "src/sksl/SkSLConstantFolder.h"
9
10 #include <limits>
11
12 #include "src/sksl/SkSLContext.h"
13 #include "src/sksl/SkSLErrorReporter.h"
14 #include "src/sksl/ir/SkSLBinaryExpression.h"
15 #include "src/sksl/ir/SkSLBoolLiteral.h"
16 #include "src/sksl/ir/SkSLConstructor.h"
17 #include "src/sksl/ir/SkSLConstructorCompound.h"
18 #include "src/sksl/ir/SkSLConstructorSplat.h"
19 #include "src/sksl/ir/SkSLExpression.h"
20 #include "src/sksl/ir/SkSLFloatLiteral.h"
21 #include "src/sksl/ir/SkSLIntLiteral.h"
22 #include "src/sksl/ir/SkSLPrefixExpression.h"
23 #include "src/sksl/ir/SkSLType.h"
24 #include "src/sksl/ir/SkSLVariable.h"
25 #include "src/sksl/ir/SkSLVariableReference.h"
26
27 namespace SkSL {
28
eliminate_no_op_boolean(const Expression & left,Operator op,const Expression & right)29 static std::unique_ptr<Expression> eliminate_no_op_boolean(const Expression& left,
30 Operator op,
31 const Expression& right) {
32 SkASSERT(right.is<BoolLiteral>());
33 bool rightVal = right.as<BoolLiteral>().value();
34
35 // Detect no-op Boolean expressions and optimize them away.
36 if ((op.kind() == Token::Kind::TK_LOGICALAND && rightVal) || // (expr && true) -> (expr)
37 (op.kind() == Token::Kind::TK_LOGICALOR && !rightVal) || // (expr || false) -> (expr)
38 (op.kind() == Token::Kind::TK_LOGICALXOR && !rightVal) || // (expr ^^ false) -> (expr)
39 (op.kind() == Token::Kind::TK_EQEQ && rightVal) || // (expr == true) -> (expr)
40 (op.kind() == Token::Kind::TK_NEQ && !rightVal)) { // (expr != false) -> (expr)
41
42 return left.clone();
43 }
44
45 return nullptr;
46 }
47
short_circuit_boolean(const Expression & left,Operator op,const Expression & right)48 static std::unique_ptr<Expression> short_circuit_boolean(const Expression& left,
49 Operator op,
50 const Expression& right) {
51 SkASSERT(left.is<BoolLiteral>());
52 bool leftVal = left.as<BoolLiteral>().value();
53
54 // When the literal is on the left, we can sometimes eliminate the other expression entirely.
55 if ((op.kind() == Token::Kind::TK_LOGICALAND && !leftVal) || // (false && expr) -> (false)
56 (op.kind() == Token::Kind::TK_LOGICALOR && leftVal)) { // (true || expr) -> (true)
57
58 return left.clone();
59 }
60
61 // We can't eliminate the right-side expression via short-circuit, but we might still be able to
62 // simplify away a no-op expression.
63 return eliminate_no_op_boolean(right, op, left);
64 }
65
66 // 'T' is the actual stored type of the literal data (SKSL_FLOAT or SKSL_INT).
67 // 'U' is an unsigned version of that, used to perform addition, subtraction, and multiplication,
68 // to avoid signed-integer overflow errors. This mimics the use of URESULT vs. RESULT when doing
69 // scalar folding in Simplify, later in this file.
70 template <typename T, typename U = T>
simplify_vector(const Context & context,const Expression & left,Operator op,const Expression & right)71 static std::unique_ptr<Expression> simplify_vector(const Context& context,
72 const Expression& left,
73 Operator op,
74 const Expression& right) {
75 SkASSERT(left.type().isVector());
76 SkASSERT(left.type() == right.type());
77 const Type& type = left.type();
78
79 // Handle boolean operations: == !=
80 if (op.kind() == Token::Kind::TK_EQEQ || op.kind() == Token::Kind::TK_NEQ) {
81 bool equality = (op.kind() == Token::Kind::TK_EQEQ);
82
83 switch (left.compareConstant(right)) {
84 case Expression::ComparisonResult::kNotEqual:
85 equality = !equality;
86 [[fallthrough]];
87
88 case Expression::ComparisonResult::kEqual:
89 return BoolLiteral::Make(context, left.fOffset, equality);
90
91 case Expression::ComparisonResult::kUnknown:
92 return nullptr;
93 }
94 }
95
96 // Handle floating-point arithmetic: + - * /
97 const auto vectorComponentwiseFold = [&](auto foldFn) -> std::unique_ptr<Expression> {
98 const Type& componentType = type.componentType();
99 ExpressionArray args;
100 args.reserve_back(type.columns());
101 for (int i = 0; i < type.columns(); i++) {
102 U value = foldFn(left.getConstantSubexpression(i)->as<Literal<T>>().value(),
103 right.getConstantSubexpression(i)->as<Literal<T>>().value());
104 args.push_back(Literal<T>::Make(left.fOffset, value, &componentType));
105 }
106 return ConstructorCompound::Make(context, left.fOffset, type, std::move(args));
107 };
108
109 switch (op.kind()) {
110 case Token::Kind::TK_PLUS: return vectorComponentwiseFold([](U a, U b) { return a + b; });
111 case Token::Kind::TK_MINUS: return vectorComponentwiseFold([](U a, U b) { return a - b; });
112 case Token::Kind::TK_STAR: return vectorComponentwiseFold([](U a, U b) { return a * b; });
113 case Token::Kind::TK_SLASH: return vectorComponentwiseFold([](T a, T b) { return a / b; });
114 default:
115 return nullptr;
116 }
117 }
118
cast_expression(const Context & context,const Expression & expr,const Type & type)119 static std::unique_ptr<Expression> cast_expression(const Context& context,
120 const Expression& expr,
121 const Type& type) {
122 ExpressionArray ctorArgs;
123 ctorArgs.push_back(expr.clone());
124 std::unique_ptr<Expression> ctor = Constructor::Convert(context, expr.fOffset, type,
125 std::move(ctorArgs));
126 SkASSERT(ctor);
127 return ctor;
128 }
129
splat_scalar(const Expression & scalar,const Type & type)130 static ConstructorSplat splat_scalar(const Expression& scalar, const Type& type) {
131 SkASSERT(type.isVector());
132 SkASSERT(type.componentType() == scalar.type());
133
134 // Use a constructor to splat the scalar expression across a vector.
135 return ConstructorSplat{scalar.fOffset, type, scalar.clone()};
136 }
137
GetConstantInt(const Expression & value,SKSL_INT * out)138 bool ConstantFolder::GetConstantInt(const Expression& value, SKSL_INT* out) {
139 const Expression* expr = GetConstantValueForVariable(value);
140 if (!expr->is<IntLiteral>()) {
141 return false;
142 }
143 *out = expr->as<IntLiteral>().value();
144 return true;
145 }
146
GetConstantFloat(const Expression & value,SKSL_FLOAT * out)147 bool ConstantFolder::GetConstantFloat(const Expression& value, SKSL_FLOAT* out) {
148 const Expression* expr = GetConstantValueForVariable(value);
149 if (!expr->is<FloatLiteral>()) {
150 return false;
151 }
152 *out = expr->as<FloatLiteral>().value();
153 return true;
154 }
155
is_constant_scalar_value(const Expression & inExpr,float match)156 static bool is_constant_scalar_value(const Expression& inExpr, float match) {
157 const Expression* expr = ConstantFolder::GetConstantValueForVariable(inExpr);
158 return (expr->is<IntLiteral>() && expr->as<IntLiteral>().value() == match) ||
159 (expr->is<FloatLiteral>() && expr->as<FloatLiteral>().value() == match);
160 }
161
contains_constant_zero(const Expression & expr)162 static bool contains_constant_zero(const Expression& expr) {
163 if (expr.isAnyConstructor()) {
164 for (const auto& arg : expr.asAnyConstructor().argumentSpan()) {
165 if (contains_constant_zero(*arg)) {
166 return true;
167 }
168 }
169 return false;
170 }
171 return is_constant_scalar_value(expr, 0.0);
172 }
173
is_constant_value(const Expression & expr,float value)174 static bool is_constant_value(const Expression& expr, float value) {
175 // This check only supports scalars and vectors (and in particular, not matrices).
176 SkASSERT(expr.type().isScalar() || expr.type().isVector());
177
178 if (expr.isAnyConstructor()) {
179 for (const auto& arg : expr.asAnyConstructor().argumentSpan()) {
180 if (!is_constant_value(*arg, value)) {
181 return false;
182 }
183 }
184 return true;
185 }
186 return is_constant_scalar_value(expr, value);
187 }
188
ErrorOnDivideByZero(const Context & context,int offset,Operator op,const Expression & right)189 bool ConstantFolder::ErrorOnDivideByZero(const Context& context, int offset, Operator op,
190 const Expression& right) {
191 switch (op.kind()) {
192 case Token::Kind::TK_SLASH:
193 case Token::Kind::TK_SLASHEQ:
194 case Token::Kind::TK_PERCENT:
195 case Token::Kind::TK_PERCENTEQ:
196 if (contains_constant_zero(right)) {
197 context.fErrors.error(offset, "division by zero");
198 return true;
199 }
200 return false;
201 default:
202 return false;
203 }
204 }
205
GetConstantValueForVariable(const Expression & inExpr)206 const Expression* ConstantFolder::GetConstantValueForVariable(const Expression& inExpr) {
207 for (const Expression* expr = &inExpr;;) {
208 if (!expr->is<VariableReference>()) {
209 break;
210 }
211 const VariableReference& varRef = expr->as<VariableReference>();
212 if (varRef.refKind() != VariableRefKind::kRead) {
213 break;
214 }
215 const Variable& var = *varRef.variable();
216 if (!(var.modifiers().fFlags & Modifiers::kConst_Flag)) {
217 break;
218 }
219 expr = var.initialValue();
220 if (!expr) {
221 SkDEBUGFAILF("found a const variable without an initial value (%s)",
222 var.description().c_str());
223 break;
224 }
225 if (expr->isCompileTimeConstant()) {
226 return expr;
227 }
228 if (!expr->is<VariableReference>()) {
229 break;
230 }
231 }
232 // We didn't find a compile-time constant at the end. Return the expression as-is.
233 return &inExpr;
234 }
235
MakeConstantValueForVariable(std::unique_ptr<Expression> expr)236 std::unique_ptr<Expression> ConstantFolder::MakeConstantValueForVariable(
237 std::unique_ptr<Expression> expr) {
238 const Expression* constantExpr = GetConstantValueForVariable(*expr);
239 if (constantExpr != expr.get()) {
240 expr = constantExpr->clone();
241 }
242 return expr;
243 }
244
simplify_no_op_arithmetic(const Context & context,const Expression & left,Operator op,const Expression & right,const Type & resultType)245 static std::unique_ptr<Expression> simplify_no_op_arithmetic(const Context& context,
246 const Expression& left,
247 Operator op,
248 const Expression& right,
249 const Type& resultType) {
250 switch (op.kind()) {
251 case Token::Kind::TK_PLUS:
252 if (is_constant_value(right, 0.0)) { // x + 0
253 return cast_expression(context, left, resultType);
254 }
255 if (is_constant_value(left, 0.0)) { // 0 + x
256 return cast_expression(context, right, resultType);
257 }
258 break;
259
260 case Token::Kind::TK_STAR:
261 if (is_constant_value(right, 1.0)) { // x * 1
262 return cast_expression(context, left, resultType);
263 }
264 if (is_constant_value(left, 1.0)) { // 1 * x
265 return cast_expression(context, right, resultType);
266 }
267 if (is_constant_value(right, 0.0) && !left.hasSideEffects()) { // x * 0
268 return cast_expression(context, right, resultType);
269 }
270 if (is_constant_value(left, 0.0) && !right.hasSideEffects()) { // 0 * x
271 return cast_expression(context, left, resultType);
272 }
273 break;
274
275 case Token::Kind::TK_MINUS:
276 if (is_constant_value(right, 0.0)) { // x - 0
277 return cast_expression(context, left, resultType);
278 }
279 if (is_constant_value(left, 0.0)) { // 0 - x (to `-x`)
280 return PrefixExpression::Make(context, Token::Kind::TK_MINUS,
281 cast_expression(context, right, resultType));
282 }
283 break;
284
285 case Token::Kind::TK_SLASH:
286 if (is_constant_value(right, 1.0)) { // x / 1
287 return cast_expression(context, left, resultType);
288 }
289 if (is_constant_value(left, 0.0) &&
290 !is_constant_value(right, 0.0) &&
291 !right.hasSideEffects()) { // 0 / x (where x is not 0)
292 return cast_expression(context, left, resultType);
293 }
294 break;
295
296 case Token::Kind::TK_PLUSEQ:
297 case Token::Kind::TK_MINUSEQ:
298 if (is_constant_value(right, 0.0)) { // x += 0, x -= 0
299 std::unique_ptr<Expression> result = cast_expression(context, left, resultType);
300 Analysis::UpdateRefKind(result.get(), VariableRefKind::kRead);
301 return result;
302 }
303 break;
304
305 case Token::Kind::TK_STAREQ:
306 case Token::Kind::TK_SLASHEQ:
307 if (is_constant_value(right, 1.0)) { // x *= 1, x /= 1
308 std::unique_ptr<Expression> result = cast_expression(context, left, resultType);
309 Analysis::UpdateRefKind(result.get(), VariableRefKind::kRead);
310 return result;
311 }
312 break;
313
314 default:
315 break;
316 }
317
318 return nullptr;
319 }
320
Simplify(const Context & context,int offset,const Expression & leftExpr,Operator op,const Expression & rightExpr,const Type & resultType)321 std::unique_ptr<Expression> ConstantFolder::Simplify(const Context& context,
322 int offset,
323 const Expression& leftExpr,
324 Operator op,
325 const Expression& rightExpr,
326 const Type& resultType) {
327 // When optimization is enabled, replace constant variables with trivial initial-values.
328 const Expression* left;
329 const Expression* right;
330 if (context.fConfig->fSettings.fOptimize) {
331 left = GetConstantValueForVariable(leftExpr);
332 right = GetConstantValueForVariable(rightExpr);
333 } else {
334 left = &leftExpr;
335 right = &rightExpr;
336 }
337
338 // If this is the comma operator, the left side is evaluated but not otherwise used in any way.
339 // So if the left side has no side effects, it can just be eliminated entirely.
340 if (op.kind() == Token::Kind::TK_COMMA && !left->hasSideEffects()) {
341 return right->clone();
342 }
343
344 // If this is the assignment operator, and both sides are the same trivial expression, this is
345 // self-assignment (i.e., `var = var`) and can be reduced to just a variable reference (`var`).
346 // This can happen when other parts of the assignment are optimized away.
347 if (op.kind() == Token::Kind::TK_EQ && Analysis::IsSameExpressionTree(*left, *right)) {
348 return right->clone();
349 }
350
351 // Simplify the expression when both sides are constant Boolean literals.
352 if (left->is<BoolLiteral>() && right->is<BoolLiteral>()) {
353 bool leftVal = left->as<BoolLiteral>().value();
354 bool rightVal = right->as<BoolLiteral>().value();
355 bool result;
356 switch (op.kind()) {
357 case Token::Kind::TK_LOGICALAND: result = leftVal && rightVal; break;
358 case Token::Kind::TK_LOGICALOR: result = leftVal || rightVal; break;
359 case Token::Kind::TK_LOGICALXOR: result = leftVal ^ rightVal; break;
360 case Token::Kind::TK_EQEQ: result = leftVal == rightVal; break;
361 case Token::Kind::TK_NEQ: result = leftVal != rightVal; break;
362 default: return nullptr;
363 }
364 return BoolLiteral::Make(context, offset, result);
365 }
366
367 // If the left side is a Boolean literal, apply short-circuit optimizations.
368 if (left->is<BoolLiteral>()) {
369 return short_circuit_boolean(*left, op, *right);
370 }
371
372 // If the right side is a Boolean literal...
373 if (right->is<BoolLiteral>()) {
374 // ... and the left side has no side effects...
375 if (!left->hasSideEffects()) {
376 // We can reverse the expressions and short-circuit optimizations are still valid.
377 return short_circuit_boolean(*right, op, *left);
378 }
379
380 // We can't use short-circuiting, but we can still optimize away no-op Boolean expressions.
381 return eliminate_no_op_boolean(*left, op, *right);
382 }
383
384 if (op.kind() == Token::Kind::TK_EQEQ && Analysis::IsSameExpressionTree(*left, *right)) {
385 // With == comparison, if both sides are the same trivial expression, this is self-
386 // comparison and is always true. (We are not concerned with NaN.)
387 return BoolLiteral::Make(context, leftExpr.fOffset, /*value=*/true);
388 }
389
390 if (op.kind() == Token::Kind::TK_NEQ && Analysis::IsSameExpressionTree(*left, *right)) {
391 // With != comparison, if both sides are the same trivial expression, this is self-
392 // comparison and is always false. (We are not concerned with NaN.)
393 return BoolLiteral::Make(context, leftExpr.fOffset, /*value=*/false);
394 }
395
396 if (ErrorOnDivideByZero(context, offset, op, *right)) {
397 return nullptr;
398 }
399
400 // Optimize away no-op arithmetic like `x * 1`, `x *= 1`, `x + 0`, `x * 0`, `0 / x`, etc.
401 const Type& leftType = left->type();
402 const Type& rightType = right->type();
403 if ((leftType.isScalar() || leftType.isVector()) &&
404 (rightType.isScalar() || rightType.isVector())) {
405 std::unique_ptr<Expression> expr = simplify_no_op_arithmetic(context, *left, op, *right,
406 resultType);
407 if (expr) {
408 return expr;
409 }
410 }
411
412 // Other than the cases above, constant folding requires both sides to be constant.
413 if (!left->isCompileTimeConstant() || !right->isCompileTimeConstant()) {
414 return nullptr;
415 }
416
417 // Note that we expressly do not worry about precision and overflow here -- we use the maximum
418 // precision to calculate the results and hope the result makes sense.
419 // TODO(skia:10932): detect and handle integer overflow properly.
420 using SKSL_UINT = uint64_t;
421 #define RESULT(t, op) t ## Literal::Make(offset, leftVal op rightVal, &resultType)
422 #define URESULT(t, op) t ## Literal::Make(offset, (SKSL_UINT)(leftVal) op \
423 (SKSL_UINT)(rightVal), &resultType)
424 if (left->is<IntLiteral>() && right->is<IntLiteral>()) {
425 SKSL_INT leftVal = left->as<IntLiteral>().value();
426 SKSL_INT rightVal = right->as<IntLiteral>().value();
427 switch (op.kind()) {
428 case Token::Kind::TK_PLUS: return URESULT(Int, +);
429 case Token::Kind::TK_MINUS: return URESULT(Int, -);
430 case Token::Kind::TK_STAR: return URESULT(Int, *);
431 case Token::Kind::TK_SLASH:
432 if (leftVal == std::numeric_limits<SKSL_INT>::min() && rightVal == -1) {
433 context.fErrors.error(offset, "arithmetic overflow");
434 return nullptr;
435 }
436 return RESULT(Int, /);
437 case Token::Kind::TK_PERCENT:
438 if (leftVal == std::numeric_limits<SKSL_INT>::min() && rightVal == -1) {
439 context.fErrors.error(offset, "arithmetic overflow");
440 return nullptr;
441 }
442 return RESULT(Int, %);
443 case Token::Kind::TK_BITWISEAND: return RESULT(Int, &);
444 case Token::Kind::TK_BITWISEOR: return RESULT(Int, |);
445 case Token::Kind::TK_BITWISEXOR: return RESULT(Int, ^);
446 case Token::Kind::TK_EQEQ: return RESULT(Bool, ==);
447 case Token::Kind::TK_NEQ: return RESULT(Bool, !=);
448 case Token::Kind::TK_GT: return RESULT(Bool, >);
449 case Token::Kind::TK_GTEQ: return RESULT(Bool, >=);
450 case Token::Kind::TK_LT: return RESULT(Bool, <);
451 case Token::Kind::TK_LTEQ: return RESULT(Bool, <=);
452 case Token::Kind::TK_SHL:
453 if (rightVal >= 0 && rightVal <= 31) {
454 // Left-shifting a negative (or really, any signed) value is undefined behavior
455 // in C++, but not GLSL. Do the shift on unsigned values, to avoid UBSAN.
456 return URESULT(Int, <<);
457 }
458 context.fErrors.error(offset, "shift value out of range");
459 return nullptr;
460 case Token::Kind::TK_SHR:
461 if (rightVal >= 0 && rightVal <= 31) {
462 return RESULT(Int, >>);
463 }
464 context.fErrors.error(offset, "shift value out of range");
465 return nullptr;
466
467 default:
468 return nullptr;
469 }
470 }
471
472 // Perform constant folding on pairs of floating-point literals.
473 if (left->is<FloatLiteral>() && right->is<FloatLiteral>()) {
474 SKSL_FLOAT leftVal = left->as<FloatLiteral>().value();
475 SKSL_FLOAT rightVal = right->as<FloatLiteral>().value();
476 switch (op.kind()) {
477 case Token::Kind::TK_PLUS: return RESULT(Float, +);
478 case Token::Kind::TK_MINUS: return RESULT(Float, -);
479 case Token::Kind::TK_STAR: return RESULT(Float, *);
480 case Token::Kind::TK_SLASH: return RESULT(Float, /);
481 case Token::Kind::TK_EQEQ: return RESULT(Bool, ==);
482 case Token::Kind::TK_NEQ: return RESULT(Bool, !=);
483 case Token::Kind::TK_GT: return RESULT(Bool, >);
484 case Token::Kind::TK_GTEQ: return RESULT(Bool, >=);
485 case Token::Kind::TK_LT: return RESULT(Bool, <);
486 case Token::Kind::TK_LTEQ: return RESULT(Bool, <=);
487 default: return nullptr;
488 }
489 }
490
491 // Perform constant folding on pairs of vectors.
492 if (leftType.isVector() && leftType == rightType) {
493 if (leftType.componentType().isFloat()) {
494 return simplify_vector<SKSL_FLOAT>(context, *left, op, *right);
495 }
496 if (leftType.componentType().isInteger()) {
497 return simplify_vector<SKSL_INT, SKSL_UINT>(context, *left, op, *right);
498 }
499 return nullptr;
500 }
501
502 // Perform constant folding on vectors against scalars, e.g.: half4(2) + 2
503 if (leftType.isVector() && leftType.componentType() == rightType) {
504 if (rightType.isFloat()) {
505 return simplify_vector<SKSL_FLOAT>(context, *left, op,
506 splat_scalar(*right, left->type()));
507 }
508 if (rightType.isInteger()) {
509 return simplify_vector<SKSL_INT, SKSL_UINT>(context, *left, op,
510 splat_scalar(*right, left->type()));
511 }
512 return nullptr;
513 }
514
515 // Perform constant folding on scalars against vectors, e.g.: 2 + half4(2)
516 if (rightType.isVector() && rightType.componentType() == leftType) {
517 if (leftType.isFloat()) {
518 return simplify_vector<SKSL_FLOAT>(context, splat_scalar(*left, right->type()), op,
519 *right);
520 }
521 if (leftType.isInteger()) {
522 return simplify_vector<SKSL_INT, SKSL_UINT>(context, splat_scalar(*left, right->type()),
523 op, *right);
524 }
525 return nullptr;
526 }
527
528 // Perform constant folding on pairs of matrices or arrays.
529 if ((leftType.isMatrix() && rightType.isMatrix()) ||
530 (leftType.isArray() && rightType.isArray())) {
531 bool equality;
532 switch (op.kind()) {
533 case Token::Kind::TK_EQEQ:
534 equality = true;
535 break;
536 case Token::Kind::TK_NEQ:
537 equality = false;
538 break;
539 default:
540 return nullptr;
541 }
542
543 switch (left->compareConstant(*right)) {
544 case Expression::ComparisonResult::kNotEqual:
545 equality = !equality;
546 [[fallthrough]];
547
548 case Expression::ComparisonResult::kEqual:
549 return BoolLiteral::Make(context, offset, equality);
550
551 case Expression::ComparisonResult::kUnknown:
552 return nullptr;
553 }
554 }
555
556 // We aren't able to constant-fold.
557 #undef RESULT
558 #undef URESULT
559 return nullptr;
560 }
561
562 } // namespace SkSL
563