1 /*
2 * Copyright 2020 Google LLC
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8 #include "src/sksl/SkSLConstantFolder.h"
9
10 #include "include/core/SkTypes.h"
11 #include "include/private/base/SkFloatingPoint.h"
12 #include "include/private/base/SkTArray.h"
13 #include "src/sksl/SkSLAnalysis.h"
14 #include "src/sksl/SkSLContext.h"
15 #include "src/sksl/SkSLErrorReporter.h"
16 #include "src/sksl/SkSLPosition.h"
17 #include "src/sksl/SkSLProgramSettings.h"
18 #include "src/sksl/ir/SkSLBinaryExpression.h"
19 #include "src/sksl/ir/SkSLConstructorCompound.h"
20 #include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
21 #include "src/sksl/ir/SkSLConstructorSplat.h"
22 #include "src/sksl/ir/SkSLExpression.h"
23 #include "src/sksl/ir/SkSLLiteral.h"
24 #include "src/sksl/ir/SkSLModifierFlags.h"
25 #include "src/sksl/ir/SkSLPrefixExpression.h"
26 #include "src/sksl/ir/SkSLType.h"
27 #include "src/sksl/ir/SkSLVariable.h"
28 #include "src/sksl/ir/SkSLVariableReference.h"
29
30 #include <cstdint>
31 #include <float.h>
32 #include <limits>
33 #include <optional>
34 #include <string>
35 #include <utility>
36
37 using namespace skia_private;
38
39 namespace SkSL {
40
is_vec_or_mat(const Type & type)41 static bool is_vec_or_mat(const Type& type) {
42 switch (type.typeKind()) {
43 case Type::TypeKind::kMatrix:
44 case Type::TypeKind::kVector:
45 return true;
46
47 default:
48 return false;
49 }
50 }
51
eliminate_no_op_boolean(Position pos,const Expression & left,Operator op,const Expression & right)52 static std::unique_ptr<Expression> eliminate_no_op_boolean(Position pos,
53 const Expression& left,
54 Operator op,
55 const Expression& right) {
56 bool rightVal = right.as<Literal>().boolValue();
57
58 // Detect no-op Boolean expressions and optimize them away.
59 if ((op.kind() == Operator::Kind::LOGICALAND && rightVal) || // (expr && true) -> (expr)
60 (op.kind() == Operator::Kind::LOGICALOR && !rightVal) || // (expr || false) -> (expr)
61 (op.kind() == Operator::Kind::LOGICALXOR && !rightVal) || // (expr ^^ false) -> (expr)
62 (op.kind() == Operator::Kind::EQEQ && rightVal) || // (expr == true) -> (expr)
63 (op.kind() == Operator::Kind::NEQ && !rightVal)) { // (expr != false) -> (expr)
64
65 return left.clone(pos);
66 }
67
68 return nullptr;
69 }
70
short_circuit_boolean(Position pos,const Expression & left,Operator op,const Expression & right)71 static std::unique_ptr<Expression> short_circuit_boolean(Position pos,
72 const Expression& left,
73 Operator op,
74 const Expression& right) {
75 bool leftVal = left.as<Literal>().boolValue();
76
77 // When the literal is on the left, we can sometimes eliminate the other expression entirely.
78 if ((op.kind() == Operator::Kind::LOGICALAND && !leftVal) || // (false && expr) -> (false)
79 (op.kind() == Operator::Kind::LOGICALOR && leftVal)) { // (true || expr) -> (true)
80
81 return left.clone(pos);
82 }
83
84 // We can't eliminate the right-side expression via short-circuit, but we might still be able to
85 // simplify away a no-op expression.
86 return eliminate_no_op_boolean(pos, right, op, left);
87 }
88
simplify_constant_equality(const Context & context,Position pos,const Expression & left,Operator op,const Expression & right)89 static std::unique_ptr<Expression> simplify_constant_equality(const Context& context,
90 Position pos,
91 const Expression& left,
92 Operator op,
93 const Expression& right) {
94 if (op.kind() == Operator::Kind::EQEQ || op.kind() == Operator::Kind::NEQ) {
95 bool equality = (op.kind() == Operator::Kind::EQEQ);
96
97 switch (left.compareConstant(right)) {
98 case Expression::ComparisonResult::kNotEqual:
99 equality = !equality;
100 [[fallthrough]];
101
102 case Expression::ComparisonResult::kEqual:
103 return Literal::MakeBool(context, pos, equality);
104
105 case Expression::ComparisonResult::kUnknown:
106 break;
107 }
108 }
109 return nullptr;
110 }
111
simplify_matrix_multiplication(const Context & context,Position pos,const Expression & left,const Expression & right,int leftColumns,int leftRows,int rightColumns,int rightRows)112 static std::unique_ptr<Expression> simplify_matrix_multiplication(const Context& context,
113 Position pos,
114 const Expression& left,
115 const Expression& right,
116 int leftColumns,
117 int leftRows,
118 int rightColumns,
119 int rightRows) {
120 const Type& componentType = left.type().componentType();
121 SkASSERT(componentType.matches(right.type().componentType()));
122
123 // Fetch the left matrix.
124 double leftVals[4][4];
125 for (int c = 0; c < leftColumns; ++c) {
126 for (int r = 0; r < leftRows; ++r) {
127 leftVals[c][r] = *left.getConstantValue((c * leftRows) + r);
128 }
129 }
130 // Fetch the right matrix.
131 double rightVals[4][4];
132 for (int c = 0; c < rightColumns; ++c) {
133 for (int r = 0; r < rightRows; ++r) {
134 rightVals[c][r] = *right.getConstantValue((c * rightRows) + r);
135 }
136 }
137
138 SkASSERT(leftColumns == rightRows);
139 int outColumns = rightColumns,
140 outRows = leftRows;
141
142 double args[16];
143 int argIndex = 0;
144 for (int c = 0; c < outColumns; ++c) {
145 for (int r = 0; r < outRows; ++r) {
146 // Compute a dot product for this position.
147 double val = 0;
148 for (int dotIdx = 0; dotIdx < leftColumns; ++dotIdx) {
149 val += leftVals[dotIdx][r] * rightVals[c][dotIdx];
150 }
151
152 if (val >= -FLT_MAX && val <= FLT_MAX) {
153 args[argIndex++] = val;
154 } else {
155 // The value is outside the 32-bit float range, or is NaN; do not optimize.
156 return nullptr;
157 }
158 }
159 }
160
161 if (outColumns == 1) {
162 // Matrix-times-vector conceptually makes a 1-column N-row matrix, but we return vecN.
163 std::swap(outColumns, outRows);
164 }
165
166 const Type& resultType = componentType.toCompound(context, outColumns, outRows);
167 return ConstructorCompound::MakeFromConstants(context, pos, resultType, args);
168 }
169
simplify_matrix_times_matrix(const Context & context,Position pos,const Expression & left,const Expression & right)170 static std::unique_ptr<Expression> simplify_matrix_times_matrix(const Context& context,
171 Position pos,
172 const Expression& left,
173 const Expression& right) {
174 const Type& leftType = left.type();
175 const Type& rightType = right.type();
176
177 SkASSERT(leftType.isMatrix());
178 SkASSERT(rightType.isMatrix());
179
180 return simplify_matrix_multiplication(context, pos, left, right,
181 leftType.columns(), leftType.rows(),
182 rightType.columns(), rightType.rows());
183 }
184
simplify_vector_times_matrix(const Context & context,Position pos,const Expression & left,const Expression & right)185 static std::unique_ptr<Expression> simplify_vector_times_matrix(const Context& context,
186 Position pos,
187 const Expression& left,
188 const Expression& right) {
189 const Type& leftType = left.type();
190 const Type& rightType = right.type();
191
192 SkASSERT(leftType.isVector());
193 SkASSERT(rightType.isMatrix());
194
195 return simplify_matrix_multiplication(context, pos, left, right,
196 /*leftColumns=*/leftType.columns(), /*leftRows=*/1,
197 rightType.columns(), rightType.rows());
198 }
199
simplify_matrix_times_vector(const Context & context,Position pos,const Expression & left,const Expression & right)200 static std::unique_ptr<Expression> simplify_matrix_times_vector(const Context& context,
201 Position pos,
202 const Expression& left,
203 const Expression& right) {
204 const Type& leftType = left.type();
205 const Type& rightType = right.type();
206
207 SkASSERT(leftType.isMatrix());
208 SkASSERT(rightType.isVector());
209
210 return simplify_matrix_multiplication(context, pos, left, right,
211 leftType.columns(), leftType.rows(),
212 /*rightColumns=*/1, /*rightRows=*/rightType.columns());
213 }
214
simplify_componentwise(const Context & context,Position pos,const Expression & left,Operator op,const Expression & right)215 static std::unique_ptr<Expression> simplify_componentwise(const Context& context,
216 Position pos,
217 const Expression& left,
218 Operator op,
219 const Expression& right) {
220 SkASSERT(is_vec_or_mat(left.type()));
221 SkASSERT(left.type().matches(right.type()));
222 const Type& type = left.type();
223
224 // Handle equality operations: == !=
225 if (std::unique_ptr<Expression> result = simplify_constant_equality(context, pos, left, op,
226 right)) {
227 return result;
228 }
229
230 // Handle floating-point arithmetic: + - * /
231 using FoldFn = double (*)(double, double);
232 FoldFn foldFn;
233 switch (op.kind()) {
234 case Operator::Kind::PLUS: foldFn = +[](double a, double b) { return a + b; }; break;
235 case Operator::Kind::MINUS: foldFn = +[](double a, double b) { return a - b; }; break;
236 case Operator::Kind::STAR: foldFn = +[](double a, double b) { return a * b; }; break;
237 case Operator::Kind::SLASH: foldFn = +[](double a, double b) { return a / b; }; break;
238 default:
239 return nullptr;
240 }
241
242 const Type& componentType = type.componentType();
243 SkASSERT(componentType.isNumber());
244
245 double minimumValue = componentType.minimumValue();
246 double maximumValue = componentType.maximumValue();
247
248 double args[16];
249 int numSlots = type.slotCount();
250 for (int i = 0; i < numSlots; i++) {
251 double value = foldFn(*left.getConstantValue(i), *right.getConstantValue(i));
252 if (value < minimumValue || value > maximumValue) {
253 return nullptr;
254 }
255 args[i] = value;
256 }
257 return ConstructorCompound::MakeFromConstants(context, pos, type, args);
258 }
259
splat_scalar(const Context & context,const Expression & scalar,const Type & type)260 static std::unique_ptr<Expression> splat_scalar(const Context& context,
261 const Expression& scalar,
262 const Type& type) {
263 if (type.isVector()) {
264 return ConstructorSplat::Make(context, scalar.fPosition, type, scalar.clone());
265 }
266 if (type.isMatrix()) {
267 int numSlots = type.slotCount();
268 ExpressionArray splatMatrix;
269 splatMatrix.reserve_exact(numSlots);
270 for (int index = 0; index < numSlots; ++index) {
271 splatMatrix.push_back(scalar.clone());
272 }
273 return ConstructorCompound::Make(context, scalar.fPosition, type, std::move(splatMatrix));
274 }
275 SkDEBUGFAILF("unsupported type %s", type.description().c_str());
276 return nullptr;
277 }
278
cast_expression(const Context & context,Position pos,const Expression & expr,const Type & type)279 static std::unique_ptr<Expression> cast_expression(const Context& context,
280 Position pos,
281 const Expression& expr,
282 const Type& type) {
283 SkASSERT(type.componentType().matches(expr.type().componentType()));
284 if (expr.type().isScalar()) {
285 if (type.isMatrix()) {
286 return ConstructorDiagonalMatrix::Make(context, pos, type, expr.clone());
287 }
288 if (type.isVector()) {
289 return ConstructorSplat::Make(context, pos, type, expr.clone());
290 }
291 }
292 if (type.matches(expr.type())) {
293 return expr.clone(pos);
294 }
295 // We can't cast matrices into vectors or vice-versa.
296 return nullptr;
297 }
298
zero_expression(const Context & context,Position pos,const Type & type)299 static std::unique_ptr<Expression> zero_expression(const Context& context,
300 Position pos,
301 const Type& type) {
302 std::unique_ptr<Expression> zero = Literal::Make(pos, 0.0, &type.componentType());
303 if (type.isScalar()) {
304 return zero;
305 }
306 if (type.isVector()) {
307 return ConstructorSplat::Make(context, pos, type, std::move(zero));
308 }
309 if (type.isMatrix()) {
310 return ConstructorDiagonalMatrix::Make(context, pos, type, std::move(zero));
311 }
312 SkDEBUGFAILF("unsupported type %s", type.description().c_str());
313 return nullptr;
314 }
315
negate_expression(const Context & context,Position pos,const Expression & expr,const Type & type)316 static std::unique_ptr<Expression> negate_expression(const Context& context,
317 Position pos,
318 const Expression& expr,
319 const Type& type) {
320 std::unique_ptr<Expression> ctor = cast_expression(context, pos, expr, type);
321 return ctor ? PrefixExpression::Make(context, pos, Operator::Kind::MINUS, std::move(ctor))
322 : nullptr;
323 }
324
GetConstantInt(const Expression & value,SKSL_INT * out)325 bool ConstantFolder::GetConstantInt(const Expression& value, SKSL_INT* out) {
326 const Expression* expr = GetConstantValueForVariable(value);
327 if (!expr->isIntLiteral()) {
328 return false;
329 }
330 *out = expr->as<Literal>().intValue();
331 return true;
332 }
333
GetConstantValue(const Expression & value,double * out)334 bool ConstantFolder::GetConstantValue(const Expression& value, double* out) {
335 const Expression* expr = GetConstantValueForVariable(value);
336 if (!expr->is<Literal>()) {
337 return false;
338 }
339 *out = expr->as<Literal>().value();
340 return true;
341 }
342
contains_constant_zero(const Expression & expr)343 static bool contains_constant_zero(const Expression& expr) {
344 int numSlots = expr.type().slotCount();
345 for (int index = 0; index < numSlots; ++index) {
346 std::optional<double> slotVal = expr.getConstantValue(index);
347 if (slotVal.has_value() && *slotVal == 0.0) {
348 return true;
349 }
350 }
351 return false;
352 }
353
IsConstantSplat(const Expression & expr,double value)354 bool ConstantFolder::IsConstantSplat(const Expression& expr, double value) {
355 int numSlots = expr.type().slotCount();
356 for (int index = 0; index < numSlots; ++index) {
357 std::optional<double> slotVal = expr.getConstantValue(index);
358 if (!slotVal.has_value() || *slotVal != value) {
359 return false;
360 }
361 }
362 return true;
363 }
364
365 // Returns true if the expression is a square diagonal matrix containing `value`.
is_constant_diagonal(const Expression & expr,double value)366 static bool is_constant_diagonal(const Expression& expr, double value) {
367 SkASSERT(expr.type().isMatrix());
368 int columns = expr.type().columns();
369 int rows = expr.type().rows();
370 if (columns != rows) {
371 return false;
372 }
373 int slotIdx = 0;
374 for (int c = 0; c < columns; ++c) {
375 for (int r = 0; r < rows; ++r) {
376 double expectation = (c == r) ? value : 0;
377 std::optional<double> slotVal = expr.getConstantValue(slotIdx++);
378 if (!slotVal.has_value() || *slotVal != expectation) {
379 return false;
380 }
381 }
382 }
383 return true;
384 }
385
386 // Returns true if the expression is a scalar, vector, or diagonal matrix containing `value`.
is_constant_value(const Expression & expr,double value)387 static bool is_constant_value(const Expression& expr, double value) {
388 return expr.type().isMatrix() ? is_constant_diagonal(expr, value)
389 : ConstantFolder::IsConstantSplat(expr, value);
390 }
391
392 // The expression represents the right-hand side of a division op. If the division can be
393 // strength-reduced into multiplication by a reciprocal, returns that reciprocal as an expression.
394 // Note that this only supports literal values with safe-to-use reciprocals, and returns null if
395 // Expression contains anything else.
make_reciprocal_expression(const Context & context,const Expression & right)396 static std::unique_ptr<Expression> make_reciprocal_expression(const Context& context,
397 const Expression& right) {
398 if (right.type().isMatrix() || !right.type().componentType().isFloat()) {
399 return nullptr;
400 }
401 // Verify that each slot contains a finite, non-zero literal, take its reciprocal.
402 double values[4];
403 int nslots = right.type().slotCount();
404 for (int index = 0; index < nslots; ++index) {
405 std::optional<double> value = right.getConstantValue(index);
406 if (!value) {
407 return nullptr;
408 }
409 *value = sk_ieee_double_divide(1.0, *value);
410 if (*value >= -FLT_MAX && *value <= FLT_MAX && *value != 0.0) {
411 // The reciprocal can be represented safely as a finite 32-bit float.
412 values[index] = *value;
413 } else {
414 // The value is outside the 32-bit float range, or is NaN; do not optimize.
415 return nullptr;
416 }
417 }
418 // Turn the expression array into a compound constructor. (If this is a single-slot expression,
419 // this will return the literal as-is.)
420 return ConstructorCompound::MakeFromConstants(context, right.fPosition, right.type(), values);
421 }
422
error_on_divide_by_zero(const Context & context,Position pos,Operator op,const Expression & right)423 static bool error_on_divide_by_zero(const Context& context, Position pos, Operator op,
424 const Expression& right) {
425 switch (op.kind()) {
426 case Operator::Kind::SLASH:
427 case Operator::Kind::SLASHEQ:
428 case Operator::Kind::PERCENT:
429 case Operator::Kind::PERCENTEQ:
430 if (contains_constant_zero(right)) {
431 context.fErrors->error(pos, "division by zero");
432 return true;
433 }
434 return false;
435 default:
436 return false;
437 }
438 }
439
GetConstantValueOrNull(const Expression & inExpr)440 const Expression* ConstantFolder::GetConstantValueOrNull(const Expression& inExpr) {
441 const Expression* expr = &inExpr;
442 while (expr->is<VariableReference>()) {
443 const VariableReference& varRef = expr->as<VariableReference>();
444 if (varRef.refKind() != VariableRefKind::kRead) {
445 return nullptr;
446 }
447 const Variable& var = *varRef.variable();
448 if (!var.modifierFlags().isConst()) {
449 return nullptr;
450 }
451 #ifdef SKSL_EXT
452 if (var.layout().fFlags & SkSL::LayoutFlag::kConstantId) {
453 return nullptr;
454 }
455 #endif
456 expr = var.initialValue();
457 if (!expr) {
458 // Generally, const variables must have initial values. However, function parameters are
459 // an exception; they can be const but won't have an initial value.
460 return nullptr;
461 }
462 }
463 return Analysis::IsCompileTimeConstant(*expr) ? expr : nullptr;
464 }
465
GetConstantValueForVariable(const Expression & inExpr)466 const Expression* ConstantFolder::GetConstantValueForVariable(const Expression& inExpr) {
467 const Expression* expr = GetConstantValueOrNull(inExpr);
468 return expr ? expr : &inExpr;
469 }
470
MakeConstantValueForVariable(Position pos,std::unique_ptr<Expression> inExpr)471 std::unique_ptr<Expression> ConstantFolder::MakeConstantValueForVariable(
472 Position pos, std::unique_ptr<Expression> inExpr) {
473 const Expression* expr = GetConstantValueOrNull(*inExpr);
474 return expr ? expr->clone(pos) : std::move(inExpr);
475 }
476
is_scalar_op_matrix(const Expression & left,const Expression & right)477 static bool is_scalar_op_matrix(const Expression& left, const Expression& right) {
478 return left.type().isScalar() && right.type().isMatrix();
479 }
480
is_matrix_op_scalar(const Expression & left,const Expression & right)481 static bool is_matrix_op_scalar(const Expression& left, const Expression& right) {
482 return is_scalar_op_matrix(right, left);
483 }
484
simplify_arithmetic(const Context & context,Position pos,const Expression & left,Operator op,const Expression & right,const Type & resultType)485 static std::unique_ptr<Expression> simplify_arithmetic(const Context& context,
486 Position pos,
487 const Expression& left,
488 Operator op,
489 const Expression& right,
490 const Type& resultType) {
491 switch (op.kind()) {
492 case Operator::Kind::PLUS:
493 if (!is_scalar_op_matrix(left, right) &&
494 ConstantFolder::IsConstantSplat(right, 0.0)) { // x + 0
495 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, left,
496 resultType)) {
497 return expr;
498 }
499 }
500 if (!is_matrix_op_scalar(left, right) &&
501 ConstantFolder::IsConstantSplat(left, 0.0)) { // 0 + x
502 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, right,
503 resultType)) {
504 return expr;
505 }
506 }
507 break;
508
509 case Operator::Kind::STAR:
510 if (is_constant_value(right, 1.0)) { // x * 1
511 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, left,
512 resultType)) {
513 return expr;
514 }
515 }
516 if (is_constant_value(left, 1.0)) { // 1 * x
517 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, right,
518 resultType)) {
519 return expr;
520 }
521 }
522 if (is_constant_value(right, 0.0) && !Analysis::HasSideEffects(left)) { // x * 0
523 return zero_expression(context, pos, resultType);
524 }
525 if (is_constant_value(left, 0.0) && !Analysis::HasSideEffects(right)) { // 0 * x
526 return zero_expression(context, pos, resultType);
527 }
528 if (is_constant_value(right, -1.0)) { // x * -1 (to `-x`)
529 if (std::unique_ptr<Expression> expr = negate_expression(context, pos, left,
530 resultType)) {
531 return expr;
532 }
533 }
534 if (is_constant_value(left, -1.0)) { // -1 * x (to `-x`)
535 if (std::unique_ptr<Expression> expr = negate_expression(context, pos, right,
536 resultType)) {
537 return expr;
538 }
539 }
540 break;
541
542 case Operator::Kind::MINUS:
543 if (!is_scalar_op_matrix(left, right) &&
544 ConstantFolder::IsConstantSplat(right, 0.0)) { // x - 0
545 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, left,
546 resultType)) {
547 return expr;
548 }
549 }
550 if (!is_matrix_op_scalar(left, right) &&
551 ConstantFolder::IsConstantSplat(left, 0.0)) { // 0 - x
552 if (std::unique_ptr<Expression> expr = negate_expression(context, pos, right,
553 resultType)) {
554 return expr;
555 }
556 }
557 break;
558
559 case Operator::Kind::SLASH:
560 if (!is_scalar_op_matrix(left, right) &&
561 ConstantFolder::IsConstantSplat(right, 1.0)) { // x / 1
562 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, left,
563 resultType)) {
564 return expr;
565 }
566 }
567 if (!left.type().isMatrix()) { // convert `x / 2` into `x * 0.5`
568 if (std::unique_ptr<Expression> expr = make_reciprocal_expression(context, right)) {
569 return BinaryExpression::Make(context, pos, left.clone(), Operator::Kind::STAR,
570 std::move(expr));
571 }
572 }
573 break;
574
575 case Operator::Kind::PLUSEQ:
576 case Operator::Kind::MINUSEQ:
577 if (ConstantFolder::IsConstantSplat(right, 0.0)) { // x += 0, x -= 0
578 if (std::unique_ptr<Expression> var = cast_expression(context, pos, left,
579 resultType)) {
580 Analysis::UpdateVariableRefKind(var.get(), VariableRefKind::kRead);
581 return var;
582 }
583 }
584 break;
585
586 case Operator::Kind::STAREQ:
587 if (is_constant_value(right, 1.0)) { // x *= 1
588 if (std::unique_ptr<Expression> var = cast_expression(context, pos, left,
589 resultType)) {
590 Analysis::UpdateVariableRefKind(var.get(), VariableRefKind::kRead);
591 return var;
592 }
593 }
594 break;
595
596 case Operator::Kind::SLASHEQ:
597 if (ConstantFolder::IsConstantSplat(right, 1.0)) { // x /= 1
598 if (std::unique_ptr<Expression> var = cast_expression(context, pos, left,
599 resultType)) {
600 Analysis::UpdateVariableRefKind(var.get(), VariableRefKind::kRead);
601 return var;
602 }
603 }
604 if (std::unique_ptr<Expression> expr = make_reciprocal_expression(context, right)) {
605 return BinaryExpression::Make(context, pos, left.clone(), Operator::Kind::STAREQ,
606 std::move(expr));
607 }
608 break;
609
610 default:
611 break;
612 }
613
614 return nullptr;
615 }
616
617 // The expression must be scalar, and represents the right-hand side of a division op. It can
618 // contain anything, not just literal values. This returns the binary expression `1.0 / expr`. The
619 // expression might be further simplified by the constant folding, if possible.
one_over_scalar(const Context & context,const Expression & right)620 static std::unique_ptr<Expression> one_over_scalar(const Context& context,
621 const Expression& right) {
622 SkASSERT(right.type().isScalar());
623 Position pos = right.fPosition;
624 return BinaryExpression::Make(context, pos,
625 Literal::Make(pos, 1.0, &right.type()),
626 Operator::Kind::SLASH,
627 right.clone());
628 }
629
simplify_matrix_division(const Context & context,Position pos,const Expression & left,Operator op,const Expression & right,const Type & resultType)630 static std::unique_ptr<Expression> simplify_matrix_division(const Context& context,
631 Position pos,
632 const Expression& left,
633 Operator op,
634 const Expression& right,
635 const Type& resultType) {
636 // Convert matrix-over-scalar `x /= y` into `x *= (1.0 / y)`. This generates better
637 // code in SPIR-V and Metal, and should be roughly equivalent elsewhere.
638 switch (op.kind()) {
639 case OperatorKind::SLASH:
640 case OperatorKind::SLASHEQ:
641 if (left.type().isMatrix() && right.type().isScalar()) {
642 Operator multiplyOp = op.isAssignment() ? OperatorKind::STAREQ
643 : OperatorKind::STAR;
644 return BinaryExpression::Make(context, pos,
645 left.clone(),
646 multiplyOp,
647 one_over_scalar(context, right));
648 }
649 break;
650
651 default:
652 break;
653 }
654
655 return nullptr;
656 }
657
fold_expression(Position pos,double result,const Type * resultType)658 static std::unique_ptr<Expression> fold_expression(Position pos,
659 double result,
660 const Type* resultType) {
661 if (resultType->isNumber()) {
662 if (result >= resultType->minimumValue() && result <= resultType->maximumValue()) {
663 // This result will fit inside its type.
664 } else {
665 // The value is outside the range or is NaN (all if-checks fail); do not optimize.
666 return nullptr;
667 }
668 }
669
670 return Literal::Make(pos, result, resultType);
671 }
672
fold_two_constants(const Context & context,Position pos,const Expression * left,Operator op,const Expression * right,const Type & resultType)673 static std::unique_ptr<Expression> fold_two_constants(const Context& context,
674 Position pos,
675 const Expression* left,
676 Operator op,
677 const Expression* right,
678 const Type& resultType) {
679 SkASSERT(Analysis::IsCompileTimeConstant(*left));
680 SkASSERT(Analysis::IsCompileTimeConstant(*right));
681 const Type& leftType = left->type();
682 const Type& rightType = right->type();
683
684 // Handle pairs of integer literals.
685 if (left->isIntLiteral() && right->isIntLiteral()) {
686 using SKSL_UINT = uint64_t;
687 SKSL_INT leftVal = left->as<Literal>().intValue();
688 SKSL_INT rightVal = right->as<Literal>().intValue();
689
690 // Note that fold_expression returns null if the result would overflow its type.
691 #define RESULT(Op) fold_expression(pos, (SKSL_INT)(leftVal) Op \
692 (SKSL_INT)(rightVal), &resultType)
693 #define URESULT(Op) fold_expression(pos, (SKSL_INT)((SKSL_UINT)(leftVal) Op \
694 (SKSL_UINT)(rightVal)), &resultType)
695 switch (op.kind()) {
696 case Operator::Kind::PLUS: return URESULT(+);
697 case Operator::Kind::MINUS: return URESULT(-);
698 case Operator::Kind::STAR: return URESULT(*);
699 case Operator::Kind::SLASH:
700 if (leftVal == std::numeric_limits<SKSL_INT>::min() && rightVal == -1) {
701 context.fErrors->error(pos, "arithmetic overflow");
702 return nullptr;
703 }
704 return RESULT(/);
705
706 case Operator::Kind::PERCENT:
707 if (leftVal == std::numeric_limits<SKSL_INT>::min() && rightVal == -1) {
708 context.fErrors->error(pos, "arithmetic overflow");
709 return nullptr;
710 }
711 return RESULT(%);
712
713 case Operator::Kind::BITWISEAND: return RESULT(&);
714 case Operator::Kind::BITWISEOR: return RESULT(|);
715 case Operator::Kind::BITWISEXOR: return RESULT(^);
716 case Operator::Kind::EQEQ: return RESULT(==);
717 case Operator::Kind::NEQ: return RESULT(!=);
718 case Operator::Kind::GT: return RESULT(>);
719 case Operator::Kind::GTEQ: return RESULT(>=);
720 case Operator::Kind::LT: return RESULT(<);
721 case Operator::Kind::LTEQ: return RESULT(<=);
722 case Operator::Kind::SHL:
723 if (rightVal >= 0 && rightVal <= 31) {
724 // Left-shifting a negative (or really, any signed) value is undefined behavior
725 // in C++, but not in GLSL. Do the shift on unsigned values to avoid triggering
726 // an UBSAN error.
727 return URESULT(<<);
728 }
729 context.fErrors->error(pos, "shift value out of range");
730 return nullptr;
731
732 case Operator::Kind::SHR:
733 if (rightVal >= 0 && rightVal <= 31) {
734 return RESULT(>>);
735 }
736 context.fErrors->error(pos, "shift value out of range");
737 return nullptr;
738
739 default:
740 break;
741 }
742 #undef RESULT
743 #undef URESULT
744
745 return nullptr;
746 }
747
748 // Handle pairs of floating-point literals.
749 if (left->isFloatLiteral() && right->isFloatLiteral()) {
750 SKSL_FLOAT leftVal = left->as<Literal>().floatValue();
751 SKSL_FLOAT rightVal = right->as<Literal>().floatValue();
752
753 #define RESULT(Op) fold_expression(pos, leftVal Op rightVal, &resultType)
754 switch (op.kind()) {
755 case Operator::Kind::PLUS: return RESULT(+);
756 case Operator::Kind::MINUS: return RESULT(-);
757 case Operator::Kind::STAR: return RESULT(*);
758 case Operator::Kind::SLASH: return RESULT(/);
759 case Operator::Kind::EQEQ: return RESULT(==);
760 case Operator::Kind::NEQ: return RESULT(!=);
761 case Operator::Kind::GT: return RESULT(>);
762 case Operator::Kind::GTEQ: return RESULT(>=);
763 case Operator::Kind::LT: return RESULT(<);
764 case Operator::Kind::LTEQ: return RESULT(<=);
765 default: break;
766 }
767 #undef RESULT
768
769 return nullptr;
770 }
771
772 // Perform matrix multiplication.
773 if (op.kind() == Operator::Kind::STAR) {
774 if (leftType.isMatrix() && rightType.isMatrix()) {
775 return simplify_matrix_times_matrix(context, pos, *left, *right);
776 }
777 if (leftType.isVector() && rightType.isMatrix()) {
778 return simplify_vector_times_matrix(context, pos, *left, *right);
779 }
780 if (leftType.isMatrix() && rightType.isVector()) {
781 return simplify_matrix_times_vector(context, pos, *left, *right);
782 }
783 }
784
785 // Perform constant folding on pairs of vectors/matrices.
786 if (is_vec_or_mat(leftType) && leftType.matches(rightType)) {
787 return simplify_componentwise(context, pos, *left, op, *right);
788 }
789
790 // Perform constant folding on vectors/matrices against scalars, e.g.: half4(2) + 2
791 if (rightType.isScalar() && is_vec_or_mat(leftType) &&
792 leftType.componentType().matches(rightType)) {
793 return simplify_componentwise(context, pos,
794 *left, op, *splat_scalar(context, *right, left->type()));
795 }
796
797 // Perform constant folding on scalars against vectors/matrices, e.g.: 2 + half4(2)
798 if (leftType.isScalar() && is_vec_or_mat(rightType) &&
799 rightType.componentType().matches(leftType)) {
800 return simplify_componentwise(context, pos,
801 *splat_scalar(context, *left, right->type()), op, *right);
802 }
803
804 // Perform constant folding on pairs of matrices, arrays or structs.
805 if ((leftType.isMatrix() && rightType.isMatrix()) ||
806 (leftType.isArray() && rightType.isArray()) ||
807 (leftType.isStruct() && rightType.isStruct())) {
808 return simplify_constant_equality(context, pos, *left, op, *right);
809 }
810
811 // We aren't able to constant-fold these expressions.
812 return nullptr;
813 }
814
Simplify(const Context & context,Position pos,const Expression & leftExpr,Operator op,const Expression & rightExpr,const Type & resultType)815 std::unique_ptr<Expression> ConstantFolder::Simplify(const Context& context,
816 Position pos,
817 const Expression& leftExpr,
818 Operator op,
819 const Expression& rightExpr,
820 const Type& resultType) {
821 // Replace constant variables with their literal values.
822 const Expression* left = GetConstantValueForVariable(leftExpr);
823 const Expression* right = GetConstantValueForVariable(rightExpr);
824
825 // If this is the assignment operator, and both sides are the same trivial expression, this is
826 // self-assignment (i.e., `var = var`) and can be reduced to just a variable reference (`var`).
827 // This can happen when other parts of the assignment are optimized away.
828 if (op.kind() == Operator::Kind::EQ && Analysis::IsSameExpressionTree(*left, *right)) {
829 return right->clone(pos);
830 }
831
832 // Simplify the expression when both sides are constant Boolean literals.
833 if (left->isBoolLiteral() && right->isBoolLiteral()) {
834 bool leftVal = left->as<Literal>().boolValue();
835 bool rightVal = right->as<Literal>().boolValue();
836 bool result;
837 switch (op.kind()) {
838 case Operator::Kind::LOGICALAND: result = leftVal && rightVal; break;
839 case Operator::Kind::LOGICALOR: result = leftVal || rightVal; break;
840 case Operator::Kind::LOGICALXOR: result = leftVal ^ rightVal; break;
841 case Operator::Kind::EQEQ: result = leftVal == rightVal; break;
842 case Operator::Kind::NEQ: result = leftVal != rightVal; break;
843 default: return nullptr;
844 }
845 return Literal::MakeBool(context, pos, result);
846 }
847
848 // If the left side is a Boolean literal, apply short-circuit optimizations.
849 if (left->isBoolLiteral()) {
850 return short_circuit_boolean(pos, *left, op, *right);
851 }
852
853 // If the right side is a Boolean literal...
854 if (right->isBoolLiteral()) {
855 // ... and the left side has no side effects...
856 if (!Analysis::HasSideEffects(*left)) {
857 // We can reverse the expressions and short-circuit optimizations are still valid.
858 return short_circuit_boolean(pos, *right, op, *left);
859 }
860
861 // We can't use short-circuiting, but we can still optimize away no-op Boolean expressions.
862 return eliminate_no_op_boolean(pos, *left, op, *right);
863 }
864
865 if (op.kind() == Operator::Kind::EQEQ && Analysis::IsSameExpressionTree(*left, *right)) {
866 // With == comparison, if both sides are the same trivial expression, this is self-
867 // comparison and is always true. (We are not concerned with NaN.)
868 return Literal::MakeBool(context, pos, /*value=*/true);
869 }
870
871 if (op.kind() == Operator::Kind::NEQ && Analysis::IsSameExpressionTree(*left, *right)) {
872 // With != comparison, if both sides are the same trivial expression, this is self-
873 // comparison and is always false. (We are not concerned with NaN.)
874 return Literal::MakeBool(context, pos, /*value=*/false);
875 }
876
877 if (error_on_divide_by_zero(context, pos, op, *right)) {
878 return nullptr;
879 }
880
881 // Perform full constant folding when both sides are compile-time constants.
882 bool leftSideIsConstant = Analysis::IsCompileTimeConstant(*left);
883 bool rightSideIsConstant = Analysis::IsCompileTimeConstant(*right);
884 if (leftSideIsConstant && rightSideIsConstant) {
885 return fold_two_constants(context, pos, left, op, right, resultType);
886 }
887
888 if (context.fConfig->fSettings.fOptimize) {
889 // If just one side is constant, we might still be able to simplify arithmetic expressions
890 // like `x * 1`, `x *= 1`, `x + 0`, `x * 0`, `0 / x`, etc.
891 if (leftSideIsConstant || rightSideIsConstant) {
892 if (std::unique_ptr<Expression> expr = simplify_arithmetic(context, pos, *left, op,
893 *right, resultType)) {
894 return expr;
895 }
896 }
897
898 // We can simplify some forms of matrix division even when neither side is constant.
899 if (std::unique_ptr<Expression> expr = simplify_matrix_division(context, pos, *left, op,
900 *right, resultType)) {
901 return expr;
902 }
903 }
904
905 // We aren't able to constant-fold.
906 return nullptr;
907 }
908
909 } // namespace SkSL
910