1 /*
2 * Copyright 2020 Google LLC
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8 #include "src/sksl/SkSLConstantFolder.h"
9
10 #include <limits>
11
12 #include "include/sksl/SkSLErrorReporter.h"
13 #include "src/sksl/SkSLAnalysis.h"
14 #include "src/sksl/SkSLContext.h"
15 #include "src/sksl/SkSLProgramSettings.h"
16 #include "src/sksl/ir/SkSLBinaryExpression.h"
17 #include "src/sksl/ir/SkSLConstructor.h"
18 #include "src/sksl/ir/SkSLConstructorCompound.h"
19 #include "src/sksl/ir/SkSLConstructorSplat.h"
20 #include "src/sksl/ir/SkSLExpression.h"
21 #include "src/sksl/ir/SkSLLiteral.h"
22 #include "src/sksl/ir/SkSLPrefixExpression.h"
23 #include "src/sksl/ir/SkSLType.h"
24 #include "src/sksl/ir/SkSLVariable.h"
25 #include "src/sksl/ir/SkSLVariableReference.h"
26
27 namespace SkSL {
28
eliminate_no_op_boolean(const Expression & left,Operator op,const Expression & right)29 static std::unique_ptr<Expression> eliminate_no_op_boolean(const Expression& left,
30 Operator op,
31 const Expression& right) {
32 bool rightVal = right.as<Literal>().boolValue();
33
34 // Detect no-op Boolean expressions and optimize them away.
35 if ((op.kind() == Token::Kind::TK_LOGICALAND && rightVal) || // (expr && true) -> (expr)
36 (op.kind() == Token::Kind::TK_LOGICALOR && !rightVal) || // (expr || false) -> (expr)
37 (op.kind() == Token::Kind::TK_LOGICALXOR && !rightVal) || // (expr ^^ false) -> (expr)
38 (op.kind() == Token::Kind::TK_EQEQ && rightVal) || // (expr == true) -> (expr)
39 (op.kind() == Token::Kind::TK_NEQ && !rightVal)) { // (expr != false) -> (expr)
40
41 return left.clone();
42 }
43
44 return nullptr;
45 }
46
short_circuit_boolean(const Expression & left,Operator op,const Expression & right)47 static std::unique_ptr<Expression> short_circuit_boolean(const Expression& left,
48 Operator op,
49 const Expression& right) {
50 bool leftVal = left.as<Literal>().boolValue();
51
52 // When the literal is on the left, we can sometimes eliminate the other expression entirely.
53 if ((op.kind() == Token::Kind::TK_LOGICALAND && !leftVal) || // (false && expr) -> (false)
54 (op.kind() == Token::Kind::TK_LOGICALOR && leftVal)) { // (true || expr) -> (true)
55
56 return left.clone();
57 }
58
59 // We can't eliminate the right-side expression via short-circuit, but we might still be able to
60 // simplify away a no-op expression.
61 return eliminate_no_op_boolean(right, op, left);
62 }
63
simplify_vector_equality(const Context & context,const Expression & left,Operator op,const Expression & right)64 static std::unique_ptr<Expression> simplify_vector_equality(const Context& context,
65 const Expression& left,
66 Operator op,
67 const Expression& right) {
68 if (op.kind() == Token::Kind::TK_EQEQ || op.kind() == Token::Kind::TK_NEQ) {
69 bool equality = (op.kind() == Token::Kind::TK_EQEQ);
70
71 switch (left.compareConstant(right)) {
72 case Expression::ComparisonResult::kNotEqual:
73 equality = !equality;
74 [[fallthrough]];
75
76 case Expression::ComparisonResult::kEqual:
77 return Literal::MakeBool(context, left.fLine, equality);
78
79 case Expression::ComparisonResult::kUnknown:
80 break;
81 }
82 }
83 return nullptr;
84 }
85
simplify_vector(const Context & context,const Expression & left,Operator op,const Expression & right)86 static std::unique_ptr<Expression> simplify_vector(const Context& context,
87 const Expression& left,
88 Operator op,
89 const Expression& right) {
90 SkASSERT(left.type().isVector());
91 SkASSERT(left.type() == right.type());
92 const Type& type = left.type();
93
94 // Handle equality operations: == !=
95 if (std::unique_ptr<Expression> result = simplify_vector_equality(context, left, op, right)) {
96 return result;
97 }
98
99 // Handle floating-point arithmetic: + - * /
100 using FoldFn = double (*)(double, double);
101 FoldFn foldFn;
102 switch (op.kind()) {
103 case Token::Kind::TK_PLUS: foldFn = +[](double a, double b) { return a + b; }; break;
104 case Token::Kind::TK_MINUS: foldFn = +[](double a, double b) { return a - b; }; break;
105 case Token::Kind::TK_STAR: foldFn = +[](double a, double b) { return a * b; }; break;
106 case Token::Kind::TK_SLASH: foldFn = +[](double a, double b) { return a / b; }; break;
107 default:
108 return nullptr;
109 }
110
111 const Type& componentType = type.componentType();
112 double minimumValue = -INFINITY, maximumValue = INFINITY;
113 if (componentType.isInteger()) {
114 minimumValue = componentType.minimumValue();
115 maximumValue = componentType.maximumValue();
116 }
117
118 ExpressionArray args;
119 args.reserve_back(type.columns());
120 for (int i = 0; i < type.columns(); i++) {
121 double value = foldFn(*left.getConstantValue(i), *right.getConstantValue(i));
122 if (value < minimumValue || value > maximumValue) {
123 return nullptr;
124 }
125
126 args.push_back(Literal::Make(left.fLine, value, &componentType));
127 }
128 return ConstructorCompound::Make(context, left.fLine, type, std::move(args));
129 }
130
cast_expression(const Context & context,const Expression & expr,const Type & type)131 static std::unique_ptr<Expression> cast_expression(const Context& context,
132 const Expression& expr,
133 const Type& type) {
134 ExpressionArray ctorArgs;
135 ctorArgs.push_back(expr.clone());
136 return Constructor::Convert(context, expr.fLine, type, std::move(ctorArgs));
137 }
138
GetConstantInt(const Expression & value,SKSL_INT * out)139 bool ConstantFolder::GetConstantInt(const Expression& value, SKSL_INT* out) {
140 const Expression* expr = GetConstantValueForVariable(value);
141 if (!expr->isIntLiteral()) {
142 return false;
143 }
144 *out = expr->as<Literal>().intValue();
145 return true;
146 }
147
GetConstantValue(const Expression & value,double * out)148 bool ConstantFolder::GetConstantValue(const Expression& value, double* out) {
149 const Expression* expr = GetConstantValueForVariable(value);
150 if (!expr->is<Literal>()) {
151 return false;
152 }
153 *out = expr->as<Literal>().value();
154 return true;
155 }
156
contains_constant_zero(const Expression & expr)157 static bool contains_constant_zero(const Expression& expr) {
158 int numSlots = expr.type().slotCount();
159 for (int index = 0; index < numSlots; ++index) {
160 skstd::optional<double> slotVal = expr.getConstantValue(index);
161 if (slotVal.has_value() && *slotVal == 0.0) {
162 return true;
163 }
164 }
165 return false;
166 }
167
is_constant_value(const Expression & expr,double value)168 static bool is_constant_value(const Expression& expr, double value) {
169 int numSlots = expr.type().slotCount();
170 for (int index = 0; index < numSlots; ++index) {
171 skstd::optional<double> slotVal = expr.getConstantValue(index);
172 if (!slotVal.has_value() || *slotVal != value) {
173 return false;
174 }
175 }
176 return true;
177 }
178
error_on_divide_by_zero(const Context & context,int line,Operator op,const Expression & right)179 static bool error_on_divide_by_zero(const Context& context, int line, Operator op,
180 const Expression& right) {
181 switch (op.kind()) {
182 case Token::Kind::TK_SLASH:
183 case Token::Kind::TK_SLASHEQ:
184 case Token::Kind::TK_PERCENT:
185 case Token::Kind::TK_PERCENTEQ:
186 if (contains_constant_zero(right)) {
187 context.fErrors->error(line, "division by zero");
188 return true;
189 }
190 return false;
191 default:
192 return false;
193 }
194 }
195
GetConstantValueForVariable(const Expression & inExpr)196 const Expression* ConstantFolder::GetConstantValueForVariable(const Expression& inExpr) {
197 for (const Expression* expr = &inExpr;;) {
198 if (!expr->is<VariableReference>()) {
199 break;
200 }
201 const VariableReference& varRef = expr->as<VariableReference>();
202 if (varRef.refKind() != VariableRefKind::kRead) {
203 break;
204 }
205 const Variable& var = *varRef.variable();
206 if (!(var.modifiers().fFlags & Modifiers::kConst_Flag)) {
207 break;
208 }
209 #ifdef SKSL_EXT
210 if (var.modifiers().fLayout.fFlags & Layout::Flag::kConstantId_Flag) {
211 break;
212 }
213 #endif
214 expr = var.initialValue();
215 if (!expr) {
216 // Function parameters can be const but won't have an initial value.
217 break;
218 }
219 if (expr->isCompileTimeConstant()) {
220 return expr;
221 }
222 }
223 // We didn't find a compile-time constant at the end. Return the expression as-is.
224 return &inExpr;
225 }
226
MakeConstantValueForVariable(std::unique_ptr<Expression> expr)227 std::unique_ptr<Expression> ConstantFolder::MakeConstantValueForVariable(
228 std::unique_ptr<Expression> expr) {
229 const Expression* constantExpr = GetConstantValueForVariable(*expr);
230 if (constantExpr != expr.get()) {
231 expr = constantExpr->clone();
232 }
233 return expr;
234 }
235
simplify_no_op_arithmetic(const Context & context,const Expression & left,Operator op,const Expression & right,const Type & resultType)236 static std::unique_ptr<Expression> simplify_no_op_arithmetic(const Context& context,
237 const Expression& left,
238 Operator op,
239 const Expression& right,
240 const Type& resultType) {
241 switch (op.kind()) {
242 case Token::Kind::TK_PLUS:
243 if (is_constant_value(right, 0.0)) { // x + 0
244 return cast_expression(context, left, resultType);
245 }
246 if (is_constant_value(left, 0.0)) { // 0 + x
247 return cast_expression(context, right, resultType);
248 }
249 break;
250
251 case Token::Kind::TK_STAR:
252 if (is_constant_value(right, 1.0)) { // x * 1
253 return cast_expression(context, left, resultType);
254 }
255 if (is_constant_value(left, 1.0)) { // 1 * x
256 return cast_expression(context, right, resultType);
257 }
258 if (is_constant_value(right, 0.0) && !left.hasSideEffects()) { // x * 0
259 return cast_expression(context, right, resultType);
260 }
261 if (is_constant_value(left, 0.0) && !right.hasSideEffects()) { // 0 * x
262 return cast_expression(context, left, resultType);
263 }
264 break;
265
266 case Token::Kind::TK_MINUS:
267 if (is_constant_value(right, 0.0)) { // x - 0
268 return cast_expression(context, left, resultType);
269 }
270 if (is_constant_value(left, 0.0)) { // 0 - x (to `-x`)
271 if (std::unique_ptr<Expression> val = cast_expression(context, right, resultType)) {
272 return PrefixExpression::Make(context, Token::Kind::TK_MINUS, std::move(val));
273 }
274 }
275 break;
276
277 case Token::Kind::TK_SLASH:
278 if (is_constant_value(right, 1.0)) { // x / 1
279 return cast_expression(context, left, resultType);
280 }
281 break;
282
283 case Token::Kind::TK_PLUSEQ:
284 case Token::Kind::TK_MINUSEQ:
285 if (is_constant_value(right, 0.0)) { // x += 0, x -= 0
286 if (std::unique_ptr<Expression> var = cast_expression(context, left, resultType)) {
287 Analysis::UpdateVariableRefKind(var.get(), VariableRefKind::kRead);
288 return var;
289 }
290 }
291 break;
292
293 case Token::Kind::TK_STAREQ:
294 case Token::Kind::TK_SLASHEQ:
295 if (is_constant_value(right, 1.0)) { // x *= 1, x /= 1
296 if (std::unique_ptr<Expression> var = cast_expression(context, left, resultType)) {
297 Analysis::UpdateVariableRefKind(var.get(), VariableRefKind::kRead);
298 return var;
299 }
300 }
301 break;
302
303 default:
304 break;
305 }
306
307 return nullptr;
308 }
309
310 template <typename T>
fold_float_expression(int line,T result,const Type * resultType)311 static std::unique_ptr<Expression> fold_float_expression(int line,
312 T result,
313 const Type* resultType) {
314 // If constant-folding this expression would generate a NaN/infinite result, leave it as-is.
315 if constexpr (!std::is_same<T, bool>::value) {
316 if (!std::isfinite(result)) {
317 return nullptr;
318 }
319 }
320
321 return Literal::Make(line, result, resultType);
322 }
323
324 template <typename T>
fold_int_expression(int line,T result,const Type * resultType)325 static std::unique_ptr<Expression> fold_int_expression(int line,
326 T result,
327 const Type* resultType) {
328 // If constant-folding this expression would overflow the result type, leave it as-is.
329 if constexpr (!std::is_same<T, bool>::value) {
330 if (result < resultType->minimumValue() || result > resultType->maximumValue()) {
331 return nullptr;
332 }
333 }
334
335 return Literal::Make(line, result, resultType);
336 }
337
Simplify(const Context & context,int line,const Expression & leftExpr,Operator op,const Expression & rightExpr,const Type & resultType)338 std::unique_ptr<Expression> ConstantFolder::Simplify(const Context& context,
339 int line,
340 const Expression& leftExpr,
341 Operator op,
342 const Expression& rightExpr,
343 const Type& resultType) {
344 // Replace constant variables with their literal values.
345 const Expression* left = GetConstantValueForVariable(leftExpr);
346 const Expression* right = GetConstantValueForVariable(rightExpr);
347
348 // If this is the comma operator, the left side is evaluated but not otherwise used in any way.
349 // So if the left side has no side effects, it can just be eliminated entirely.
350 if (op.kind() == Token::Kind::TK_COMMA && !left->hasSideEffects()) {
351 return right->clone();
352 }
353
354 // If this is the assignment operator, and both sides are the same trivial expression, this is
355 // self-assignment (i.e., `var = var`) and can be reduced to just a variable reference (`var`).
356 // This can happen when other parts of the assignment are optimized away.
357 if (op.kind() == Token::Kind::TK_EQ && Analysis::IsSameExpressionTree(*left, *right)) {
358 return right->clone();
359 }
360
361 // Simplify the expression when both sides are constant Boolean literals.
362 if (left->isBoolLiteral() && right->isBoolLiteral()) {
363 bool leftVal = left->as<Literal>().boolValue();
364 bool rightVal = right->as<Literal>().boolValue();
365 bool result;
366 switch (op.kind()) {
367 case Token::Kind::TK_LOGICALAND: result = leftVal && rightVal; break;
368 case Token::Kind::TK_LOGICALOR: result = leftVal || rightVal; break;
369 case Token::Kind::TK_LOGICALXOR: result = leftVal ^ rightVal; break;
370 case Token::Kind::TK_EQEQ: result = leftVal == rightVal; break;
371 case Token::Kind::TK_NEQ: result = leftVal != rightVal; break;
372 default: return nullptr;
373 }
374 return Literal::MakeBool(context, line, result);
375 }
376
377 // If the left side is a Boolean literal, apply short-circuit optimizations.
378 if (left->isBoolLiteral()) {
379 return short_circuit_boolean(*left, op, *right);
380 }
381
382 // If the right side is a Boolean literal...
383 if (right->isBoolLiteral()) {
384 // ... and the left side has no side effects...
385 if (!left->hasSideEffects()) {
386 // We can reverse the expressions and short-circuit optimizations are still valid.
387 return short_circuit_boolean(*right, op, *left);
388 }
389
390 // We can't use short-circuiting, but we can still optimize away no-op Boolean expressions.
391 return eliminate_no_op_boolean(*left, op, *right);
392 }
393
394 if (op.kind() == Token::Kind::TK_EQEQ && Analysis::IsSameExpressionTree(*left, *right)) {
395 // With == comparison, if both sides are the same trivial expression, this is self-
396 // comparison and is always true. (We are not concerned with NaN.)
397 return Literal::MakeBool(context, leftExpr.fLine, /*value=*/true);
398 }
399
400 if (op.kind() == Token::Kind::TK_NEQ && Analysis::IsSameExpressionTree(*left, *right)) {
401 // With != comparison, if both sides are the same trivial expression, this is self-
402 // comparison and is always false. (We are not concerned with NaN.)
403 return Literal::MakeBool(context, leftExpr.fLine, /*value=*/false);
404 }
405
406 if (error_on_divide_by_zero(context, line, op, *right)) {
407 return nullptr;
408 }
409
410 // Optimize away no-op arithmetic like `x * 1`, `x *= 1`, `x + 0`, `x * 0`, `0 / x`, etc.
411 const Type& leftType = left->type();
412 const Type& rightType = right->type();
413 if ((leftType.isScalar() || leftType.isVector()) &&
414 (rightType.isScalar() || rightType.isVector())) {
415 std::unique_ptr<Expression> expr = simplify_no_op_arithmetic(context, *left, op, *right,
416 resultType);
417 if (expr) {
418 return expr;
419 }
420 }
421
422 // Other than the cases above, constant folding requires both sides to be constant.
423 if (!left->isCompileTimeConstant() || !right->isCompileTimeConstant()) {
424 return nullptr;
425 }
426
427 // Note that we expressly do not worry about precision and overflow here -- we use the maximum
428 // precision to calculate the results and hope the result makes sense.
429 // TODO(skia:10932): detect and handle integer overflow properly.
430 using SKSL_UINT = uint64_t;
431 if (left->isIntLiteral() && right->isIntLiteral()) {
432 SKSL_INT leftVal = left->as<Literal>().intValue();
433 SKSL_INT rightVal = right->as<Literal>().intValue();
434
435 #define RESULT(Op) fold_int_expression(line, \
436 (SKSL_INT)(leftVal) Op (SKSL_INT)(rightVal), &resultType)
437 #define URESULT(Op) fold_int_expression(line, \
438 (SKSL_INT)((SKSL_UINT)(leftVal) Op (SKSL_UINT)(rightVal)), &resultType)
439 switch (op.kind()) {
440 case Token::Kind::TK_PLUS: return URESULT(+);
441 case Token::Kind::TK_MINUS: return URESULT(-);
442 case Token::Kind::TK_STAR: return URESULT(*);
443 case Token::Kind::TK_SLASH:
444 if (leftVal == std::numeric_limits<SKSL_INT>::min() && rightVal == -1) {
445 context.fErrors->error(line, "arithmetic overflow");
446 return nullptr;
447 }
448 return RESULT(/);
449 case Token::Kind::TK_PERCENT:
450 if (leftVal == std::numeric_limits<SKSL_INT>::min() && rightVal == -1) {
451 context.fErrors->error(line, "arithmetic overflow");
452 return nullptr;
453 }
454 return RESULT(%);
455 case Token::Kind::TK_BITWISEAND: return RESULT(&);
456 case Token::Kind::TK_BITWISEOR: return RESULT(|);
457 case Token::Kind::TK_BITWISEXOR: return RESULT(^);
458 case Token::Kind::TK_EQEQ: return RESULT(==);
459 case Token::Kind::TK_NEQ: return RESULT(!=);
460 case Token::Kind::TK_GT: return RESULT(>);
461 case Token::Kind::TK_GTEQ: return RESULT(>=);
462 case Token::Kind::TK_LT: return RESULT(<);
463 case Token::Kind::TK_LTEQ: return RESULT(<=);
464 case Token::Kind::TK_SHL:
465 if (rightVal >= 0 && rightVal <= 31) {
466 // Left-shifting a negative (or really, any signed) value is undefined behavior
467 // in C++, but not GLSL. Do the shift on unsigned values, to avoid UBSAN.
468 return URESULT(<<);
469 }
470 context.fErrors->error(line, "shift value out of range");
471 return nullptr;
472 case Token::Kind::TK_SHR:
473 if (rightVal >= 0 && rightVal <= 31) {
474 return RESULT(>>);
475 }
476 context.fErrors->error(line, "shift value out of range");
477 return nullptr;
478
479 default:
480 return nullptr;
481 }
482 #undef RESULT
483 #undef URESULT
484 }
485
486 // Perform constant folding on pairs of floating-point literals.
487 if (left->isFloatLiteral() && right->isFloatLiteral()) {
488 SKSL_FLOAT leftVal = left->as<Literal>().floatValue();
489 SKSL_FLOAT rightVal = right->as<Literal>().floatValue();
490
491 #define RESULT(Op) fold_float_expression(line, leftVal Op rightVal, &resultType)
492 switch (op.kind()) {
493 case Token::Kind::TK_PLUS: return RESULT(+);
494 case Token::Kind::TK_MINUS: return RESULT(-);
495 case Token::Kind::TK_STAR: return RESULT(*);
496 case Token::Kind::TK_SLASH: return RESULT(/);
497 case Token::Kind::TK_EQEQ: return RESULT(==);
498 case Token::Kind::TK_NEQ: return RESULT(!=);
499 case Token::Kind::TK_GT: return RESULT(>);
500 case Token::Kind::TK_GTEQ: return RESULT(>=);
501 case Token::Kind::TK_LT: return RESULT(<);
502 case Token::Kind::TK_LTEQ: return RESULT(<=);
503 default: return nullptr;
504 }
505 #undef RESULT
506 }
507
508 // Perform constant folding on pairs of vectors.
509 if (leftType.isVector() && leftType == rightType) {
510 if (leftType.componentType().isFloat()) {
511 return simplify_vector(context, *left, op, *right);
512 }
513 if (leftType.componentType().isInteger()) {
514 return simplify_vector(context, *left, op, *right);
515 }
516 if (leftType.componentType().isBoolean()) {
517 return simplify_vector_equality(context, *left, op, *right);
518 }
519 return nullptr;
520 }
521
522 // Perform constant folding on vectors against scalars, e.g.: half4(2) + 2
523 if (leftType.isVector() && leftType.componentType() == rightType) {
524 if (rightType.isFloat()) {
525 return simplify_vector(context, *left, op, ConstructorSplat(*right, left->type()));
526 }
527 if (rightType.isInteger()) {
528 return simplify_vector(context, *left, op, ConstructorSplat(*right, left->type()));
529 }
530 if (rightType.isBoolean()) {
531 return simplify_vector_equality(context, *left, op,
532 ConstructorSplat(*right, left->type()));
533 }
534 return nullptr;
535 }
536
537 // Perform constant folding on scalars against vectors, e.g.: 2 + half4(2)
538 if (rightType.isVector() && rightType.componentType() == leftType) {
539 if (leftType.isFloat()) {
540 return simplify_vector(context, ConstructorSplat(*left, right->type()), op, *right);
541 }
542 if (leftType.isInteger()) {
543 return simplify_vector(context, ConstructorSplat(*left, right->type()), op, *right);
544 }
545 if (leftType.isBoolean()) {
546 return simplify_vector_equality(context, ConstructorSplat(*left, right->type()),
547 op, *right);
548 }
549 return nullptr;
550 }
551
552 // Perform constant folding on pairs of matrices or arrays.
553 if ((leftType.isMatrix() && rightType.isMatrix()) ||
554 (leftType.isArray() && rightType.isArray())) {
555 bool equality;
556 switch (op.kind()) {
557 case Token::Kind::TK_EQEQ:
558 equality = true;
559 break;
560 case Token::Kind::TK_NEQ:
561 equality = false;
562 break;
563 default:
564 return nullptr;
565 }
566
567 switch (left->compareConstant(*right)) {
568 case Expression::ComparisonResult::kNotEqual:
569 equality = !equality;
570 [[fallthrough]];
571
572 case Expression::ComparisonResult::kEqual:
573 return Literal::MakeBool(context, line, equality);
574
575 case Expression::ComparisonResult::kUnknown:
576 return nullptr;
577 }
578 }
579
580 // We aren't able to constant-fold.
581 return nullptr;
582 }
583
584 } // namespace SkSL
585