• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2020 Google LLC
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/SkSLConstantFolder.h"
9 
10 #include "include/core/SkTypes.h"
11 #include "include/private/base/SkFloatingPoint.h"
12 #include "include/private/base/SkTArray.h"
13 #include "src/sksl/SkSLAnalysis.h"
14 #include "src/sksl/SkSLContext.h"
15 #include "src/sksl/SkSLErrorReporter.h"
16 #include "src/sksl/SkSLPosition.h"
17 #include "src/sksl/SkSLProgramSettings.h"
18 #include "src/sksl/ir/SkSLBinaryExpression.h"
19 #include "src/sksl/ir/SkSLConstructorCompound.h"
20 #include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
21 #include "src/sksl/ir/SkSLConstructorSplat.h"
22 #include "src/sksl/ir/SkSLExpression.h"
23 #include "src/sksl/ir/SkSLLiteral.h"
24 #include "src/sksl/ir/SkSLModifierFlags.h"
25 #include "src/sksl/ir/SkSLPrefixExpression.h"
26 #include "src/sksl/ir/SkSLType.h"
27 #include "src/sksl/ir/SkSLVariable.h"
28 #include "src/sksl/ir/SkSLVariableReference.h"
29 
30 #include <cstdint>
31 #include <float.h>
32 #include <limits>
33 #include <optional>
34 #include <string>
35 #include <utility>
36 
37 using namespace skia_private;
38 
39 namespace SkSL {
40 
is_vec_or_mat(const Type & type)41 static bool is_vec_or_mat(const Type& type) {
42     switch (type.typeKind()) {
43         case Type::TypeKind::kMatrix:
44         case Type::TypeKind::kVector:
45             return true;
46 
47         default:
48             return false;
49     }
50 }
51 
eliminate_no_op_boolean(Position pos,const Expression & left,Operator op,const Expression & right)52 static std::unique_ptr<Expression> eliminate_no_op_boolean(Position pos,
53                                                            const Expression& left,
54                                                            Operator op,
55                                                            const Expression& right) {
56     bool rightVal = right.as<Literal>().boolValue();
57 
58     // Detect no-op Boolean expressions and optimize them away.
59     if ((op.kind() == Operator::Kind::LOGICALAND && rightVal)  ||  // (expr && true)  -> (expr)
60         (op.kind() == Operator::Kind::LOGICALOR  && !rightVal) ||  // (expr || false) -> (expr)
61         (op.kind() == Operator::Kind::LOGICALXOR && !rightVal) ||  // (expr ^^ false) -> (expr)
62         (op.kind() == Operator::Kind::EQEQ       && rightVal)  ||  // (expr == true)  -> (expr)
63         (op.kind() == Operator::Kind::NEQ        && !rightVal)) {  // (expr != false) -> (expr)
64 
65         return left.clone(pos);
66     }
67 
68     return nullptr;
69 }
70 
short_circuit_boolean(Position pos,const Expression & left,Operator op,const Expression & right)71 static std::unique_ptr<Expression> short_circuit_boolean(Position pos,
72                                                          const Expression& left,
73                                                          Operator op,
74                                                          const Expression& right) {
75     bool leftVal = left.as<Literal>().boolValue();
76 
77     // When the literal is on the left, we can sometimes eliminate the other expression entirely.
78     if ((op.kind() == Operator::Kind::LOGICALAND && !leftVal) ||  // (false && expr) -> (false)
79         (op.kind() == Operator::Kind::LOGICALOR  && leftVal)) {   // (true  || expr) -> (true)
80 
81         return left.clone(pos);
82     }
83 
84     // We can't eliminate the right-side expression via short-circuit, but we might still be able to
85     // simplify away a no-op expression.
86     return eliminate_no_op_boolean(pos, right, op, left);
87 }
88 
simplify_constant_equality(const Context & context,Position pos,const Expression & left,Operator op,const Expression & right)89 static std::unique_ptr<Expression> simplify_constant_equality(const Context& context,
90                                                               Position pos,
91                                                               const Expression& left,
92                                                               Operator op,
93                                                               const Expression& right) {
94     if (op.kind() == Operator::Kind::EQEQ || op.kind() == Operator::Kind::NEQ) {
95         bool equality = (op.kind() == Operator::Kind::EQEQ);
96 
97         switch (left.compareConstant(right)) {
98             case Expression::ComparisonResult::kNotEqual:
99                 equality = !equality;
100                 [[fallthrough]];
101 
102             case Expression::ComparisonResult::kEqual:
103                 return Literal::MakeBool(context, pos, equality);
104 
105             case Expression::ComparisonResult::kUnknown:
106                 break;
107         }
108     }
109     return nullptr;
110 }
111 
simplify_matrix_multiplication(const Context & context,Position pos,const Expression & left,const Expression & right,int leftColumns,int leftRows,int rightColumns,int rightRows)112 static std::unique_ptr<Expression> simplify_matrix_multiplication(const Context& context,
113                                                                   Position pos,
114                                                                   const Expression& left,
115                                                                   const Expression& right,
116                                                                   int leftColumns,
117                                                                   int leftRows,
118                                                                   int rightColumns,
119                                                                   int rightRows) {
120     const Type& componentType = left.type().componentType();
121     SkASSERT(componentType.matches(right.type().componentType()));
122 
123     // Fetch the left matrix.
124     double leftVals[4][4];
125     for (int c = 0; c < leftColumns; ++c) {
126         for (int r = 0; r < leftRows; ++r) {
127             leftVals[c][r] = *left.getConstantValue((c * leftRows) + r);
128         }
129     }
130     // Fetch the right matrix.
131     double rightVals[4][4];
132     for (int c = 0; c < rightColumns; ++c) {
133         for (int r = 0; r < rightRows; ++r) {
134             rightVals[c][r] = *right.getConstantValue((c * rightRows) + r);
135         }
136     }
137 
138     SkASSERT(leftColumns == rightRows);
139     int outColumns   = rightColumns,
140         outRows      = leftRows;
141 
142     double args[16];
143     int argIndex = 0;
144     for (int c = 0; c < outColumns; ++c) {
145         for (int r = 0; r < outRows; ++r) {
146             // Compute a dot product for this position.
147             double val = 0;
148             for (int dotIdx = 0; dotIdx < leftColumns; ++dotIdx) {
149                 val += leftVals[dotIdx][r] * rightVals[c][dotIdx];
150             }
151 
152             if (val >= -FLT_MAX && val <= FLT_MAX) {
153                 args[argIndex++] = val;
154             } else {
155                 // The value is outside the 32-bit float range, or is NaN; do not optimize.
156                 return nullptr;
157             }
158         }
159     }
160 
161     if (outColumns == 1) {
162         // Matrix-times-vector conceptually makes a 1-column N-row matrix, but we return vecN.
163         std::swap(outColumns, outRows);
164     }
165 
166     const Type& resultType = componentType.toCompound(context, outColumns, outRows);
167     return ConstructorCompound::MakeFromConstants(context, pos, resultType, args);
168 }
169 
simplify_matrix_times_matrix(const Context & context,Position pos,const Expression & left,const Expression & right)170 static std::unique_ptr<Expression> simplify_matrix_times_matrix(const Context& context,
171                                                                 Position pos,
172                                                                 const Expression& left,
173                                                                 const Expression& right) {
174     const Type& leftType = left.type();
175     const Type& rightType = right.type();
176 
177     SkASSERT(leftType.isMatrix());
178     SkASSERT(rightType.isMatrix());
179 
180     return simplify_matrix_multiplication(context, pos, left, right,
181                                           leftType.columns(), leftType.rows(),
182                                           rightType.columns(), rightType.rows());
183 }
184 
simplify_vector_times_matrix(const Context & context,Position pos,const Expression & left,const Expression & right)185 static std::unique_ptr<Expression> simplify_vector_times_matrix(const Context& context,
186                                                                 Position pos,
187                                                                 const Expression& left,
188                                                                 const Expression& right) {
189     const Type& leftType = left.type();
190     const Type& rightType = right.type();
191 
192     SkASSERT(leftType.isVector());
193     SkASSERT(rightType.isMatrix());
194 
195     return simplify_matrix_multiplication(context, pos, left, right,
196                                           /*leftColumns=*/leftType.columns(), /*leftRows=*/1,
197                                           rightType.columns(), rightType.rows());
198 }
199 
simplify_matrix_times_vector(const Context & context,Position pos,const Expression & left,const Expression & right)200 static std::unique_ptr<Expression> simplify_matrix_times_vector(const Context& context,
201                                                                 Position pos,
202                                                                 const Expression& left,
203                                                                 const Expression& right) {
204     const Type& leftType = left.type();
205     const Type& rightType = right.type();
206 
207     SkASSERT(leftType.isMatrix());
208     SkASSERT(rightType.isVector());
209 
210     return simplify_matrix_multiplication(context, pos, left, right,
211                                           leftType.columns(), leftType.rows(),
212                                           /*rightColumns=*/1, /*rightRows=*/rightType.columns());
213 }
214 
simplify_componentwise(const Context & context,Position pos,const Expression & left,Operator op,const Expression & right)215 static std::unique_ptr<Expression> simplify_componentwise(const Context& context,
216                                                           Position pos,
217                                                           const Expression& left,
218                                                           Operator op,
219                                                           const Expression& right) {
220     SkASSERT(is_vec_or_mat(left.type()));
221     SkASSERT(left.type().matches(right.type()));
222     const Type& type = left.type();
223 
224     // Handle equality operations: == !=
225     if (std::unique_ptr<Expression> result = simplify_constant_equality(context, pos, left, op,
226             right)) {
227         return result;
228     }
229 
230     // Handle floating-point arithmetic: + - * /
231     using FoldFn = double (*)(double, double);
232     FoldFn foldFn;
233     switch (op.kind()) {
234         case Operator::Kind::PLUS:  foldFn = +[](double a, double b) { return a + b; }; break;
235         case Operator::Kind::MINUS: foldFn = +[](double a, double b) { return a - b; }; break;
236         case Operator::Kind::STAR:  foldFn = +[](double a, double b) { return a * b; }; break;
237         case Operator::Kind::SLASH: foldFn = +[](double a, double b) { return a / b; }; break;
238         default:
239             return nullptr;
240     }
241 
242     const Type& componentType = type.componentType();
243     SkASSERT(componentType.isNumber());
244 
245     double minimumValue = componentType.minimumValue();
246     double maximumValue = componentType.maximumValue();
247 
248     double args[16];
249     int numSlots = type.slotCount();
250     for (int i = 0; i < numSlots; i++) {
251         double value = foldFn(*left.getConstantValue(i), *right.getConstantValue(i));
252         if (value < minimumValue || value > maximumValue) {
253             return nullptr;
254         }
255         args[i] = value;
256     }
257     return ConstructorCompound::MakeFromConstants(context, pos, type, args);
258 }
259 
splat_scalar(const Context & context,const Expression & scalar,const Type & type)260 static std::unique_ptr<Expression> splat_scalar(const Context& context,
261                                                 const Expression& scalar,
262                                                 const Type& type) {
263     if (type.isVector()) {
264         return ConstructorSplat::Make(context, scalar.fPosition, type, scalar.clone());
265     }
266     if (type.isMatrix()) {
267         int numSlots = type.slotCount();
268         ExpressionArray splatMatrix;
269         splatMatrix.reserve_exact(numSlots);
270         for (int index = 0; index < numSlots; ++index) {
271             splatMatrix.push_back(scalar.clone());
272         }
273         return ConstructorCompound::Make(context, scalar.fPosition, type, std::move(splatMatrix));
274     }
275     SkDEBUGFAILF("unsupported type %s", type.description().c_str());
276     return nullptr;
277 }
278 
cast_expression(const Context & context,Position pos,const Expression & expr,const Type & type)279 static std::unique_ptr<Expression> cast_expression(const Context& context,
280                                                    Position pos,
281                                                    const Expression& expr,
282                                                    const Type& type) {
283     SkASSERT(type.componentType().matches(expr.type().componentType()));
284     if (expr.type().isScalar()) {
285         if (type.isMatrix()) {
286             return ConstructorDiagonalMatrix::Make(context, pos, type, expr.clone());
287         }
288         if (type.isVector()) {
289             return ConstructorSplat::Make(context, pos, type, expr.clone());
290         }
291     }
292     if (type.matches(expr.type())) {
293         return expr.clone(pos);
294     }
295     // We can't cast matrices into vectors or vice-versa.
296     return nullptr;
297 }
298 
zero_expression(const Context & context,Position pos,const Type & type)299 static std::unique_ptr<Expression> zero_expression(const Context& context,
300                                                    Position pos,
301                                                    const Type& type) {
302     std::unique_ptr<Expression> zero = Literal::Make(pos, 0.0, &type.componentType());
303     if (type.isScalar()) {
304         return zero;
305     }
306     if (type.isVector()) {
307         return ConstructorSplat::Make(context, pos, type, std::move(zero));
308     }
309     if (type.isMatrix()) {
310         return ConstructorDiagonalMatrix::Make(context, pos, type, std::move(zero));
311     }
312     SkDEBUGFAILF("unsupported type %s", type.description().c_str());
313     return nullptr;
314 }
315 
negate_expression(const Context & context,Position pos,const Expression & expr,const Type & type)316 static std::unique_ptr<Expression> negate_expression(const Context& context,
317                                                      Position pos,
318                                                      const Expression& expr,
319                                                      const Type& type) {
320     std::unique_ptr<Expression> ctor = cast_expression(context, pos, expr, type);
321     return ctor ? PrefixExpression::Make(context, pos, Operator::Kind::MINUS, std::move(ctor))
322                 : nullptr;
323 }
324 
GetConstantInt(const Expression & value,SKSL_INT * out)325 bool ConstantFolder::GetConstantInt(const Expression& value, SKSL_INT* out) {
326     const Expression* expr = GetConstantValueForVariable(value);
327     if (!expr->isIntLiteral()) {
328         return false;
329     }
330     *out = expr->as<Literal>().intValue();
331     return true;
332 }
333 
GetConstantValue(const Expression & value,double * out)334 bool ConstantFolder::GetConstantValue(const Expression& value, double* out) {
335     const Expression* expr = GetConstantValueForVariable(value);
336     if (!expr->is<Literal>()) {
337         return false;
338     }
339     *out = expr->as<Literal>().value();
340     return true;
341 }
342 
contains_constant_zero(const Expression & expr)343 static bool contains_constant_zero(const Expression& expr) {
344     int numSlots = expr.type().slotCount();
345     for (int index = 0; index < numSlots; ++index) {
346         std::optional<double> slotVal = expr.getConstantValue(index);
347         if (slotVal.has_value() && *slotVal == 0.0) {
348             return true;
349         }
350     }
351     return false;
352 }
353 
IsConstantSplat(const Expression & expr,double value)354 bool ConstantFolder::IsConstantSplat(const Expression& expr, double value) {
355     int numSlots = expr.type().slotCount();
356     for (int index = 0; index < numSlots; ++index) {
357         std::optional<double> slotVal = expr.getConstantValue(index);
358         if (!slotVal.has_value() || *slotVal != value) {
359             return false;
360         }
361     }
362     return true;
363 }
364 
365 // Returns true if the expression is a square diagonal matrix containing `value`.
is_constant_diagonal(const Expression & expr,double value)366 static bool is_constant_diagonal(const Expression& expr, double value) {
367     SkASSERT(expr.type().isMatrix());
368     int columns = expr.type().columns();
369     int rows = expr.type().rows();
370     if (columns != rows) {
371         return false;
372     }
373     int slotIdx = 0;
374     for (int c = 0; c < columns; ++c) {
375         for (int r = 0; r < rows; ++r) {
376             double expectation = (c == r) ? value : 0;
377             std::optional<double> slotVal = expr.getConstantValue(slotIdx++);
378             if (!slotVal.has_value() || *slotVal != expectation) {
379                 return false;
380             }
381         }
382     }
383     return true;
384 }
385 
386 // Returns true if the expression is a scalar, vector, or diagonal matrix containing `value`.
is_constant_value(const Expression & expr,double value)387 static bool is_constant_value(const Expression& expr, double value) {
388     return expr.type().isMatrix() ? is_constant_diagonal(expr, value)
389                                   : ConstantFolder::IsConstantSplat(expr, value);
390 }
391 
392 // The expression represents the right-hand side of a division op. If the division can be
393 // strength-reduced into multiplication by a reciprocal, returns that reciprocal as an expression.
394 // Note that this only supports literal values with safe-to-use reciprocals, and returns null if
395 // Expression contains anything else.
make_reciprocal_expression(const Context & context,const Expression & right)396 static std::unique_ptr<Expression> make_reciprocal_expression(const Context& context,
397                                                               const Expression& right) {
398     if (right.type().isMatrix() || !right.type().componentType().isFloat()) {
399         return nullptr;
400     }
401     // Verify that each slot contains a finite, non-zero literal, take its reciprocal.
402     double values[4];
403     int nslots = right.type().slotCount();
404     for (int index = 0; index < nslots; ++index) {
405         std::optional<double> value = right.getConstantValue(index);
406         if (!value) {
407             return nullptr;
408         }
409         *value = sk_ieee_double_divide(1.0, *value);
410         if (*value >= -FLT_MAX && *value <= FLT_MAX && *value != 0.0) {
411             // The reciprocal can be represented safely as a finite 32-bit float.
412             values[index] = *value;
413         } else {
414             // The value is outside the 32-bit float range, or is NaN; do not optimize.
415             return nullptr;
416         }
417     }
418     // Turn the expression array into a compound constructor. (If this is a single-slot expression,
419     // this will return the literal as-is.)
420     return ConstructorCompound::MakeFromConstants(context, right.fPosition, right.type(), values);
421 }
422 
error_on_divide_by_zero(const Context & context,Position pos,Operator op,const Expression & right)423 static bool error_on_divide_by_zero(const Context& context, Position pos, Operator op,
424                                     const Expression& right) {
425     switch (op.kind()) {
426         case Operator::Kind::SLASH:
427         case Operator::Kind::SLASHEQ:
428         case Operator::Kind::PERCENT:
429         case Operator::Kind::PERCENTEQ:
430             if (contains_constant_zero(right)) {
431                 context.fErrors->error(pos, "division by zero");
432                 return true;
433             }
434             return false;
435         default:
436             return false;
437     }
438 }
439 
GetConstantValueOrNull(const Expression & inExpr)440 const Expression* ConstantFolder::GetConstantValueOrNull(const Expression& inExpr) {
441     const Expression* expr = &inExpr;
442     while (expr->is<VariableReference>()) {
443         const VariableReference& varRef = expr->as<VariableReference>();
444         if (varRef.refKind() != VariableRefKind::kRead) {
445             return nullptr;
446         }
447         const Variable& var = *varRef.variable();
448         if (!var.modifierFlags().isConst()) {
449             return nullptr;
450         }
451 #ifdef SKSL_EXT
452         if (var.layout().fFlags & SkSL::LayoutFlag::kConstantId) {
453             return nullptr;
454         }
455 #endif
456         expr = var.initialValue();
457         if (!expr) {
458             // Generally, const variables must have initial values. However, function parameters are
459             // an exception; they can be const but won't have an initial value.
460             return nullptr;
461         }
462     }
463     return Analysis::IsCompileTimeConstant(*expr) ? expr : nullptr;
464 }
465 
GetConstantValueForVariable(const Expression & inExpr)466 const Expression* ConstantFolder::GetConstantValueForVariable(const Expression& inExpr) {
467     const Expression* expr = GetConstantValueOrNull(inExpr);
468     return expr ? expr : &inExpr;
469 }
470 
MakeConstantValueForVariable(Position pos,std::unique_ptr<Expression> inExpr)471 std::unique_ptr<Expression> ConstantFolder::MakeConstantValueForVariable(
472         Position pos, std::unique_ptr<Expression> inExpr) {
473     const Expression* expr = GetConstantValueOrNull(*inExpr);
474     return expr ? expr->clone(pos) : std::move(inExpr);
475 }
476 
is_scalar_op_matrix(const Expression & left,const Expression & right)477 static bool is_scalar_op_matrix(const Expression& left, const Expression& right) {
478     return left.type().isScalar() && right.type().isMatrix();
479 }
480 
is_matrix_op_scalar(const Expression & left,const Expression & right)481 static bool is_matrix_op_scalar(const Expression& left, const Expression& right) {
482     return is_scalar_op_matrix(right, left);
483 }
484 
simplify_arithmetic(const Context & context,Position pos,const Expression & left,Operator op,const Expression & right,const Type & resultType)485 static std::unique_ptr<Expression> simplify_arithmetic(const Context& context,
486                                                        Position pos,
487                                                        const Expression& left,
488                                                        Operator op,
489                                                        const Expression& right,
490                                                        const Type& resultType) {
491     switch (op.kind()) {
492         case Operator::Kind::PLUS:
493             if (!is_scalar_op_matrix(left, right) &&
494                 ConstantFolder::IsConstantSplat(right, 0.0)) {  // x + 0
495                 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, left,
496                                                                        resultType)) {
497                     return expr;
498                 }
499             }
500             if (!is_matrix_op_scalar(left, right) &&
501                 ConstantFolder::IsConstantSplat(left, 0.0)) {  // 0 + x
502                 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, right,
503                                                                        resultType)) {
504                     return expr;
505                 }
506             }
507             break;
508 
509         case Operator::Kind::STAR:
510             if (is_constant_value(right, 1.0)) {  // x * 1
511                 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, left,
512                                                                        resultType)) {
513                     return expr;
514                 }
515             }
516             if (is_constant_value(left, 1.0)) {   // 1 * x
517                 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, right,
518                                                                        resultType)) {
519                     return expr;
520                 }
521             }
522             if (is_constant_value(right, 0.0) && !Analysis::HasSideEffects(left)) {  // x * 0
523                 return zero_expression(context, pos, resultType);
524             }
525             if (is_constant_value(left, 0.0) && !Analysis::HasSideEffects(right)) {  // 0 * x
526                 return zero_expression(context, pos, resultType);
527             }
528             if (is_constant_value(right, -1.0)) {  // x * -1 (to `-x`)
529                 if (std::unique_ptr<Expression> expr = negate_expression(context, pos, left,
530                                                                          resultType)) {
531                     return expr;
532                 }
533             }
534             if (is_constant_value(left, -1.0)) {  // -1 * x (to `-x`)
535                 if (std::unique_ptr<Expression> expr = negate_expression(context, pos, right,
536                                                                          resultType)) {
537                     return expr;
538                 }
539             }
540             break;
541 
542         case Operator::Kind::MINUS:
543             if (!is_scalar_op_matrix(left, right) &&
544                 ConstantFolder::IsConstantSplat(right, 0.0)) {  // x - 0
545                 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, left,
546                                                                        resultType)) {
547                     return expr;
548                 }
549             }
550             if (!is_matrix_op_scalar(left, right) &&
551                 ConstantFolder::IsConstantSplat(left, 0.0)) {  // 0 - x
552                 if (std::unique_ptr<Expression> expr = negate_expression(context, pos, right,
553                                                                          resultType)) {
554                     return expr;
555                 }
556             }
557             break;
558 
559         case Operator::Kind::SLASH:
560             if (!is_scalar_op_matrix(left, right) &&
561                 ConstantFolder::IsConstantSplat(right, 1.0)) {  // x / 1
562                 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, left,
563                                                                        resultType)) {
564                     return expr;
565                 }
566             }
567             if (!left.type().isMatrix()) {  // convert `x / 2` into `x * 0.5`
568                 if (std::unique_ptr<Expression> expr = make_reciprocal_expression(context, right)) {
569                     return BinaryExpression::Make(context, pos, left.clone(), Operator::Kind::STAR,
570                                                   std::move(expr));
571                 }
572             }
573             break;
574 
575         case Operator::Kind::PLUSEQ:
576         case Operator::Kind::MINUSEQ:
577             if (ConstantFolder::IsConstantSplat(right, 0.0)) {  // x += 0, x -= 0
578                 if (std::unique_ptr<Expression> var = cast_expression(context, pos, left,
579                                                                       resultType)) {
580                     Analysis::UpdateVariableRefKind(var.get(), VariableRefKind::kRead);
581                     return var;
582                 }
583             }
584             break;
585 
586         case Operator::Kind::STAREQ:
587             if (is_constant_value(right, 1.0)) {  // x *= 1
588                 if (std::unique_ptr<Expression> var = cast_expression(context, pos, left,
589                                                                       resultType)) {
590                     Analysis::UpdateVariableRefKind(var.get(), VariableRefKind::kRead);
591                     return var;
592                 }
593             }
594             break;
595 
596         case Operator::Kind::SLASHEQ:
597             if (ConstantFolder::IsConstantSplat(right, 1.0)) {  // x /= 1
598                 if (std::unique_ptr<Expression> var = cast_expression(context, pos, left,
599                                                                       resultType)) {
600                     Analysis::UpdateVariableRefKind(var.get(), VariableRefKind::kRead);
601                     return var;
602                 }
603             }
604             if (std::unique_ptr<Expression> expr = make_reciprocal_expression(context, right)) {
605                 return BinaryExpression::Make(context, pos, left.clone(), Operator::Kind::STAREQ,
606                                               std::move(expr));
607             }
608             break;
609 
610         default:
611             break;
612     }
613 
614     return nullptr;
615 }
616 
617 // The expression must be scalar, and represents the right-hand side of a division op. It can
618 // contain anything, not just literal values. This returns the binary expression `1.0 / expr`. The
619 // expression might be further simplified by the constant folding, if possible.
one_over_scalar(const Context & context,const Expression & right)620 static std::unique_ptr<Expression> one_over_scalar(const Context& context,
621                                                    const Expression& right) {
622     SkASSERT(right.type().isScalar());
623     Position pos = right.fPosition;
624     return BinaryExpression::Make(context, pos,
625                                   Literal::Make(pos, 1.0, &right.type()),
626                                   Operator::Kind::SLASH,
627                                   right.clone());
628 }
629 
simplify_matrix_division(const Context & context,Position pos,const Expression & left,Operator op,const Expression & right,const Type & resultType)630 static std::unique_ptr<Expression> simplify_matrix_division(const Context& context,
631                                                             Position pos,
632                                                             const Expression& left,
633                                                             Operator op,
634                                                             const Expression& right,
635                                                             const Type& resultType) {
636     // Convert matrix-over-scalar `x /= y` into `x *= (1.0 / y)`. This generates better
637     // code in SPIR-V and Metal, and should be roughly equivalent elsewhere.
638     switch (op.kind()) {
639         case OperatorKind::SLASH:
640         case OperatorKind::SLASHEQ:
641             if (left.type().isMatrix() && right.type().isScalar()) {
642                 Operator multiplyOp = op.isAssignment() ? OperatorKind::STAREQ
643                                                         : OperatorKind::STAR;
644                 return BinaryExpression::Make(context, pos,
645                                               left.clone(),
646                                               multiplyOp,
647                                               one_over_scalar(context, right));
648             }
649             break;
650 
651         default:
652             break;
653     }
654 
655     return nullptr;
656 }
657 
fold_expression(Position pos,double result,const Type * resultType)658 static std::unique_ptr<Expression> fold_expression(Position pos,
659                                                    double result,
660                                                    const Type* resultType) {
661     if (resultType->isNumber()) {
662         if (result >= resultType->minimumValue() && result <= resultType->maximumValue()) {
663             // This result will fit inside its type.
664         } else {
665             // The value is outside the range or is NaN (all if-checks fail); do not optimize.
666             return nullptr;
667         }
668     }
669 
670     return Literal::Make(pos, result, resultType);
671 }
672 
fold_two_constants(const Context & context,Position pos,const Expression * left,Operator op,const Expression * right,const Type & resultType)673 static std::unique_ptr<Expression> fold_two_constants(const Context& context,
674                                                       Position pos,
675                                                       const Expression* left,
676                                                       Operator op,
677                                                       const Expression* right,
678                                                       const Type& resultType) {
679     SkASSERT(Analysis::IsCompileTimeConstant(*left));
680     SkASSERT(Analysis::IsCompileTimeConstant(*right));
681     const Type& leftType = left->type();
682     const Type& rightType = right->type();
683 
684     // Handle pairs of integer literals.
685     if (left->isIntLiteral() && right->isIntLiteral()) {
686         using SKSL_UINT = uint64_t;
687         SKSL_INT leftVal  = left->as<Literal>().intValue();
688         SKSL_INT rightVal = right->as<Literal>().intValue();
689 
690         // Note that fold_expression returns null if the result would overflow its type.
691         #define RESULT(Op)   fold_expression(pos, (SKSL_INT)(leftVal) Op \
692                                                   (SKSL_INT)(rightVal), &resultType)
693         #define URESULT(Op)  fold_expression(pos, (SKSL_INT)((SKSL_UINT)(leftVal) Op \
694                                                   (SKSL_UINT)(rightVal)), &resultType)
695         switch (op.kind()) {
696             case Operator::Kind::PLUS:       return URESULT(+);
697             case Operator::Kind::MINUS:      return URESULT(-);
698             case Operator::Kind::STAR:       return URESULT(*);
699             case Operator::Kind::SLASH:
700                 if (leftVal == std::numeric_limits<SKSL_INT>::min() && rightVal == -1) {
701                     context.fErrors->error(pos, "arithmetic overflow");
702                     return nullptr;
703                 }
704                 return RESULT(/);
705 
706             case Operator::Kind::PERCENT:
707                 if (leftVal == std::numeric_limits<SKSL_INT>::min() && rightVal == -1) {
708                     context.fErrors->error(pos, "arithmetic overflow");
709                     return nullptr;
710                 }
711                 return RESULT(%);
712 
713             case Operator::Kind::BITWISEAND: return RESULT(&);
714             case Operator::Kind::BITWISEOR:  return RESULT(|);
715             case Operator::Kind::BITWISEXOR: return RESULT(^);
716             case Operator::Kind::EQEQ:       return RESULT(==);
717             case Operator::Kind::NEQ:        return RESULT(!=);
718             case Operator::Kind::GT:         return RESULT(>);
719             case Operator::Kind::GTEQ:       return RESULT(>=);
720             case Operator::Kind::LT:         return RESULT(<);
721             case Operator::Kind::LTEQ:       return RESULT(<=);
722             case Operator::Kind::SHL:
723                 if (rightVal >= 0 && rightVal <= 31) {
724                     // Left-shifting a negative (or really, any signed) value is undefined behavior
725                     // in C++, but not in GLSL. Do the shift on unsigned values to avoid triggering
726                     // an UBSAN error.
727                     return URESULT(<<);
728                 }
729                 context.fErrors->error(pos, "shift value out of range");
730                 return nullptr;
731 
732             case Operator::Kind::SHR:
733                 if (rightVal >= 0 && rightVal <= 31) {
734                     return RESULT(>>);
735                 }
736                 context.fErrors->error(pos, "shift value out of range");
737                 return nullptr;
738 
739             default:
740                 break;
741         }
742         #undef RESULT
743         #undef URESULT
744 
745         return nullptr;
746     }
747 
748     // Handle pairs of floating-point literals.
749     if (left->isFloatLiteral() && right->isFloatLiteral()) {
750         SKSL_FLOAT leftVal  = left->as<Literal>().floatValue();
751         SKSL_FLOAT rightVal = right->as<Literal>().floatValue();
752 
753         #define RESULT(Op) fold_expression(pos, leftVal Op rightVal, &resultType)
754         switch (op.kind()) {
755             case Operator::Kind::PLUS:  return RESULT(+);
756             case Operator::Kind::MINUS: return RESULT(-);
757             case Operator::Kind::STAR:  return RESULT(*);
758             case Operator::Kind::SLASH: return RESULT(/);
759             case Operator::Kind::EQEQ:  return RESULT(==);
760             case Operator::Kind::NEQ:   return RESULT(!=);
761             case Operator::Kind::GT:    return RESULT(>);
762             case Operator::Kind::GTEQ:  return RESULT(>=);
763             case Operator::Kind::LT:    return RESULT(<);
764             case Operator::Kind::LTEQ:  return RESULT(<=);
765             default:                    break;
766         }
767         #undef RESULT
768 
769         return nullptr;
770     }
771 
772     // Perform matrix multiplication.
773     if (op.kind() == Operator::Kind::STAR) {
774         if (leftType.isMatrix() && rightType.isMatrix()) {
775             return simplify_matrix_times_matrix(context, pos, *left, *right);
776         }
777         if (leftType.isVector() && rightType.isMatrix()) {
778             return simplify_vector_times_matrix(context, pos, *left, *right);
779         }
780         if (leftType.isMatrix() && rightType.isVector()) {
781             return simplify_matrix_times_vector(context, pos, *left, *right);
782         }
783     }
784 
785     // Perform constant folding on pairs of vectors/matrices.
786     if (is_vec_or_mat(leftType) && leftType.matches(rightType)) {
787         return simplify_componentwise(context, pos, *left, op, *right);
788     }
789 
790     // Perform constant folding on vectors/matrices against scalars, e.g.: half4(2) + 2
791     if (rightType.isScalar() && is_vec_or_mat(leftType) &&
792         leftType.componentType().matches(rightType)) {
793         return simplify_componentwise(context, pos,
794                                       *left, op, *splat_scalar(context, *right, left->type()));
795     }
796 
797     // Perform constant folding on scalars against vectors/matrices, e.g.: 2 + half4(2)
798     if (leftType.isScalar() && is_vec_or_mat(rightType) &&
799         rightType.componentType().matches(leftType)) {
800         return simplify_componentwise(context, pos,
801                                       *splat_scalar(context, *left, right->type()), op, *right);
802     }
803 
804     // Perform constant folding on pairs of matrices, arrays or structs.
805     if ((leftType.isMatrix() && rightType.isMatrix()) ||
806         (leftType.isArray() && rightType.isArray()) ||
807         (leftType.isStruct() && rightType.isStruct())) {
808         return simplify_constant_equality(context, pos, *left, op, *right);
809     }
810 
811     // We aren't able to constant-fold these expressions.
812     return nullptr;
813 }
814 
Simplify(const Context & context,Position pos,const Expression & leftExpr,Operator op,const Expression & rightExpr,const Type & resultType)815 std::unique_ptr<Expression> ConstantFolder::Simplify(const Context& context,
816                                                      Position pos,
817                                                      const Expression& leftExpr,
818                                                      Operator op,
819                                                      const Expression& rightExpr,
820                                                      const Type& resultType) {
821     // Replace constant variables with their literal values.
822     const Expression* left = GetConstantValueForVariable(leftExpr);
823     const Expression* right = GetConstantValueForVariable(rightExpr);
824 
825     // If this is the assignment operator, and both sides are the same trivial expression, this is
826     // self-assignment (i.e., `var = var`) and can be reduced to just a variable reference (`var`).
827     // This can happen when other parts of the assignment are optimized away.
828     if (op.kind() == Operator::Kind::EQ && Analysis::IsSameExpressionTree(*left, *right)) {
829         return right->clone(pos);
830     }
831 
832     // Simplify the expression when both sides are constant Boolean literals.
833     if (left->isBoolLiteral() && right->isBoolLiteral()) {
834         bool leftVal  = left->as<Literal>().boolValue();
835         bool rightVal = right->as<Literal>().boolValue();
836         bool result;
837         switch (op.kind()) {
838             case Operator::Kind::LOGICALAND: result = leftVal && rightVal; break;
839             case Operator::Kind::LOGICALOR:  result = leftVal || rightVal; break;
840             case Operator::Kind::LOGICALXOR: result = leftVal ^  rightVal; break;
841             case Operator::Kind::EQEQ:       result = leftVal == rightVal; break;
842             case Operator::Kind::NEQ:        result = leftVal != rightVal; break;
843             default: return nullptr;
844         }
845         return Literal::MakeBool(context, pos, result);
846     }
847 
848     // If the left side is a Boolean literal, apply short-circuit optimizations.
849     if (left->isBoolLiteral()) {
850         return short_circuit_boolean(pos, *left, op, *right);
851     }
852 
853     // If the right side is a Boolean literal...
854     if (right->isBoolLiteral()) {
855         // ... and the left side has no side effects...
856         if (!Analysis::HasSideEffects(*left)) {
857             // We can reverse the expressions and short-circuit optimizations are still valid.
858             return short_circuit_boolean(pos, *right, op, *left);
859         }
860 
861         // We can't use short-circuiting, but we can still optimize away no-op Boolean expressions.
862         return eliminate_no_op_boolean(pos, *left, op, *right);
863     }
864 
865     if (op.kind() == Operator::Kind::EQEQ && Analysis::IsSameExpressionTree(*left, *right)) {
866         // With == comparison, if both sides are the same trivial expression, this is self-
867         // comparison and is always true. (We are not concerned with NaN.)
868         return Literal::MakeBool(context, pos, /*value=*/true);
869     }
870 
871     if (op.kind() == Operator::Kind::NEQ && Analysis::IsSameExpressionTree(*left, *right)) {
872         // With != comparison, if both sides are the same trivial expression, this is self-
873         // comparison and is always false. (We are not concerned with NaN.)
874         return Literal::MakeBool(context, pos, /*value=*/false);
875     }
876 
877     if (error_on_divide_by_zero(context, pos, op, *right)) {
878         return nullptr;
879     }
880 
881     // Perform full constant folding when both sides are compile-time constants.
882     bool leftSideIsConstant = Analysis::IsCompileTimeConstant(*left);
883     bool rightSideIsConstant = Analysis::IsCompileTimeConstant(*right);
884     if (leftSideIsConstant && rightSideIsConstant) {
885         return fold_two_constants(context, pos, left, op, right, resultType);
886     }
887 
888     if (context.fConfig->fSettings.fOptimize) {
889         // If just one side is constant, we might still be able to simplify arithmetic expressions
890         // like `x * 1`, `x *= 1`, `x + 0`, `x * 0`, `0 / x`, etc.
891         if (leftSideIsConstant || rightSideIsConstant) {
892             if (std::unique_ptr<Expression> expr = simplify_arithmetic(context, pos, *left, op,
893                                                                        *right, resultType)) {
894                 return expr;
895             }
896         }
897 
898         // We can simplify some forms of matrix division even when neither side is constant.
899         if (std::unique_ptr<Expression> expr = simplify_matrix_division(context, pos, *left, op,
900                                                                         *right, resultType)) {
901             return expr;
902         }
903     }
904 
905     // We aren't able to constant-fold.
906     return nullptr;
907 }
908 
909 }  // namespace SkSL
910