• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2020 Google LLC
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/SkSLConstantFolder.h"
9 
10 #include <limits>
11 
12 #include "include/sksl/SkSLErrorReporter.h"
13 #include "src/sksl/SkSLAnalysis.h"
14 #include "src/sksl/SkSLContext.h"
15 #include "src/sksl/SkSLProgramSettings.h"
16 #include "src/sksl/ir/SkSLBinaryExpression.h"
17 #include "src/sksl/ir/SkSLConstructor.h"
18 #include "src/sksl/ir/SkSLConstructorCompound.h"
19 #include "src/sksl/ir/SkSLConstructorSplat.h"
20 #include "src/sksl/ir/SkSLExpression.h"
21 #include "src/sksl/ir/SkSLLiteral.h"
22 #include "src/sksl/ir/SkSLPrefixExpression.h"
23 #include "src/sksl/ir/SkSLType.h"
24 #include "src/sksl/ir/SkSLVariable.h"
25 #include "src/sksl/ir/SkSLVariableReference.h"
26 
27 namespace SkSL {
28 
eliminate_no_op_boolean(const Expression & left,Operator op,const Expression & right)29 static std::unique_ptr<Expression> eliminate_no_op_boolean(const Expression& left,
30                                                            Operator op,
31                                                            const Expression& right) {
32     bool rightVal = right.as<Literal>().boolValue();
33 
34     // Detect no-op Boolean expressions and optimize them away.
35     if ((op.kind() == Token::Kind::TK_LOGICALAND && rightVal)  ||  // (expr && true)  -> (expr)
36         (op.kind() == Token::Kind::TK_LOGICALOR  && !rightVal) ||  // (expr || false) -> (expr)
37         (op.kind() == Token::Kind::TK_LOGICALXOR && !rightVal) ||  // (expr ^^ false) -> (expr)
38         (op.kind() == Token::Kind::TK_EQEQ       && rightVal)  ||  // (expr == true)  -> (expr)
39         (op.kind() == Token::Kind::TK_NEQ        && !rightVal)) {  // (expr != false) -> (expr)
40 
41         return left.clone();
42     }
43 
44     return nullptr;
45 }
46 
short_circuit_boolean(const Expression & left,Operator op,const Expression & right)47 static std::unique_ptr<Expression> short_circuit_boolean(const Expression& left,
48                                                          Operator op,
49                                                          const Expression& right) {
50     bool leftVal = left.as<Literal>().boolValue();
51 
52     // When the literal is on the left, we can sometimes eliminate the other expression entirely.
53     if ((op.kind() == Token::Kind::TK_LOGICALAND && !leftVal) ||  // (false && expr) -> (false)
54         (op.kind() == Token::Kind::TK_LOGICALOR  && leftVal)) {   // (true  || expr) -> (true)
55 
56         return left.clone();
57     }
58 
59     // We can't eliminate the right-side expression via short-circuit, but we might still be able to
60     // simplify away a no-op expression.
61     return eliminate_no_op_boolean(right, op, left);
62 }
63 
simplify_vector_equality(const Context & context,const Expression & left,Operator op,const Expression & right)64 static std::unique_ptr<Expression> simplify_vector_equality(const Context& context,
65                                                             const Expression& left,
66                                                             Operator op,
67                                                             const Expression& right) {
68     if (op.kind() == Token::Kind::TK_EQEQ || op.kind() == Token::Kind::TK_NEQ) {
69         bool equality = (op.kind() == Token::Kind::TK_EQEQ);
70 
71         switch (left.compareConstant(right)) {
72             case Expression::ComparisonResult::kNotEqual:
73                 equality = !equality;
74                 [[fallthrough]];
75 
76             case Expression::ComparisonResult::kEqual:
77                 return Literal::MakeBool(context, left.fLine, equality);
78 
79             case Expression::ComparisonResult::kUnknown:
80                 break;
81         }
82     }
83     return nullptr;
84 }
85 
simplify_vector(const Context & context,const Expression & left,Operator op,const Expression & right)86 static std::unique_ptr<Expression> simplify_vector(const Context& context,
87                                                    const Expression& left,
88                                                    Operator op,
89                                                    const Expression& right) {
90     SkASSERT(left.type().isVector());
91     SkASSERT(left.type() == right.type());
92     const Type& type = left.type();
93 
94     // Handle equality operations: == !=
95     if (std::unique_ptr<Expression> result = simplify_vector_equality(context, left, op, right)) {
96         return result;
97     }
98 
99     // Handle floating-point arithmetic: + - * /
100     using FoldFn = double (*)(double, double);
101     FoldFn foldFn;
102     switch (op.kind()) {
103         case Token::Kind::TK_PLUS:  foldFn = +[](double a, double b) { return a + b; }; break;
104         case Token::Kind::TK_MINUS: foldFn = +[](double a, double b) { return a - b; }; break;
105         case Token::Kind::TK_STAR:  foldFn = +[](double a, double b) { return a * b; }; break;
106         case Token::Kind::TK_SLASH: foldFn = +[](double a, double b) { return a / b; }; break;
107         default:
108             return nullptr;
109     }
110 
111     const Type& componentType = type.componentType();
112     double minimumValue = -INFINITY, maximumValue = INFINITY;
113     if (componentType.isInteger()) {
114         minimumValue = componentType.minimumValue();
115         maximumValue = componentType.maximumValue();
116     }
117 
118     ExpressionArray args;
119     args.reserve_back(type.columns());
120     for (int i = 0; i < type.columns(); i++) {
121         double value = foldFn(*left.getConstantValue(i), *right.getConstantValue(i));
122         if (value < minimumValue || value > maximumValue) {
123             return nullptr;
124         }
125 
126         args.push_back(Literal::Make(left.fLine, value, &componentType));
127     }
128     return ConstructorCompound::Make(context, left.fLine, type, std::move(args));
129 }
130 
cast_expression(const Context & context,const Expression & expr,const Type & type)131 static std::unique_ptr<Expression> cast_expression(const Context& context,
132                                                    const Expression& expr,
133                                                    const Type& type) {
134     ExpressionArray ctorArgs;
135     ctorArgs.push_back(expr.clone());
136     return Constructor::Convert(context, expr.fLine, type, std::move(ctorArgs));
137 }
138 
GetConstantInt(const Expression & value,SKSL_INT * out)139 bool ConstantFolder::GetConstantInt(const Expression& value, SKSL_INT* out) {
140     const Expression* expr = GetConstantValueForVariable(value);
141     if (!expr->isIntLiteral()) {
142         return false;
143     }
144     *out = expr->as<Literal>().intValue();
145     return true;
146 }
147 
GetConstantValue(const Expression & value,double * out)148 bool ConstantFolder::GetConstantValue(const Expression& value, double* out) {
149     const Expression* expr = GetConstantValueForVariable(value);
150     if (!expr->is<Literal>()) {
151         return false;
152     }
153     *out = expr->as<Literal>().value();
154     return true;
155 }
156 
contains_constant_zero(const Expression & expr)157 static bool contains_constant_zero(const Expression& expr) {
158     int numSlots = expr.type().slotCount();
159     for (int index = 0; index < numSlots; ++index) {
160         skstd::optional<double> slotVal = expr.getConstantValue(index);
161         if (slotVal.has_value() && *slotVal == 0.0) {
162             return true;
163         }
164     }
165     return false;
166 }
167 
is_constant_value(const Expression & expr,double value)168 static bool is_constant_value(const Expression& expr, double value) {
169     int numSlots = expr.type().slotCount();
170     for (int index = 0; index < numSlots; ++index) {
171         skstd::optional<double> slotVal = expr.getConstantValue(index);
172         if (!slotVal.has_value() || *slotVal != value) {
173             return false;
174         }
175     }
176     return true;
177 }
178 
error_on_divide_by_zero(const Context & context,int line,Operator op,const Expression & right)179 static bool error_on_divide_by_zero(const Context& context, int line, Operator op,
180                                     const Expression& right) {
181     switch (op.kind()) {
182         case Token::Kind::TK_SLASH:
183         case Token::Kind::TK_SLASHEQ:
184         case Token::Kind::TK_PERCENT:
185         case Token::Kind::TK_PERCENTEQ:
186             if (contains_constant_zero(right)) {
187                 context.fErrors->error(line, "division by zero");
188                 return true;
189             }
190             return false;
191         default:
192             return false;
193     }
194 }
195 
GetConstantValueForVariable(const Expression & inExpr)196 const Expression* ConstantFolder::GetConstantValueForVariable(const Expression& inExpr) {
197     for (const Expression* expr = &inExpr;;) {
198         if (!expr->is<VariableReference>()) {
199             break;
200         }
201         const VariableReference& varRef = expr->as<VariableReference>();
202         if (varRef.refKind() != VariableRefKind::kRead) {
203             break;
204         }
205         const Variable& var = *varRef.variable();
206         if (!(var.modifiers().fFlags & Modifiers::kConst_Flag)) {
207             break;
208         }
209 #ifdef SKSL_EXT
210         if (var.modifiers().fLayout.fFlags & Layout::Flag::kConstantId_Flag) {
211             break;
212         }
213 #endif
214         expr = var.initialValue();
215         if (!expr) {
216             // Function parameters can be const but won't have an initial value.
217             break;
218         }
219         if (expr->isCompileTimeConstant()) {
220             return expr;
221         }
222     }
223     // We didn't find a compile-time constant at the end. Return the expression as-is.
224     return &inExpr;
225 }
226 
MakeConstantValueForVariable(std::unique_ptr<Expression> expr)227 std::unique_ptr<Expression> ConstantFolder::MakeConstantValueForVariable(
228         std::unique_ptr<Expression> expr) {
229     const Expression* constantExpr = GetConstantValueForVariable(*expr);
230     if (constantExpr != expr.get()) {
231         expr = constantExpr->clone();
232     }
233     return expr;
234 }
235 
simplify_no_op_arithmetic(const Context & context,const Expression & left,Operator op,const Expression & right,const Type & resultType)236 static std::unique_ptr<Expression> simplify_no_op_arithmetic(const Context& context,
237                                                              const Expression& left,
238                                                              Operator op,
239                                                              const Expression& right,
240                                                              const Type& resultType) {
241     switch (op.kind()) {
242         case Token::Kind::TK_PLUS:
243             if (is_constant_value(right, 0.0)) {  // x + 0
244                 return cast_expression(context, left, resultType);
245             }
246             if (is_constant_value(left, 0.0)) {   // 0 + x
247                 return cast_expression(context, right, resultType);
248             }
249             break;
250 
251         case Token::Kind::TK_STAR:
252             if (is_constant_value(right, 1.0)) {  // x * 1
253                 return cast_expression(context, left, resultType);
254             }
255             if (is_constant_value(left, 1.0)) {   // 1 * x
256                 return cast_expression(context, right, resultType);
257             }
258             if (is_constant_value(right, 0.0) && !left.hasSideEffects()) {  // x * 0
259                 return cast_expression(context, right, resultType);
260             }
261             if (is_constant_value(left, 0.0) && !right.hasSideEffects()) {  // 0 * x
262                 return cast_expression(context, left, resultType);
263             }
264             break;
265 
266         case Token::Kind::TK_MINUS:
267             if (is_constant_value(right, 0.0)) {  // x - 0
268                 return cast_expression(context, left, resultType);
269             }
270             if (is_constant_value(left, 0.0)) {   // 0 - x (to `-x`)
271                 if (std::unique_ptr<Expression> val = cast_expression(context, right, resultType)) {
272                     return PrefixExpression::Make(context, Token::Kind::TK_MINUS, std::move(val));
273                 }
274             }
275             break;
276 
277         case Token::Kind::TK_SLASH:
278             if (is_constant_value(right, 1.0)) {  // x / 1
279                 return cast_expression(context, left, resultType);
280             }
281             break;
282 
283         case Token::Kind::TK_PLUSEQ:
284         case Token::Kind::TK_MINUSEQ:
285             if (is_constant_value(right, 0.0)) {  // x += 0, x -= 0
286                 if (std::unique_ptr<Expression> var = cast_expression(context, left, resultType)) {
287                     Analysis::UpdateVariableRefKind(var.get(), VariableRefKind::kRead);
288                     return var;
289                 }
290             }
291             break;
292 
293         case Token::Kind::TK_STAREQ:
294         case Token::Kind::TK_SLASHEQ:
295             if (is_constant_value(right, 1.0)) {  // x *= 1, x /= 1
296                 if (std::unique_ptr<Expression> var = cast_expression(context, left, resultType)) {
297                     Analysis::UpdateVariableRefKind(var.get(), VariableRefKind::kRead);
298                     return var;
299                 }
300             }
301             break;
302 
303         default:
304             break;
305     }
306 
307     return nullptr;
308 }
309 
310 template <typename T>
fold_float_expression(int line,T result,const Type * resultType)311 static std::unique_ptr<Expression> fold_float_expression(int line,
312                                                          T result,
313                                                          const Type* resultType) {
314     // If constant-folding this expression would generate a NaN/infinite result, leave it as-is.
315     if constexpr (!std::is_same<T, bool>::value) {
316         if (!std::isfinite(result)) {
317             return nullptr;
318         }
319     }
320 
321     return Literal::Make(line, result, resultType);
322 }
323 
324 template <typename T>
fold_int_expression(int line,T result,const Type * resultType)325 static std::unique_ptr<Expression> fold_int_expression(int line,
326                                                        T result,
327                                                        const Type* resultType) {
328     // If constant-folding this expression would overflow the result type, leave it as-is.
329     if constexpr (!std::is_same<T, bool>::value) {
330         if (result < resultType->minimumValue() || result > resultType->maximumValue()) {
331             return nullptr;
332         }
333     }
334 
335     return Literal::Make(line, result, resultType);
336 }
337 
Simplify(const Context & context,int line,const Expression & leftExpr,Operator op,const Expression & rightExpr,const Type & resultType)338 std::unique_ptr<Expression> ConstantFolder::Simplify(const Context& context,
339                                                      int line,
340                                                      const Expression& leftExpr,
341                                                      Operator op,
342                                                      const Expression& rightExpr,
343                                                      const Type& resultType) {
344     // Replace constant variables with their literal values.
345     const Expression* left = GetConstantValueForVariable(leftExpr);
346     const Expression* right = GetConstantValueForVariable(rightExpr);
347 
348     // If this is the comma operator, the left side is evaluated but not otherwise used in any way.
349     // So if the left side has no side effects, it can just be eliminated entirely.
350     if (op.kind() == Token::Kind::TK_COMMA && !left->hasSideEffects()) {
351         return right->clone();
352     }
353 
354     // If this is the assignment operator, and both sides are the same trivial expression, this is
355     // self-assignment (i.e., `var = var`) and can be reduced to just a variable reference (`var`).
356     // This can happen when other parts of the assignment are optimized away.
357     if (op.kind() == Token::Kind::TK_EQ && Analysis::IsSameExpressionTree(*left, *right)) {
358         return right->clone();
359     }
360 
361     // Simplify the expression when both sides are constant Boolean literals.
362     if (left->isBoolLiteral() && right->isBoolLiteral()) {
363         bool leftVal  = left->as<Literal>().boolValue();
364         bool rightVal = right->as<Literal>().boolValue();
365         bool result;
366         switch (op.kind()) {
367             case Token::Kind::TK_LOGICALAND: result = leftVal && rightVal; break;
368             case Token::Kind::TK_LOGICALOR:  result = leftVal || rightVal; break;
369             case Token::Kind::TK_LOGICALXOR: result = leftVal ^  rightVal; break;
370             case Token::Kind::TK_EQEQ:       result = leftVal == rightVal; break;
371             case Token::Kind::TK_NEQ:        result = leftVal != rightVal; break;
372             default: return nullptr;
373         }
374         return Literal::MakeBool(context, line, result);
375     }
376 
377     // If the left side is a Boolean literal, apply short-circuit optimizations.
378     if (left->isBoolLiteral()) {
379         return short_circuit_boolean(*left, op, *right);
380     }
381 
382     // If the right side is a Boolean literal...
383     if (right->isBoolLiteral()) {
384         // ... and the left side has no side effects...
385         if (!left->hasSideEffects()) {
386             // We can reverse the expressions and short-circuit optimizations are still valid.
387             return short_circuit_boolean(*right, op, *left);
388         }
389 
390         // We can't use short-circuiting, but we can still optimize away no-op Boolean expressions.
391         return eliminate_no_op_boolean(*left, op, *right);
392     }
393 
394     if (op.kind() == Token::Kind::TK_EQEQ && Analysis::IsSameExpressionTree(*left, *right)) {
395         // With == comparison, if both sides are the same trivial expression, this is self-
396         // comparison and is always true. (We are not concerned with NaN.)
397         return Literal::MakeBool(context, leftExpr.fLine, /*value=*/true);
398     }
399 
400     if (op.kind() == Token::Kind::TK_NEQ && Analysis::IsSameExpressionTree(*left, *right)) {
401         // With != comparison, if both sides are the same trivial expression, this is self-
402         // comparison and is always false. (We are not concerned with NaN.)
403         return Literal::MakeBool(context, leftExpr.fLine, /*value=*/false);
404     }
405 
406     if (error_on_divide_by_zero(context, line, op, *right)) {
407         return nullptr;
408     }
409 
410     // Optimize away no-op arithmetic like `x * 1`, `x *= 1`, `x + 0`, `x * 0`, `0 / x`, etc.
411     const Type& leftType = left->type();
412     const Type& rightType = right->type();
413     if ((leftType.isScalar() || leftType.isVector()) &&
414         (rightType.isScalar() || rightType.isVector())) {
415         std::unique_ptr<Expression> expr = simplify_no_op_arithmetic(context, *left, op, *right,
416                                                                      resultType);
417         if (expr) {
418             return expr;
419         }
420     }
421 
422     // Other than the cases above, constant folding requires both sides to be constant.
423     if (!left->isCompileTimeConstant() || !right->isCompileTimeConstant()) {
424         return nullptr;
425     }
426 
427     // Note that we expressly do not worry about precision and overflow here -- we use the maximum
428     // precision to calculate the results and hope the result makes sense.
429     // TODO(skia:10932): detect and handle integer overflow properly.
430     using SKSL_UINT = uint64_t;
431     if (left->isIntLiteral() && right->isIntLiteral()) {
432         SKSL_INT leftVal  = left->as<Literal>().intValue();
433         SKSL_INT rightVal = right->as<Literal>().intValue();
434 
435         #define RESULT(Op)   fold_int_expression(line, \
436                                         (SKSL_INT)(leftVal) Op (SKSL_INT)(rightVal), &resultType)
437         #define URESULT(Op)  fold_int_expression(line, \
438                              (SKSL_INT)((SKSL_UINT)(leftVal) Op (SKSL_UINT)(rightVal)), &resultType)
439         switch (op.kind()) {
440             case Token::Kind::TK_PLUS:       return URESULT(+);
441             case Token::Kind::TK_MINUS:      return URESULT(-);
442             case Token::Kind::TK_STAR:       return URESULT(*);
443             case Token::Kind::TK_SLASH:
444                 if (leftVal == std::numeric_limits<SKSL_INT>::min() && rightVal == -1) {
445                     context.fErrors->error(line, "arithmetic overflow");
446                     return nullptr;
447                 }
448                 return RESULT(/);
449             case Token::Kind::TK_PERCENT:
450                 if (leftVal == std::numeric_limits<SKSL_INT>::min() && rightVal == -1) {
451                     context.fErrors->error(line, "arithmetic overflow");
452                     return nullptr;
453                 }
454                 return RESULT(%);
455             case Token::Kind::TK_BITWISEAND: return RESULT(&);
456             case Token::Kind::TK_BITWISEOR:  return RESULT(|);
457             case Token::Kind::TK_BITWISEXOR: return RESULT(^);
458             case Token::Kind::TK_EQEQ:       return RESULT(==);
459             case Token::Kind::TK_NEQ:        return RESULT(!=);
460             case Token::Kind::TK_GT:         return RESULT(>);
461             case Token::Kind::TK_GTEQ:       return RESULT(>=);
462             case Token::Kind::TK_LT:         return RESULT(<);
463             case Token::Kind::TK_LTEQ:       return RESULT(<=);
464             case Token::Kind::TK_SHL:
465                 if (rightVal >= 0 && rightVal <= 31) {
466                     // Left-shifting a negative (or really, any signed) value is undefined behavior
467                     // in C++, but not GLSL. Do the shift on unsigned values, to avoid UBSAN.
468                     return URESULT(<<);
469                 }
470                 context.fErrors->error(line, "shift value out of range");
471                 return nullptr;
472             case Token::Kind::TK_SHR:
473                 if (rightVal >= 0 && rightVal <= 31) {
474                     return RESULT(>>);
475                 }
476                 context.fErrors->error(line, "shift value out of range");
477                 return nullptr;
478 
479             default:
480                 return nullptr;
481         }
482         #undef RESULT
483         #undef URESULT
484     }
485 
486     // Perform constant folding on pairs of floating-point literals.
487     if (left->isFloatLiteral() && right->isFloatLiteral()) {
488         SKSL_FLOAT leftVal  = left->as<Literal>().floatValue();
489         SKSL_FLOAT rightVal = right->as<Literal>().floatValue();
490 
491         #define RESULT(Op) fold_float_expression(line, leftVal Op rightVal, &resultType)
492         switch (op.kind()) {
493             case Token::Kind::TK_PLUS:  return RESULT(+);
494             case Token::Kind::TK_MINUS: return RESULT(-);
495             case Token::Kind::TK_STAR:  return RESULT(*);
496             case Token::Kind::TK_SLASH: return RESULT(/);
497             case Token::Kind::TK_EQEQ:  return RESULT(==);
498             case Token::Kind::TK_NEQ:   return RESULT(!=);
499             case Token::Kind::TK_GT:    return RESULT(>);
500             case Token::Kind::TK_GTEQ:  return RESULT(>=);
501             case Token::Kind::TK_LT:    return RESULT(<);
502             case Token::Kind::TK_LTEQ:  return RESULT(<=);
503             default:                    return nullptr;
504         }
505         #undef RESULT
506     }
507 
508     // Perform constant folding on pairs of vectors.
509     if (leftType.isVector() && leftType == rightType) {
510         if (leftType.componentType().isFloat()) {
511             return simplify_vector(context, *left, op, *right);
512         }
513         if (leftType.componentType().isInteger()) {
514             return simplify_vector(context, *left, op, *right);
515         }
516         if (leftType.componentType().isBoolean()) {
517             return simplify_vector_equality(context, *left, op, *right);
518         }
519         return nullptr;
520     }
521 
522     // Perform constant folding on vectors against scalars, e.g.: half4(2) + 2
523     if (leftType.isVector() && leftType.componentType() == rightType) {
524         if (rightType.isFloat()) {
525             return simplify_vector(context, *left, op, ConstructorSplat(*right, left->type()));
526         }
527         if (rightType.isInteger()) {
528             return simplify_vector(context, *left, op, ConstructorSplat(*right, left->type()));
529         }
530         if (rightType.isBoolean()) {
531             return simplify_vector_equality(context, *left, op,
532                                             ConstructorSplat(*right, left->type()));
533         }
534         return nullptr;
535     }
536 
537     // Perform constant folding on scalars against vectors, e.g.: 2 + half4(2)
538     if (rightType.isVector() && rightType.componentType() == leftType) {
539         if (leftType.isFloat()) {
540             return simplify_vector(context, ConstructorSplat(*left, right->type()), op, *right);
541         }
542         if (leftType.isInteger()) {
543             return simplify_vector(context, ConstructorSplat(*left, right->type()), op, *right);
544         }
545         if (leftType.isBoolean()) {
546             return simplify_vector_equality(context, ConstructorSplat(*left, right->type()),
547                                             op, *right);
548         }
549         return nullptr;
550     }
551 
552     // Perform constant folding on pairs of matrices or arrays.
553     if ((leftType.isMatrix() && rightType.isMatrix()) ||
554         (leftType.isArray() && rightType.isArray())) {
555         bool equality;
556         switch (op.kind()) {
557             case Token::Kind::TK_EQEQ:
558                 equality = true;
559                 break;
560             case Token::Kind::TK_NEQ:
561                 equality = false;
562                 break;
563             default:
564                 return nullptr;
565         }
566 
567         switch (left->compareConstant(*right)) {
568             case Expression::ComparisonResult::kNotEqual:
569                 equality = !equality;
570                 [[fallthrough]];
571 
572             case Expression::ComparisonResult::kEqual:
573                 return Literal::MakeBool(context, line, equality);
574 
575             case Expression::ComparisonResult::kUnknown:
576                 return nullptr;
577         }
578     }
579 
580     // We aren't able to constant-fold.
581     return nullptr;
582 }
583 
584 }  // namespace SkSL
585