/* * Copyright 2020 Google LLC * * Use of this source code is governed by a BSD-style license that can be * found in the LICENSE file. */ #include "src/sksl/SkSLConstantFolder.h" #include "include/core/SkTypes.h" #include "include/private/SkSLModifiers.h" #include "include/private/base/SkFloatingPoint.h" #include "include/private/base/SkTArray.h" #include "include/sksl/SkSLErrorReporter.h" #include "include/sksl/SkSLPosition.h" #include "src/sksl/SkSLAnalysis.h" #include "src/sksl/SkSLContext.h" #include "src/sksl/SkSLProgramSettings.h" #include "src/sksl/ir/SkSLBinaryExpression.h" #include "src/sksl/ir/SkSLConstructorCompound.h" #include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h" #include "src/sksl/ir/SkSLConstructorSplat.h" #include "src/sksl/ir/SkSLExpression.h" #include "src/sksl/ir/SkSLLiteral.h" #include "src/sksl/ir/SkSLPrefixExpression.h" #include "src/sksl/ir/SkSLType.h" #include "src/sksl/ir/SkSLVariable.h" #include "src/sksl/ir/SkSLVariableReference.h" #include #include #include #include #include #include namespace SkSL { static bool is_vec_or_mat(const Type& type) { switch (type.typeKind()) { case Type::TypeKind::kMatrix: case Type::TypeKind::kVector: return true; default: return false; } } static std::unique_ptr eliminate_no_op_boolean(Position pos, const Expression& left, Operator op, const Expression& right) { bool rightVal = right.as().boolValue(); // Detect no-op Boolean expressions and optimize them away. if ((op.kind() == Operator::Kind::LOGICALAND && rightVal) || // (expr && true) -> (expr) (op.kind() == Operator::Kind::LOGICALOR && !rightVal) || // (expr || false) -> (expr) (op.kind() == Operator::Kind::LOGICALXOR && !rightVal) || // (expr ^^ false) -> (expr) (op.kind() == Operator::Kind::EQEQ && rightVal) || // (expr == true) -> (expr) (op.kind() == Operator::Kind::NEQ && !rightVal)) { // (expr != false) -> (expr) return left.clone(pos); } return nullptr; } static std::unique_ptr short_circuit_boolean(Position pos, const Expression& left, Operator op, const Expression& right) { bool leftVal = left.as().boolValue(); // When the literal is on the left, we can sometimes eliminate the other expression entirely. if ((op.kind() == Operator::Kind::LOGICALAND && !leftVal) || // (false && expr) -> (false) (op.kind() == Operator::Kind::LOGICALOR && leftVal)) { // (true || expr) -> (true) return left.clone(pos); } // We can't eliminate the right-side expression via short-circuit, but we might still be able to // simplify away a no-op expression. return eliminate_no_op_boolean(pos, right, op, left); } static std::unique_ptr simplify_constant_equality(const Context& context, Position pos, const Expression& left, Operator op, const Expression& right) { if (op.kind() == Operator::Kind::EQEQ || op.kind() == Operator::Kind::NEQ) { bool equality = (op.kind() == Operator::Kind::EQEQ); switch (left.compareConstant(right)) { case Expression::ComparisonResult::kNotEqual: equality = !equality; [[fallthrough]]; case Expression::ComparisonResult::kEqual: return Literal::MakeBool(context, pos, equality); case Expression::ComparisonResult::kUnknown: break; } } return nullptr; } static std::unique_ptr simplify_matrix_multiplication(const Context& context, Position pos, const Expression& left, const Expression& right, int leftColumns, int leftRows, int rightColumns, int rightRows) { const Type& componentType = left.type().componentType(); SkASSERT(componentType.matches(right.type().componentType())); // Fetch the left matrix. double leftVals[4][4]; for (int c = 0; c < leftColumns; ++c) { for (int r = 0; r < leftRows; ++r) { leftVals[c][r] = *left.getConstantValue((c * leftRows) + r); } } // Fetch the right matrix. double rightVals[4][4]; for (int c = 0; c < rightColumns; ++c) { for (int r = 0; r < rightRows; ++r) { rightVals[c][r] = *right.getConstantValue((c * rightRows) + r); } } SkASSERT(leftColumns == rightRows); int outColumns = rightColumns, outRows = leftRows; ExpressionArray args; args.reserve_back(outColumns * outRows); for (int c = 0; c < outColumns; ++c) { for (int r = 0; r < outRows; ++r) { // Compute a dot product for this position. double val = 0; for (int dotIdx = 0; dotIdx < leftColumns; ++dotIdx) { val += leftVals[dotIdx][r] * rightVals[c][dotIdx]; } args.push_back(Literal::Make(pos, val, &componentType)); } } if (outColumns == 1) { // Matrix-times-vector conceptually makes a 1-column N-row matrix, but we return vecN. std::swap(outColumns, outRows); } const Type& resultType = componentType.toCompound(context, outColumns, outRows); return ConstructorCompound::Make(context, pos, resultType, std::move(args)); } static std::unique_ptr simplify_matrix_times_matrix(const Context& context, Position pos, const Expression& left, const Expression& right) { const Type& leftType = left.type(); const Type& rightType = right.type(); SkASSERT(leftType.isMatrix()); SkASSERT(rightType.isMatrix()); return simplify_matrix_multiplication(context, pos, left, right, leftType.columns(), leftType.rows(), rightType.columns(), rightType.rows()); } static std::unique_ptr simplify_vector_times_matrix(const Context& context, Position pos, const Expression& left, const Expression& right) { const Type& leftType = left.type(); const Type& rightType = right.type(); SkASSERT(leftType.isVector()); SkASSERT(rightType.isMatrix()); return simplify_matrix_multiplication(context, pos, left, right, /*leftColumns=*/leftType.columns(), /*leftRows=*/1, rightType.columns(), rightType.rows()); } static std::unique_ptr simplify_matrix_times_vector(const Context& context, Position pos, const Expression& left, const Expression& right) { const Type& leftType = left.type(); const Type& rightType = right.type(); SkASSERT(leftType.isMatrix()); SkASSERT(rightType.isVector()); return simplify_matrix_multiplication(context, pos, left, right, leftType.columns(), leftType.rows(), /*rightColumns=*/1, /*rightRows=*/rightType.columns()); } static std::unique_ptr simplify_componentwise(const Context& context, Position pos, const Expression& left, Operator op, const Expression& right) { SkASSERT(is_vec_or_mat(left.type())); SkASSERT(left.type().matches(right.type())); const Type& type = left.type(); // Handle equality operations: == != if (std::unique_ptr result = simplify_constant_equality(context, pos, left, op, right)) { return result; } // Handle floating-point arithmetic: + - * / using FoldFn = double (*)(double, double); FoldFn foldFn; switch (op.kind()) { case Operator::Kind::PLUS: foldFn = +[](double a, double b) { return a + b; }; break; case Operator::Kind::MINUS: foldFn = +[](double a, double b) { return a - b; }; break; case Operator::Kind::STAR: foldFn = +[](double a, double b) { return a * b; }; break; case Operator::Kind::SLASH: foldFn = +[](double a, double b) { return a / b; }; break; default: return nullptr; } const Type& componentType = type.componentType(); SkASSERT(componentType.isNumber()); double minimumValue = componentType.minimumValue(); double maximumValue = componentType.maximumValue(); ExpressionArray args; int numSlots = type.slotCount(); args.reserve_back(numSlots); for (int i = 0; i < numSlots; i++) { double value = foldFn(*left.getConstantValue(i), *right.getConstantValue(i)); if (value < minimumValue || value > maximumValue) { return nullptr; } args.push_back(Literal::Make(pos, value, &componentType)); } return ConstructorCompound::Make(context, pos, type, std::move(args)); } static std::unique_ptr splat_scalar(const Context& context, const Expression& scalar, const Type& type) { if (type.isVector()) { return ConstructorSplat::Make(context, scalar.fPosition, type, scalar.clone()); } if (type.isMatrix()) { int numSlots = type.slotCount(); ExpressionArray splatMatrix; splatMatrix.reserve_back(numSlots); for (int index = 0; index < numSlots; ++index) { splatMatrix.push_back(scalar.clone()); } return ConstructorCompound::Make(context, scalar.fPosition, type, std::move(splatMatrix)); } SkDEBUGFAILF("unsupported type %s", type.description().c_str()); return nullptr; } static std::unique_ptr cast_expression(const Context& context, Position pos, const Expression& expr, const Type& type) { SkASSERT(type.componentType().matches(expr.type().componentType())); if (expr.type().isScalar()) { if (type.isMatrix()) { return ConstructorDiagonalMatrix::Make(context, pos, type, expr.clone()); } if (type.isVector()) { return ConstructorSplat::Make(context, pos, type, expr.clone()); } } if (type.matches(expr.type())) { return expr.clone(pos); } // We can't cast matrices into vectors or vice-versa. return nullptr; } static std::unique_ptr zero_expression(const Context& context, Position pos, const Type& type) { std::unique_ptr zero = Literal::Make(pos, 0.0, &type.componentType()); if (type.isScalar()) { return zero; } if (type.isVector()) { return ConstructorSplat::Make(context, pos, type, std::move(zero)); } if (type.isMatrix()) { return ConstructorDiagonalMatrix::Make(context, pos, type, std::move(zero)); } SkDEBUGFAILF("unsupported type %s", type.description().c_str()); return nullptr; } static std::unique_ptr negate_expression(const Context& context, Position pos, const Expression& expr, const Type& type) { std::unique_ptr ctor = cast_expression(context, pos, expr, type); return ctor ? PrefixExpression::Make(context, pos, Operator::Kind::MINUS, std::move(ctor)) : nullptr; } bool ConstantFolder::GetConstantInt(const Expression& value, SKSL_INT* out) { const Expression* expr = GetConstantValueForVariable(value); if (!expr->isIntLiteral()) { return false; } *out = expr->as().intValue(); return true; } bool ConstantFolder::GetConstantValue(const Expression& value, double* out) { const Expression* expr = GetConstantValueForVariable(value); if (!expr->is()) { return false; } *out = expr->as().value(); return true; } static bool contains_constant_zero(const Expression& expr) { int numSlots = expr.type().slotCount(); for (int index = 0; index < numSlots; ++index) { std::optional slotVal = expr.getConstantValue(index); if (slotVal.has_value() && *slotVal == 0.0) { return true; } } return false; } // Returns true if the expression contains `value` in every slot. static bool is_constant_splat(const Expression& expr, double value) { int numSlots = expr.type().slotCount(); for (int index = 0; index < numSlots; ++index) { std::optional slotVal = expr.getConstantValue(index); if (!slotVal.has_value() || *slotVal != value) { return false; } } return true; } // Returns true if the expression is a square diagonal matrix containing `value`. static bool is_constant_diagonal(const Expression& expr, double value) { SkASSERT(expr.type().isMatrix()); int columns = expr.type().columns(); int rows = expr.type().rows(); if (columns != rows) { return false; } int slotIdx = 0; for (int c = 0; c < columns; ++c) { for (int r = 0; r < rows; ++r) { double expectation = (c == r) ? value : 0; std::optional slotVal = expr.getConstantValue(slotIdx++); if (!slotVal.has_value() || *slotVal != expectation) { return false; } } } return true; } // Returns true if the expression is a scalar, vector, or diagonal matrix containing `value`. static bool is_constant_value(const Expression& expr, double value) { return expr.type().isMatrix() ? is_constant_diagonal(expr, value) : is_constant_splat(expr, value); } // The expression represents the right-hand side of a division op. If the division can be // strength-reduced into multiplication by a reciprocal, returns that reciprocal as an expression. // Note that this only supports literal values with safe-to-use reciprocals, and returns null if // Expression contains anything else. static std::unique_ptr make_reciprocal_expression(const Context& context, const Expression& right) { if (right.type().isMatrix() || !right.type().componentType().isFloat()) { return nullptr; } // Verify that each slot contains a finite, non-zero literal, take its reciprocal. int nslots = right.type().slotCount(); SkSTArray<4, double> values; for (int index = 0; index < nslots; ++index) { std::optional value = right.getConstantValue(index); if (!value) { return nullptr; } *value = sk_ieee_double_divide(1.0, *value); if (*value >= -FLT_MAX && *value <= FLT_MAX && *value != 0.0) { // The reciprocal can be represented safely as a finite 32-bit float. values.push_back(*value); } else { // The value is outside the 32-bit float range, or is NaN; do not optimize. return nullptr; } } // Convert our reciprocal values to Literals. ExpressionArray exprs; exprs.reserve_back(nslots); for (double value : values) { exprs.push_back(Literal::Make(right.fPosition, value, &right.type().componentType())); } // Turn the expression array into a compound constructor. (If this is a single-slot expression, // this will return the literal as-is.) return ConstructorCompound::Make(context, right.fPosition, right.type(), std::move(exprs)); } static bool error_on_divide_by_zero(const Context& context, Position pos, Operator op, const Expression& right) { switch (op.kind()) { case Operator::Kind::SLASH: case Operator::Kind::SLASHEQ: case Operator::Kind::PERCENT: case Operator::Kind::PERCENTEQ: if (contains_constant_zero(right)) { context.fErrors->error(pos, "division by zero"); return true; } return false; default: return false; } } const Expression* ConstantFolder::GetConstantValueOrNullForVariable(const Expression& inExpr) { for (const Expression* expr = &inExpr;;) { if (!expr->is()) { break; } const VariableReference& varRef = expr->as(); if (varRef.refKind() != VariableRefKind::kRead) { break; } const Variable& var = *varRef.variable(); if (!(var.modifiers().fFlags & Modifiers::kConst_Flag)) { break; } expr = var.initialValue(); if (!expr) { // Function parameters can be const but won't have an initial value. break; } if (Analysis::IsCompileTimeConstant(*expr)) { return expr; } } // We didn't find a compile-time constant at the end. return nullptr; } const Expression* ConstantFolder::GetConstantValueForVariable(const Expression& inExpr) { const Expression* expr = GetConstantValueOrNullForVariable(inExpr); return expr ? expr : &inExpr; } std::unique_ptr ConstantFolder::MakeConstantValueForVariable( Position pos, std::unique_ptr inExpr) { const Expression* expr = GetConstantValueOrNullForVariable(*inExpr); return expr ? expr->clone(pos) : std::move(inExpr); } static bool is_scalar_op_matrix(const Expression& left, const Expression& right) { return left.type().isScalar() && right.type().isMatrix(); } static bool is_matrix_op_scalar(const Expression& left, const Expression& right) { return is_scalar_op_matrix(right, left); } static std::unique_ptr simplify_arithmetic(const Context& context, Position pos, const Expression& left, Operator op, const Expression& right, const Type& resultType) { switch (op.kind()) { case Operator::Kind::PLUS: if (!is_scalar_op_matrix(left, right) && is_constant_splat(right, 0.0)) { // x + 0 if (std::unique_ptr expr = cast_expression(context, pos, left, resultType)) { return expr; } } if (!is_matrix_op_scalar(left, right) && is_constant_splat(left, 0.0)) { // 0 + x if (std::unique_ptr expr = cast_expression(context, pos, right, resultType)) { return expr; } } break; case Operator::Kind::STAR: if (is_constant_value(right, 1.0)) { // x * 1 if (std::unique_ptr expr = cast_expression(context, pos, left, resultType)) { return expr; } } if (is_constant_value(left, 1.0)) { // 1 * x if (std::unique_ptr expr = cast_expression(context, pos, right, resultType)) { return expr; } } if (is_constant_value(right, 0.0) && !Analysis::HasSideEffects(left)) { // x * 0 return zero_expression(context, pos, resultType); } if (is_constant_value(left, 0.0) && !Analysis::HasSideEffects(right)) { // 0 * x return zero_expression(context, pos, resultType); } if (is_constant_value(right, -1.0)) { // x * -1 (to `-x`) if (std::unique_ptr expr = negate_expression(context, pos, left, resultType)) { return expr; } } if (is_constant_value(left, -1.0)) { // -1 * x (to `-x`) if (std::unique_ptr expr = negate_expression(context, pos, right, resultType)) { return expr; } } break; case Operator::Kind::MINUS: if (!is_scalar_op_matrix(left, right) && is_constant_splat(right, 0.0)) { // x - 0 if (std::unique_ptr expr = cast_expression(context, pos, left, resultType)) { return expr; } } if (!is_matrix_op_scalar(left, right) && is_constant_splat(left, 0.0)) { // 0 - x if (std::unique_ptr expr = negate_expression(context, pos, right, resultType)) { return expr; } } break; case Operator::Kind::SLASH: if (!is_scalar_op_matrix(left, right) && is_constant_splat(right, 1.0)) { // x / 1 if (std::unique_ptr expr = cast_expression(context, pos, left, resultType)) { return expr; } } if (!left.type().isMatrix()) { // convert `x / 2` into `x * 0.5` if (std::unique_ptr expr = make_reciprocal_expression(context, right)) { return BinaryExpression::Make(context, pos, left.clone(), Operator::Kind::STAR, std::move(expr)); } } break; case Operator::Kind::PLUSEQ: case Operator::Kind::MINUSEQ: if (is_constant_splat(right, 0.0)) { // x += 0, x -= 0 if (std::unique_ptr var = cast_expression(context, pos, left, resultType)) { Analysis::UpdateVariableRefKind(var.get(), VariableRefKind::kRead); return var; } } break; case Operator::Kind::STAREQ: if (is_constant_value(right, 1.0)) { // x *= 1 if (std::unique_ptr var = cast_expression(context, pos, left, resultType)) { Analysis::UpdateVariableRefKind(var.get(), VariableRefKind::kRead); return var; } } break; case Operator::Kind::SLASHEQ: if (is_constant_splat(right, 1.0)) { // x /= 1 if (std::unique_ptr var = cast_expression(context, pos, left, resultType)) { Analysis::UpdateVariableRefKind(var.get(), VariableRefKind::kRead); return var; } } if (std::unique_ptr expr = make_reciprocal_expression(context, right)) { return BinaryExpression::Make(context, pos, left.clone(), Operator::Kind::STAREQ, std::move(expr)); } break; default: break; } return nullptr; } // The expression must be scalar, and represents the right-hand side of a division op. It can // contain anything, not just literal values. This returns the binary expression `1.0 / expr`. The // expression might be further simplified by the constant folding, if possible. static std::unique_ptr one_over_scalar(const Context& context, const Expression& right) { SkASSERT(right.type().isScalar()); Position pos = right.fPosition; return BinaryExpression::Make(context, pos, Literal::Make(pos, 1.0, &right.type()), Operator::Kind::SLASH, right.clone()); } static std::unique_ptr simplify_matrix_division(const Context& context, Position pos, const Expression& left, Operator op, const Expression& right, const Type& resultType) { // Convert matrix-over-scalar `x /= y` into `x *= (1.0 / y)`. This generates better // code in SPIR-V and Metal, and should be roughly equivalent elsewhere. switch (op.kind()) { case OperatorKind::SLASH: case OperatorKind::SLASHEQ: if (left.type().isMatrix() && right.type().isScalar()) { Operator multiplyOp = op.isAssignment() ? OperatorKind::STAREQ : OperatorKind::STAR; return BinaryExpression::Make(context, pos, left.clone(), multiplyOp, one_over_scalar(context, right)); } break; default: break; } return nullptr; } static std::unique_ptr fold_expression(Position pos, double result, const Type* resultType) { if (resultType->isNumber()) { if (result >= resultType->minimumValue() && result <= resultType->maximumValue()) { // This result will fit inside its type. } else { // The value is outside the range or is NaN (all if-checks fail); do not optimize. return nullptr; } } return Literal::Make(pos, result, resultType); } std::unique_ptr ConstantFolder::Simplify(const Context& context, Position pos, const Expression& leftExpr, Operator op, const Expression& rightExpr, const Type& resultType) { // Replace constant variables with their literal values. const Expression* left = GetConstantValueForVariable(leftExpr); const Expression* right = GetConstantValueForVariable(rightExpr); // If this is the assignment operator, and both sides are the same trivial expression, this is // self-assignment (i.e., `var = var`) and can be reduced to just a variable reference (`var`). // This can happen when other parts of the assignment are optimized away. if (op.kind() == Operator::Kind::EQ && Analysis::IsSameExpressionTree(*left, *right)) { return right->clone(pos); } // Simplify the expression when both sides are constant Boolean literals. if (left->isBoolLiteral() && right->isBoolLiteral()) { bool leftVal = left->as().boolValue(); bool rightVal = right->as().boolValue(); bool result; switch (op.kind()) { case Operator::Kind::LOGICALAND: result = leftVal && rightVal; break; case Operator::Kind::LOGICALOR: result = leftVal || rightVal; break; case Operator::Kind::LOGICALXOR: result = leftVal ^ rightVal; break; case Operator::Kind::EQEQ: result = leftVal == rightVal; break; case Operator::Kind::NEQ: result = leftVal != rightVal; break; default: return nullptr; } return Literal::MakeBool(context, pos, result); } // If the left side is a Boolean literal, apply short-circuit optimizations. if (left->isBoolLiteral()) { return short_circuit_boolean(pos, *left, op, *right); } // If the right side is a Boolean literal... if (right->isBoolLiteral()) { // ... and the left side has no side effects... if (!Analysis::HasSideEffects(*left)) { // We can reverse the expressions and short-circuit optimizations are still valid. return short_circuit_boolean(pos, *right, op, *left); } // We can't use short-circuiting, but we can still optimize away no-op Boolean expressions. return eliminate_no_op_boolean(pos, *left, op, *right); } if (op.kind() == Operator::Kind::EQEQ && Analysis::IsSameExpressionTree(*left, *right)) { // With == comparison, if both sides are the same trivial expression, this is self- // comparison and is always true. (We are not concerned with NaN.) return Literal::MakeBool(context, pos, /*value=*/true); } if (op.kind() == Operator::Kind::NEQ && Analysis::IsSameExpressionTree(*left, *right)) { // With != comparison, if both sides are the same trivial expression, this is self- // comparison and is always false. (We are not concerned with NaN.) return Literal::MakeBool(context, pos, /*value=*/false); } if (error_on_divide_by_zero(context, pos, op, *right)) { return nullptr; } // Perform full constant folding when both sides are compile-time constants. const Type& leftType = left->type(); const Type& rightType = right->type(); bool leftSideIsConstant = Analysis::IsCompileTimeConstant(*left); bool rightSideIsConstant = Analysis::IsCompileTimeConstant(*right); if (leftSideIsConstant && rightSideIsConstant) { // Handle pairs of integer literals. if (left->isIntLiteral() && right->isIntLiteral()) { using SKSL_UINT = uint64_t; SKSL_INT leftVal = left->as().intValue(); SKSL_INT rightVal = right->as().intValue(); // Note that fold_expression returns null if the result would overflow its type. #define RESULT(Op) fold_expression(pos, (SKSL_INT)(leftVal) Op \ (SKSL_INT)(rightVal), &resultType) #define URESULT(Op) fold_expression(pos, (SKSL_INT)((SKSL_UINT)(leftVal) Op \ (SKSL_UINT)(rightVal)), &resultType) switch (op.kind()) { case Operator::Kind::PLUS: return URESULT(+); case Operator::Kind::MINUS: return URESULT(-); case Operator::Kind::STAR: return URESULT(*); case Operator::Kind::SLASH: if (leftVal == std::numeric_limits::min() && rightVal == -1) { context.fErrors->error(pos, "arithmetic overflow"); return nullptr; } return RESULT(/); case Operator::Kind::PERCENT: if (leftVal == std::numeric_limits::min() && rightVal == -1) { context.fErrors->error(pos, "arithmetic overflow"); return nullptr; } return RESULT(%); case Operator::Kind::BITWISEAND: return RESULT(&); case Operator::Kind::BITWISEOR: return RESULT(|); case Operator::Kind::BITWISEXOR: return RESULT(^); case Operator::Kind::EQEQ: return RESULT(==); case Operator::Kind::NEQ: return RESULT(!=); case Operator::Kind::GT: return RESULT(>); case Operator::Kind::GTEQ: return RESULT(>=); case Operator::Kind::LT: return RESULT(<); case Operator::Kind::LTEQ: return RESULT(<=); case Operator::Kind::SHL: if (rightVal >= 0 && rightVal <= 31) { // Left-shifting a negative (or really, any signed) value is undefined // behavior in C++, but not in GLSL. Do the shift on unsigned values to avoid // triggering an UBSAN error. return URESULT(<<); } context.fErrors->error(pos, "shift value out of range"); return nullptr; case Operator::Kind::SHR: if (rightVal >= 0 && rightVal <= 31) { return RESULT(>>); } context.fErrors->error(pos, "shift value out of range"); return nullptr; default: return nullptr; } #undef RESULT #undef URESULT } // Handle pairs of floating-point literals. if (left->isFloatLiteral() && right->isFloatLiteral()) { SKSL_FLOAT leftVal = left->as().floatValue(); SKSL_FLOAT rightVal = right->as().floatValue(); #define RESULT(Op) fold_expression(pos, leftVal Op rightVal, &resultType) switch (op.kind()) { case Operator::Kind::PLUS: return RESULT(+); case Operator::Kind::MINUS: return RESULT(-); case Operator::Kind::STAR: return RESULT(*); case Operator::Kind::SLASH: return RESULT(/); case Operator::Kind::EQEQ: return RESULT(==); case Operator::Kind::NEQ: return RESULT(!=); case Operator::Kind::GT: return RESULT(>); case Operator::Kind::GTEQ: return RESULT(>=); case Operator::Kind::LT: return RESULT(<); case Operator::Kind::LTEQ: return RESULT(<=); default: return nullptr; } #undef RESULT } // Perform matrix multiplication. if (op.kind() == Operator::Kind::STAR) { if (leftType.isMatrix() && rightType.isMatrix()) { return simplify_matrix_times_matrix(context, pos, *left, *right); } if (leftType.isVector() && rightType.isMatrix()) { return simplify_vector_times_matrix(context, pos, *left, *right); } if (leftType.isMatrix() && rightType.isVector()) { return simplify_matrix_times_vector(context, pos, *left, *right); } } // Perform constant folding on pairs of vectors/matrices. if (is_vec_or_mat(leftType) && leftType.matches(rightType)) { return simplify_componentwise(context, pos, *left, op, *right); } // Perform constant folding on vectors/matrices against scalars, e.g.: half4(2) + 2 if (rightType.isScalar() && is_vec_or_mat(leftType) && leftType.componentType().matches(rightType)) { return simplify_componentwise(context, pos, *left, op, *splat_scalar(context, *right, left->type())); } // Perform constant folding on scalars against vectors/matrices, e.g.: 2 + half4(2) if (leftType.isScalar() && is_vec_or_mat(rightType) && rightType.componentType().matches(leftType)) { return simplify_componentwise(context, pos, *splat_scalar(context, *left, right->type()), op, *right); } // Perform constant folding on pairs of matrices, arrays or structs. if ((leftType.isMatrix() && rightType.isMatrix()) || (leftType.isArray() && rightType.isArray()) || (leftType.isStruct() && rightType.isStruct())) { return simplify_constant_equality(context, pos, *left, op, *right); } } if (context.fConfig->fSettings.fOptimize) { // If just one side is constant, we might still be able to simplify arithmetic expressions // like `x * 1`, `x *= 1`, `x + 0`, `x * 0`, `0 / x`, etc. if (leftSideIsConstant || rightSideIsConstant) { if (std::unique_ptr expr = simplify_arithmetic(context, pos, *left, op, *right, resultType)) { return expr; } } // We can simplify some forms of matrix division even when neither side is constant. if (std::unique_ptr expr = simplify_matrix_division(context, pos, *left, op, *right, resultType)) { return expr; } } // We aren't able to constant-fold. return nullptr; } } // namespace SkSL