• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2020 Google LLC
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/SkSLConstantFolder.h"
9 
10 #include "include/core/SkTypes.h"
11 #include "include/private/SkSLModifiers.h"
12 #include "include/private/base/SkFloatingPoint.h"
13 #include "include/private/base/SkTArray.h"
14 #include "include/sksl/SkSLErrorReporter.h"
15 #include "include/sksl/SkSLPosition.h"
16 #include "src/sksl/SkSLAnalysis.h"
17 #include "src/sksl/SkSLContext.h"
18 #include "src/sksl/SkSLProgramSettings.h"
19 #include "src/sksl/ir/SkSLBinaryExpression.h"
20 #include "src/sksl/ir/SkSLConstructorCompound.h"
21 #include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
22 #include "src/sksl/ir/SkSLConstructorSplat.h"
23 #include "src/sksl/ir/SkSLExpression.h"
24 #include "src/sksl/ir/SkSLLiteral.h"
25 #include "src/sksl/ir/SkSLPrefixExpression.h"
26 #include "src/sksl/ir/SkSLType.h"
27 #include "src/sksl/ir/SkSLVariable.h"
28 #include "src/sksl/ir/SkSLVariableReference.h"
29 
30 #include <cstdint>
31 #include <float.h>
32 #include <limits>
33 #include <optional>
34 #include <string>
35 #include <utility>
36 
37 namespace SkSL {
38 
is_vec_or_mat(const Type & type)39 static bool is_vec_or_mat(const Type& type) {
40     switch (type.typeKind()) {
41         case Type::TypeKind::kMatrix:
42         case Type::TypeKind::kVector:
43             return true;
44 
45         default:
46             return false;
47     }
48 }
49 
eliminate_no_op_boolean(Position pos,const Expression & left,Operator op,const Expression & right)50 static std::unique_ptr<Expression> eliminate_no_op_boolean(Position pos,
51                                                            const Expression& left,
52                                                            Operator op,
53                                                            const Expression& right) {
54     bool rightVal = right.as<Literal>().boolValue();
55 
56     // Detect no-op Boolean expressions and optimize them away.
57     if ((op.kind() == Operator::Kind::LOGICALAND && rightVal)  ||  // (expr && true)  -> (expr)
58         (op.kind() == Operator::Kind::LOGICALOR  && !rightVal) ||  // (expr || false) -> (expr)
59         (op.kind() == Operator::Kind::LOGICALXOR && !rightVal) ||  // (expr ^^ false) -> (expr)
60         (op.kind() == Operator::Kind::EQEQ       && rightVal)  ||  // (expr == true)  -> (expr)
61         (op.kind() == Operator::Kind::NEQ        && !rightVal)) {  // (expr != false) -> (expr)
62 
63         return left.clone(pos);
64     }
65 
66     return nullptr;
67 }
68 
short_circuit_boolean(Position pos,const Expression & left,Operator op,const Expression & right)69 static std::unique_ptr<Expression> short_circuit_boolean(Position pos,
70                                                          const Expression& left,
71                                                          Operator op,
72                                                          const Expression& right) {
73     bool leftVal = left.as<Literal>().boolValue();
74 
75     // When the literal is on the left, we can sometimes eliminate the other expression entirely.
76     if ((op.kind() == Operator::Kind::LOGICALAND && !leftVal) ||  // (false && expr) -> (false)
77         (op.kind() == Operator::Kind::LOGICALOR  && leftVal)) {   // (true  || expr) -> (true)
78 
79         return left.clone(pos);
80     }
81 
82     // We can't eliminate the right-side expression via short-circuit, but we might still be able to
83     // simplify away a no-op expression.
84     return eliminate_no_op_boolean(pos, right, op, left);
85 }
86 
simplify_constant_equality(const Context & context,Position pos,const Expression & left,Operator op,const Expression & right)87 static std::unique_ptr<Expression> simplify_constant_equality(const Context& context,
88                                                               Position pos,
89                                                               const Expression& left,
90                                                               Operator op,
91                                                               const Expression& right) {
92     if (op.kind() == Operator::Kind::EQEQ || op.kind() == Operator::Kind::NEQ) {
93         bool equality = (op.kind() == Operator::Kind::EQEQ);
94 
95         switch (left.compareConstant(right)) {
96             case Expression::ComparisonResult::kNotEqual:
97                 equality = !equality;
98                 [[fallthrough]];
99 
100             case Expression::ComparisonResult::kEqual:
101                 return Literal::MakeBool(context, pos, equality);
102 
103             case Expression::ComparisonResult::kUnknown:
104                 break;
105         }
106     }
107     return nullptr;
108 }
109 
simplify_matrix_multiplication(const Context & context,Position pos,const Expression & left,const Expression & right,int leftColumns,int leftRows,int rightColumns,int rightRows)110 static std::unique_ptr<Expression> simplify_matrix_multiplication(const Context& context,
111                                                                   Position pos,
112                                                                   const Expression& left,
113                                                                   const Expression& right,
114                                                                   int leftColumns,
115                                                                   int leftRows,
116                                                                   int rightColumns,
117                                                                   int rightRows) {
118     const Type& componentType = left.type().componentType();
119     SkASSERT(componentType.matches(right.type().componentType()));
120 
121     // Fetch the left matrix.
122     double leftVals[4][4];
123     for (int c = 0; c < leftColumns; ++c) {
124         for (int r = 0; r < leftRows; ++r) {
125             leftVals[c][r] = *left.getConstantValue((c * leftRows) + r);
126         }
127     }
128     // Fetch the right matrix.
129     double rightVals[4][4];
130     for (int c = 0; c < rightColumns; ++c) {
131         for (int r = 0; r < rightRows; ++r) {
132             rightVals[c][r] = *right.getConstantValue((c * rightRows) + r);
133         }
134     }
135 
136     SkASSERT(leftColumns == rightRows);
137     int outColumns   = rightColumns,
138         outRows      = leftRows;
139 
140     ExpressionArray args;
141     args.reserve_back(outColumns * outRows);
142     for (int c = 0; c < outColumns; ++c) {
143         for (int r = 0; r < outRows; ++r) {
144             // Compute a dot product for this position.
145             double val = 0;
146             for (int dotIdx = 0; dotIdx < leftColumns; ++dotIdx) {
147                 val += leftVals[dotIdx][r] * rightVals[c][dotIdx];
148             }
149             args.push_back(Literal::Make(pos, val, &componentType));
150         }
151     }
152 
153     if (outColumns == 1) {
154         // Matrix-times-vector conceptually makes a 1-column N-row matrix, but we return vecN.
155         std::swap(outColumns, outRows);
156     }
157 
158     const Type& resultType = componentType.toCompound(context, outColumns, outRows);
159     return ConstructorCompound::Make(context, pos, resultType, std::move(args));
160 }
161 
simplify_matrix_times_matrix(const Context & context,Position pos,const Expression & left,const Expression & right)162 static std::unique_ptr<Expression> simplify_matrix_times_matrix(const Context& context,
163                                                                 Position pos,
164                                                                 const Expression& left,
165                                                                 const Expression& right) {
166     const Type& leftType = left.type();
167     const Type& rightType = right.type();
168 
169     SkASSERT(leftType.isMatrix());
170     SkASSERT(rightType.isMatrix());
171 
172     return simplify_matrix_multiplication(context, pos, left, right,
173                                           leftType.columns(), leftType.rows(),
174                                           rightType.columns(), rightType.rows());
175 }
176 
simplify_vector_times_matrix(const Context & context,Position pos,const Expression & left,const Expression & right)177 static std::unique_ptr<Expression> simplify_vector_times_matrix(const Context& context,
178                                                                 Position pos,
179                                                                 const Expression& left,
180                                                                 const Expression& right) {
181     const Type& leftType = left.type();
182     const Type& rightType = right.type();
183 
184     SkASSERT(leftType.isVector());
185     SkASSERT(rightType.isMatrix());
186 
187     return simplify_matrix_multiplication(context, pos, left, right,
188                                           /*leftColumns=*/leftType.columns(), /*leftRows=*/1,
189                                           rightType.columns(), rightType.rows());
190 }
191 
simplify_matrix_times_vector(const Context & context,Position pos,const Expression & left,const Expression & right)192 static std::unique_ptr<Expression> simplify_matrix_times_vector(const Context& context,
193                                                                 Position pos,
194                                                                 const Expression& left,
195                                                                 const Expression& right) {
196     const Type& leftType = left.type();
197     const Type& rightType = right.type();
198 
199     SkASSERT(leftType.isMatrix());
200     SkASSERT(rightType.isVector());
201 
202     return simplify_matrix_multiplication(context, pos, left, right,
203                                           leftType.columns(), leftType.rows(),
204                                           /*rightColumns=*/1, /*rightRows=*/rightType.columns());
205 }
206 
simplify_componentwise(const Context & context,Position pos,const Expression & left,Operator op,const Expression & right)207 static std::unique_ptr<Expression> simplify_componentwise(const Context& context,
208                                                           Position pos,
209                                                           const Expression& left,
210                                                           Operator op,
211                                                           const Expression& right) {
212     SkASSERT(is_vec_or_mat(left.type()));
213     SkASSERT(left.type().matches(right.type()));
214     const Type& type = left.type();
215 
216     // Handle equality operations: == !=
217     if (std::unique_ptr<Expression> result = simplify_constant_equality(context, pos, left, op,
218             right)) {
219         return result;
220     }
221 
222     // Handle floating-point arithmetic: + - * /
223     using FoldFn = double (*)(double, double);
224     FoldFn foldFn;
225     switch (op.kind()) {
226         case Operator::Kind::PLUS:  foldFn = +[](double a, double b) { return a + b; }; break;
227         case Operator::Kind::MINUS: foldFn = +[](double a, double b) { return a - b; }; break;
228         case Operator::Kind::STAR:  foldFn = +[](double a, double b) { return a * b; }; break;
229         case Operator::Kind::SLASH: foldFn = +[](double a, double b) { return a / b; }; break;
230         default:
231             return nullptr;
232     }
233 
234     const Type& componentType = type.componentType();
235     SkASSERT(componentType.isNumber());
236 
237     double minimumValue = componentType.minimumValue();
238     double maximumValue = componentType.maximumValue();
239 
240     ExpressionArray args;
241     int numSlots = type.slotCount();
242     args.reserve_back(numSlots);
243     for (int i = 0; i < numSlots; i++) {
244         double value = foldFn(*left.getConstantValue(i), *right.getConstantValue(i));
245         if (value < minimumValue || value > maximumValue) {
246             return nullptr;
247         }
248 
249         args.push_back(Literal::Make(pos, value, &componentType));
250     }
251     return ConstructorCompound::Make(context, pos, type, std::move(args));
252 }
253 
splat_scalar(const Context & context,const Expression & scalar,const Type & type)254 static std::unique_ptr<Expression> splat_scalar(const Context& context,
255                                                 const Expression& scalar,
256                                                 const Type& type) {
257     if (type.isVector()) {
258         return ConstructorSplat::Make(context, scalar.fPosition, type, scalar.clone());
259     }
260     if (type.isMatrix()) {
261         int numSlots = type.slotCount();
262         ExpressionArray splatMatrix;
263         splatMatrix.reserve_back(numSlots);
264         for (int index = 0; index < numSlots; ++index) {
265             splatMatrix.push_back(scalar.clone());
266         }
267         return ConstructorCompound::Make(context, scalar.fPosition, type, std::move(splatMatrix));
268     }
269     SkDEBUGFAILF("unsupported type %s", type.description().c_str());
270     return nullptr;
271 }
272 
cast_expression(const Context & context,Position pos,const Expression & expr,const Type & type)273 static std::unique_ptr<Expression> cast_expression(const Context& context,
274                                                    Position pos,
275                                                    const Expression& expr,
276                                                    const Type& type) {
277     SkASSERT(type.componentType().matches(expr.type().componentType()));
278     if (expr.type().isScalar()) {
279         if (type.isMatrix()) {
280             return ConstructorDiagonalMatrix::Make(context, pos, type, expr.clone());
281         }
282         if (type.isVector()) {
283             return ConstructorSplat::Make(context, pos, type, expr.clone());
284         }
285     }
286     if (type.matches(expr.type())) {
287         return expr.clone(pos);
288     }
289     // We can't cast matrices into vectors or vice-versa.
290     return nullptr;
291 }
292 
zero_expression(const Context & context,Position pos,const Type & type)293 static std::unique_ptr<Expression> zero_expression(const Context& context,
294                                                    Position pos,
295                                                    const Type& type) {
296     std::unique_ptr<Expression> zero = Literal::Make(pos, 0.0, &type.componentType());
297     if (type.isScalar()) {
298         return zero;
299     }
300     if (type.isVector()) {
301         return ConstructorSplat::Make(context, pos, type, std::move(zero));
302     }
303     if (type.isMatrix()) {
304         return ConstructorDiagonalMatrix::Make(context, pos, type, std::move(zero));
305     }
306     SkDEBUGFAILF("unsupported type %s", type.description().c_str());
307     return nullptr;
308 }
309 
negate_expression(const Context & context,Position pos,const Expression & expr,const Type & type)310 static std::unique_ptr<Expression> negate_expression(const Context& context,
311                                                      Position pos,
312                                                      const Expression& expr,
313                                                      const Type& type) {
314     std::unique_ptr<Expression> ctor = cast_expression(context, pos, expr, type);
315     return ctor ? PrefixExpression::Make(context, pos, Operator::Kind::MINUS, std::move(ctor))
316                 : nullptr;
317 }
318 
GetConstantInt(const Expression & value,SKSL_INT * out)319 bool ConstantFolder::GetConstantInt(const Expression& value, SKSL_INT* out) {
320     const Expression* expr = GetConstantValueForVariable(value);
321     if (!expr->isIntLiteral()) {
322         return false;
323     }
324     *out = expr->as<Literal>().intValue();
325     return true;
326 }
327 
GetConstantValue(const Expression & value,double * out)328 bool ConstantFolder::GetConstantValue(const Expression& value, double* out) {
329     const Expression* expr = GetConstantValueForVariable(value);
330     if (!expr->is<Literal>()) {
331         return false;
332     }
333     *out = expr->as<Literal>().value();
334     return true;
335 }
336 
contains_constant_zero(const Expression & expr)337 static bool contains_constant_zero(const Expression& expr) {
338     int numSlots = expr.type().slotCount();
339     for (int index = 0; index < numSlots; ++index) {
340         std::optional<double> slotVal = expr.getConstantValue(index);
341         if (slotVal.has_value() && *slotVal == 0.0) {
342             return true;
343         }
344     }
345     return false;
346 }
347 
348 // Returns true if the expression contains `value` in every slot.
is_constant_splat(const Expression & expr,double value)349 static bool is_constant_splat(const Expression& expr, double value) {
350     int numSlots = expr.type().slotCount();
351     for (int index = 0; index < numSlots; ++index) {
352         std::optional<double> slotVal = expr.getConstantValue(index);
353         if (!slotVal.has_value() || *slotVal != value) {
354             return false;
355         }
356     }
357     return true;
358 }
359 
360 // Returns true if the expression is a square diagonal matrix containing `value`.
is_constant_diagonal(const Expression & expr,double value)361 static bool is_constant_diagonal(const Expression& expr, double value) {
362     SkASSERT(expr.type().isMatrix());
363     int columns = expr.type().columns();
364     int rows = expr.type().rows();
365     if (columns != rows) {
366         return false;
367     }
368     int slotIdx = 0;
369     for (int c = 0; c < columns; ++c) {
370         for (int r = 0; r < rows; ++r) {
371             double expectation = (c == r) ? value : 0;
372             std::optional<double> slotVal = expr.getConstantValue(slotIdx++);
373             if (!slotVal.has_value() || *slotVal != expectation) {
374                 return false;
375             }
376         }
377     }
378     return true;
379 }
380 
381 // Returns true if the expression is a scalar, vector, or diagonal matrix containing `value`.
is_constant_value(const Expression & expr,double value)382 static bool is_constant_value(const Expression& expr, double value) {
383     return expr.type().isMatrix() ? is_constant_diagonal(expr, value)
384                                   : is_constant_splat(expr, value);
385 }
386 
387 // The expression represents the right-hand side of a division op. If the division can be
388 // strength-reduced into multiplication by a reciprocal, returns that reciprocal as an expression.
389 // Note that this only supports literal values with safe-to-use reciprocals, and returns null if
390 // Expression contains anything else.
make_reciprocal_expression(const Context & context,const Expression & right)391 static std::unique_ptr<Expression> make_reciprocal_expression(const Context& context,
392                                                               const Expression& right) {
393     if (right.type().isMatrix() || !right.type().componentType().isFloat()) {
394         return nullptr;
395     }
396     // Verify that each slot contains a finite, non-zero literal, take its reciprocal.
397     int nslots = right.type().slotCount();
398     SkSTArray<4, double> values;
399     for (int index = 0; index < nslots; ++index) {
400         std::optional<double> value = right.getConstantValue(index);
401         if (!value) {
402             return nullptr;
403         }
404         *value = sk_ieee_double_divide(1.0, *value);
405         if (*value >= -FLT_MAX && *value <= FLT_MAX && *value != 0.0) {
406             // The reciprocal can be represented safely as a finite 32-bit float.
407             values.push_back(*value);
408         } else {
409             // The value is outside the 32-bit float range, or is NaN; do not optimize.
410             return nullptr;
411         }
412     }
413     // Convert our reciprocal values to Literals.
414     ExpressionArray exprs;
415     exprs.reserve_back(nslots);
416     for (double value : values) {
417         exprs.push_back(Literal::Make(right.fPosition, value, &right.type().componentType()));
418     }
419     // Turn the expression array into a compound constructor. (If this is a single-slot expression,
420     // this will return the literal as-is.)
421     return ConstructorCompound::Make(context, right.fPosition, right.type(), std::move(exprs));
422 }
423 
error_on_divide_by_zero(const Context & context,Position pos,Operator op,const Expression & right)424 static bool error_on_divide_by_zero(const Context& context, Position pos, Operator op,
425                                     const Expression& right) {
426     switch (op.kind()) {
427         case Operator::Kind::SLASH:
428         case Operator::Kind::SLASHEQ:
429         case Operator::Kind::PERCENT:
430         case Operator::Kind::PERCENTEQ:
431             if (contains_constant_zero(right)) {
432                 context.fErrors->error(pos, "division by zero");
433                 return true;
434             }
435             return false;
436         default:
437             return false;
438     }
439 }
440 
GetConstantValueOrNullForVariable(const Expression & inExpr)441 const Expression* ConstantFolder::GetConstantValueOrNullForVariable(const Expression& inExpr) {
442     for (const Expression* expr = &inExpr;;) {
443         if (!expr->is<VariableReference>()) {
444             break;
445         }
446         const VariableReference& varRef = expr->as<VariableReference>();
447         if (varRef.refKind() != VariableRefKind::kRead) {
448             break;
449         }
450         const Variable& var = *varRef.variable();
451         if (!(var.modifiers().fFlags & Modifiers::kConst_Flag)) {
452             break;
453         }
454         expr = var.initialValue();
455         if (!expr) {
456             // Function parameters can be const but won't have an initial value.
457             break;
458         }
459         if (Analysis::IsCompileTimeConstant(*expr)) {
460             return expr;
461         }
462     }
463     // We didn't find a compile-time constant at the end.
464     return nullptr;
465 }
466 
GetConstantValueForVariable(const Expression & inExpr)467 const Expression* ConstantFolder::GetConstantValueForVariable(const Expression& inExpr) {
468     const Expression* expr = GetConstantValueOrNullForVariable(inExpr);
469     return expr ? expr : &inExpr;
470 }
471 
MakeConstantValueForVariable(Position pos,std::unique_ptr<Expression> inExpr)472 std::unique_ptr<Expression> ConstantFolder::MakeConstantValueForVariable(
473         Position pos, std::unique_ptr<Expression> inExpr) {
474     const Expression* expr = GetConstantValueOrNullForVariable(*inExpr);
475     return expr ? expr->clone(pos) : std::move(inExpr);
476 }
477 
is_scalar_op_matrix(const Expression & left,const Expression & right)478 static bool is_scalar_op_matrix(const Expression& left, const Expression& right) {
479     return left.type().isScalar() && right.type().isMatrix();
480 }
481 
is_matrix_op_scalar(const Expression & left,const Expression & right)482 static bool is_matrix_op_scalar(const Expression& left, const Expression& right) {
483     return is_scalar_op_matrix(right, left);
484 }
485 
simplify_arithmetic(const Context & context,Position pos,const Expression & left,Operator op,const Expression & right,const Type & resultType)486 static std::unique_ptr<Expression> simplify_arithmetic(const Context& context,
487                                                        Position pos,
488                                                        const Expression& left,
489                                                        Operator op,
490                                                        const Expression& right,
491                                                        const Type& resultType) {
492     switch (op.kind()) {
493         case Operator::Kind::PLUS:
494             if (!is_scalar_op_matrix(left, right) && is_constant_splat(right, 0.0)) {  // x + 0
495                 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, left,
496                                                                        resultType)) {
497                     return expr;
498                 }
499             }
500             if (!is_matrix_op_scalar(left, right) && is_constant_splat(left, 0.0)) {   // 0 + x
501                 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, right,
502                                                                        resultType)) {
503                     return expr;
504                 }
505             }
506             break;
507 
508         case Operator::Kind::STAR:
509             if (is_constant_value(right, 1.0)) {  // x * 1
510                 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, left,
511                                                                        resultType)) {
512                     return expr;
513                 }
514             }
515             if (is_constant_value(left, 1.0)) {   // 1 * x
516                 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, right,
517                                                                        resultType)) {
518                     return expr;
519                 }
520             }
521             if (is_constant_value(right, 0.0) && !Analysis::HasSideEffects(left)) {  // x * 0
522                 return zero_expression(context, pos, resultType);
523             }
524             if (is_constant_value(left, 0.0) && !Analysis::HasSideEffects(right)) {  // 0 * x
525                 return zero_expression(context, pos, resultType);
526             }
527             if (is_constant_value(right, -1.0)) {  // x * -1 (to `-x`)
528                 if (std::unique_ptr<Expression> expr = negate_expression(context, pos, left,
529                                                                          resultType)) {
530                     return expr;
531                 }
532             }
533             if (is_constant_value(left, -1.0)) {  // -1 * x (to `-x`)
534                 if (std::unique_ptr<Expression> expr = negate_expression(context, pos, right,
535                                                                          resultType)) {
536                     return expr;
537                 }
538             }
539             break;
540 
541         case Operator::Kind::MINUS:
542             if (!is_scalar_op_matrix(left, right) && is_constant_splat(right, 0.0)) {  // x - 0
543                 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, left,
544                                                                        resultType)) {
545                     return expr;
546                 }
547             }
548             if (!is_matrix_op_scalar(left, right) && is_constant_splat(left, 0.0)) {   // 0 - x
549                 if (std::unique_ptr<Expression> expr = negate_expression(context, pos, right,
550                                                                          resultType)) {
551                     return expr;
552                 }
553             }
554             break;
555 
556         case Operator::Kind::SLASH:
557             if (!is_scalar_op_matrix(left, right) && is_constant_splat(right, 1.0)) {  // x / 1
558                 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, left,
559                                                                        resultType)) {
560                     return expr;
561                 }
562             }
563             if (!left.type().isMatrix()) {  // convert `x / 2` into `x * 0.5`
564                 if (std::unique_ptr<Expression> expr = make_reciprocal_expression(context, right)) {
565                     return BinaryExpression::Make(context, pos, left.clone(), Operator::Kind::STAR,
566                                                   std::move(expr));
567                 }
568             }
569             break;
570 
571         case Operator::Kind::PLUSEQ:
572         case Operator::Kind::MINUSEQ:
573             if (is_constant_splat(right, 0.0)) {  // x += 0, x -= 0
574                 if (std::unique_ptr<Expression> var = cast_expression(context, pos, left,
575                                                                       resultType)) {
576                     Analysis::UpdateVariableRefKind(var.get(), VariableRefKind::kRead);
577                     return var;
578                 }
579             }
580             break;
581 
582         case Operator::Kind::STAREQ:
583             if (is_constant_value(right, 1.0)) {  // x *= 1
584                 if (std::unique_ptr<Expression> var = cast_expression(context, pos, left,
585                                                                       resultType)) {
586                     Analysis::UpdateVariableRefKind(var.get(), VariableRefKind::kRead);
587                     return var;
588                 }
589             }
590             break;
591 
592         case Operator::Kind::SLASHEQ:
593             if (is_constant_splat(right, 1.0)) {  // x /= 1
594                 if (std::unique_ptr<Expression> var = cast_expression(context, pos, left,
595                                                                       resultType)) {
596                     Analysis::UpdateVariableRefKind(var.get(), VariableRefKind::kRead);
597                     return var;
598                 }
599             }
600             if (std::unique_ptr<Expression> expr = make_reciprocal_expression(context, right)) {
601                 return BinaryExpression::Make(context, pos, left.clone(), Operator::Kind::STAREQ,
602                                               std::move(expr));
603             }
604             break;
605 
606         default:
607             break;
608     }
609 
610     return nullptr;
611 }
612 
613 // The expression must be scalar, and represents the right-hand side of a division op. It can
614 // contain anything, not just literal values. This returns the binary expression `1.0 / expr`. The
615 // expression might be further simplified by the constant folding, if possible.
one_over_scalar(const Context & context,const Expression & right)616 static std::unique_ptr<Expression> one_over_scalar(const Context& context,
617                                                    const Expression& right) {
618     SkASSERT(right.type().isScalar());
619     Position pos = right.fPosition;
620     return BinaryExpression::Make(context, pos,
621                                   Literal::Make(pos, 1.0, &right.type()),
622                                   Operator::Kind::SLASH,
623                                   right.clone());
624 }
625 
simplify_matrix_division(const Context & context,Position pos,const Expression & left,Operator op,const Expression & right,const Type & resultType)626 static std::unique_ptr<Expression> simplify_matrix_division(const Context& context,
627                                                             Position pos,
628                                                             const Expression& left,
629                                                             Operator op,
630                                                             const Expression& right,
631                                                             const Type& resultType) {
632     // Convert matrix-over-scalar `x /= y` into `x *= (1.0 / y)`. This generates better
633     // code in SPIR-V and Metal, and should be roughly equivalent elsewhere.
634     switch (op.kind()) {
635         case OperatorKind::SLASH:
636         case OperatorKind::SLASHEQ:
637             if (left.type().isMatrix() && right.type().isScalar()) {
638                 Operator multiplyOp = op.isAssignment() ? OperatorKind::STAREQ
639                                                         : OperatorKind::STAR;
640                 return BinaryExpression::Make(context, pos,
641                                               left.clone(),
642                                               multiplyOp,
643                                               one_over_scalar(context, right));
644             }
645             break;
646 
647         default:
648             break;
649     }
650 
651     return nullptr;
652 }
653 
fold_expression(Position pos,double result,const Type * resultType)654 static std::unique_ptr<Expression> fold_expression(Position pos,
655                                                    double result,
656                                                    const Type* resultType) {
657     if (resultType->isNumber()) {
658         if (result >= resultType->minimumValue() && result <= resultType->maximumValue()) {
659             // This result will fit inside its type.
660         } else {
661             // The value is outside the range or is NaN (all if-checks fail); do not optimize.
662             return nullptr;
663         }
664     }
665 
666     return Literal::Make(pos, result, resultType);
667 }
668 
Simplify(const Context & context,Position pos,const Expression & leftExpr,Operator op,const Expression & rightExpr,const Type & resultType)669 std::unique_ptr<Expression> ConstantFolder::Simplify(const Context& context,
670                                                      Position pos,
671                                                      const Expression& leftExpr,
672                                                      Operator op,
673                                                      const Expression& rightExpr,
674                                                      const Type& resultType) {
675     // Replace constant variables with their literal values.
676     const Expression* left = GetConstantValueForVariable(leftExpr);
677     const Expression* right = GetConstantValueForVariable(rightExpr);
678 
679     // If this is the assignment operator, and both sides are the same trivial expression, this is
680     // self-assignment (i.e., `var = var`) and can be reduced to just a variable reference (`var`).
681     // This can happen when other parts of the assignment are optimized away.
682     if (op.kind() == Operator::Kind::EQ && Analysis::IsSameExpressionTree(*left, *right)) {
683         return right->clone(pos);
684     }
685 
686     // Simplify the expression when both sides are constant Boolean literals.
687     if (left->isBoolLiteral() && right->isBoolLiteral()) {
688         bool leftVal  = left->as<Literal>().boolValue();
689         bool rightVal = right->as<Literal>().boolValue();
690         bool result;
691         switch (op.kind()) {
692             case Operator::Kind::LOGICALAND: result = leftVal && rightVal; break;
693             case Operator::Kind::LOGICALOR:  result = leftVal || rightVal; break;
694             case Operator::Kind::LOGICALXOR: result = leftVal ^  rightVal; break;
695             case Operator::Kind::EQEQ:       result = leftVal == rightVal; break;
696             case Operator::Kind::NEQ:        result = leftVal != rightVal; break;
697             default: return nullptr;
698         }
699         return Literal::MakeBool(context, pos, result);
700     }
701 
702     // If the left side is a Boolean literal, apply short-circuit optimizations.
703     if (left->isBoolLiteral()) {
704         return short_circuit_boolean(pos, *left, op, *right);
705     }
706 
707     // If the right side is a Boolean literal...
708     if (right->isBoolLiteral()) {
709         // ... and the left side has no side effects...
710         if (!Analysis::HasSideEffects(*left)) {
711             // We can reverse the expressions and short-circuit optimizations are still valid.
712             return short_circuit_boolean(pos, *right, op, *left);
713         }
714 
715         // We can't use short-circuiting, but we can still optimize away no-op Boolean expressions.
716         return eliminate_no_op_boolean(pos, *left, op, *right);
717     }
718 
719     if (op.kind() == Operator::Kind::EQEQ && Analysis::IsSameExpressionTree(*left, *right)) {
720         // With == comparison, if both sides are the same trivial expression, this is self-
721         // comparison and is always true. (We are not concerned with NaN.)
722         return Literal::MakeBool(context, pos, /*value=*/true);
723     }
724 
725     if (op.kind() == Operator::Kind::NEQ && Analysis::IsSameExpressionTree(*left, *right)) {
726         // With != comparison, if both sides are the same trivial expression, this is self-
727         // comparison and is always false. (We are not concerned with NaN.)
728         return Literal::MakeBool(context, pos, /*value=*/false);
729     }
730 
731     if (error_on_divide_by_zero(context, pos, op, *right)) {
732         return nullptr;
733     }
734 
735     // Perform full constant folding when both sides are compile-time constants.
736     const Type& leftType = left->type();
737     const Type& rightType = right->type();
738     bool leftSideIsConstant = Analysis::IsCompileTimeConstant(*left);
739     bool rightSideIsConstant = Analysis::IsCompileTimeConstant(*right);
740 
741     if (leftSideIsConstant && rightSideIsConstant) {
742         // Handle pairs of integer literals.
743         if (left->isIntLiteral() && right->isIntLiteral()) {
744             using SKSL_UINT = uint64_t;
745             SKSL_INT leftVal  = left->as<Literal>().intValue();
746             SKSL_INT rightVal = right->as<Literal>().intValue();
747 
748             // Note that fold_expression returns null if the result would overflow its type.
749             #define RESULT(Op)   fold_expression(pos, (SKSL_INT)(leftVal) Op \
750                                                       (SKSL_INT)(rightVal), &resultType)
751             #define URESULT(Op)  fold_expression(pos, (SKSL_INT)((SKSL_UINT)(leftVal) Op \
752                                                       (SKSL_UINT)(rightVal)), &resultType)
753             switch (op.kind()) {
754                 case Operator::Kind::PLUS:       return URESULT(+);
755                 case Operator::Kind::MINUS:      return URESULT(-);
756                 case Operator::Kind::STAR:       return URESULT(*);
757                 case Operator::Kind::SLASH:
758                     if (leftVal == std::numeric_limits<SKSL_INT>::min() && rightVal == -1) {
759                         context.fErrors->error(pos, "arithmetic overflow");
760                         return nullptr;
761                     }
762                     return RESULT(/);
763                 case Operator::Kind::PERCENT:
764                     if (leftVal == std::numeric_limits<SKSL_INT>::min() && rightVal == -1) {
765                         context.fErrors->error(pos, "arithmetic overflow");
766                         return nullptr;
767                     }
768                     return RESULT(%);
769                 case Operator::Kind::BITWISEAND: return RESULT(&);
770                 case Operator::Kind::BITWISEOR:  return RESULT(|);
771                 case Operator::Kind::BITWISEXOR: return RESULT(^);
772                 case Operator::Kind::EQEQ:       return RESULT(==);
773                 case Operator::Kind::NEQ:        return RESULT(!=);
774                 case Operator::Kind::GT:         return RESULT(>);
775                 case Operator::Kind::GTEQ:       return RESULT(>=);
776                 case Operator::Kind::LT:         return RESULT(<);
777                 case Operator::Kind::LTEQ:       return RESULT(<=);
778                 case Operator::Kind::SHL:
779                     if (rightVal >= 0 && rightVal <= 31) {
780                         // Left-shifting a negative (or really, any signed) value is undefined
781                         // behavior in C++, but not in GLSL. Do the shift on unsigned values to avoid
782                         // triggering an UBSAN error.
783                         return URESULT(<<);
784                     }
785                     context.fErrors->error(pos, "shift value out of range");
786                     return nullptr;
787                 case Operator::Kind::SHR:
788                     if (rightVal >= 0 && rightVal <= 31) {
789                         return RESULT(>>);
790                     }
791                     context.fErrors->error(pos, "shift value out of range");
792                     return nullptr;
793 
794                 default:
795                     return nullptr;
796             }
797             #undef RESULT
798             #undef URESULT
799         }
800 
801         // Handle pairs of floating-point literals.
802         if (left->isFloatLiteral() && right->isFloatLiteral()) {
803             SKSL_FLOAT leftVal  = left->as<Literal>().floatValue();
804             SKSL_FLOAT rightVal = right->as<Literal>().floatValue();
805 
806             #define RESULT(Op) fold_expression(pos, leftVal Op rightVal, &resultType)
807             switch (op.kind()) {
808                 case Operator::Kind::PLUS:  return RESULT(+);
809                 case Operator::Kind::MINUS: return RESULT(-);
810                 case Operator::Kind::STAR:  return RESULT(*);
811                 case Operator::Kind::SLASH: return RESULT(/);
812                 case Operator::Kind::EQEQ:  return RESULT(==);
813                 case Operator::Kind::NEQ:   return RESULT(!=);
814                 case Operator::Kind::GT:    return RESULT(>);
815                 case Operator::Kind::GTEQ:  return RESULT(>=);
816                 case Operator::Kind::LT:    return RESULT(<);
817                 case Operator::Kind::LTEQ:  return RESULT(<=);
818                 default:                    return nullptr;
819             }
820             #undef RESULT
821         }
822 
823         // Perform matrix multiplication.
824         if (op.kind() == Operator::Kind::STAR) {
825             if (leftType.isMatrix() && rightType.isMatrix()) {
826                 return simplify_matrix_times_matrix(context, pos, *left, *right);
827             }
828             if (leftType.isVector() && rightType.isMatrix()) {
829                 return simplify_vector_times_matrix(context, pos, *left, *right);
830             }
831             if (leftType.isMatrix() && rightType.isVector()) {
832                 return simplify_matrix_times_vector(context, pos, *left, *right);
833             }
834         }
835 
836         // Perform constant folding on pairs of vectors/matrices.
837         if (is_vec_or_mat(leftType) && leftType.matches(rightType)) {
838             return simplify_componentwise(context, pos, *left, op, *right);
839         }
840 
841         // Perform constant folding on vectors/matrices against scalars, e.g.: half4(2) + 2
842         if (rightType.isScalar() && is_vec_or_mat(leftType) &&
843             leftType.componentType().matches(rightType)) {
844             return simplify_componentwise(context, pos,
845                                           *left, op, *splat_scalar(context, *right, left->type()));
846         }
847 
848         // Perform constant folding on scalars against vectors/matrices, e.g.: 2 + half4(2)
849         if (leftType.isScalar() && is_vec_or_mat(rightType) &&
850             rightType.componentType().matches(leftType)) {
851             return simplify_componentwise(context, pos,
852                                           *splat_scalar(context, *left, right->type()), op, *right);
853         }
854 
855         // Perform constant folding on pairs of matrices, arrays or structs.
856         if ((leftType.isMatrix() && rightType.isMatrix()) ||
857             (leftType.isArray() && rightType.isArray()) ||
858             (leftType.isStruct() && rightType.isStruct())) {
859             return simplify_constant_equality(context, pos, *left, op, *right);
860         }
861     }
862 
863     if (context.fConfig->fSettings.fOptimize) {
864         // If just one side is constant, we might still be able to simplify arithmetic expressions
865         // like `x * 1`, `x *= 1`, `x + 0`, `x * 0`, `0 / x`, etc.
866         if (leftSideIsConstant || rightSideIsConstant) {
867             if (std::unique_ptr<Expression> expr = simplify_arithmetic(context, pos, *left, op,
868                                                                        *right, resultType)) {
869                 return expr;
870             }
871         }
872 
873         // We can simplify some forms of matrix division even when neither side is constant.
874         if (std::unique_ptr<Expression> expr = simplify_matrix_division(context, pos, *left, op,
875                                                                         *right, resultType)) {
876             return expr;
877         }
878     }
879 
880     // We aren't able to constant-fold.
881     return nullptr;
882 }
883 
884 }  // namespace SkSL
885