1 /*
2 * Copyright 2020 Google LLC
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8 #include "src/sksl/SkSLConstantFolder.h"
9
10 #include "include/core/SkTypes.h"
11 #include "include/private/SkSLModifiers.h"
12 #include "include/private/base/SkFloatingPoint.h"
13 #include "include/private/base/SkTArray.h"
14 #include "include/sksl/SkSLErrorReporter.h"
15 #include "include/sksl/SkSLPosition.h"
16 #include "src/sksl/SkSLAnalysis.h"
17 #include "src/sksl/SkSLContext.h"
18 #include "src/sksl/SkSLProgramSettings.h"
19 #include "src/sksl/ir/SkSLBinaryExpression.h"
20 #include "src/sksl/ir/SkSLConstructorCompound.h"
21 #include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
22 #include "src/sksl/ir/SkSLConstructorSplat.h"
23 #include "src/sksl/ir/SkSLExpression.h"
24 #include "src/sksl/ir/SkSLLiteral.h"
25 #include "src/sksl/ir/SkSLPrefixExpression.h"
26 #include "src/sksl/ir/SkSLType.h"
27 #include "src/sksl/ir/SkSLVariable.h"
28 #include "src/sksl/ir/SkSLVariableReference.h"
29
30 #include <cstdint>
31 #include <float.h>
32 #include <limits>
33 #include <optional>
34 #include <string>
35 #include <utility>
36
37 namespace SkSL {
38
is_vec_or_mat(const Type & type)39 static bool is_vec_or_mat(const Type& type) {
40 switch (type.typeKind()) {
41 case Type::TypeKind::kMatrix:
42 case Type::TypeKind::kVector:
43 return true;
44
45 default:
46 return false;
47 }
48 }
49
eliminate_no_op_boolean(Position pos,const Expression & left,Operator op,const Expression & right)50 static std::unique_ptr<Expression> eliminate_no_op_boolean(Position pos,
51 const Expression& left,
52 Operator op,
53 const Expression& right) {
54 bool rightVal = right.as<Literal>().boolValue();
55
56 // Detect no-op Boolean expressions and optimize them away.
57 if ((op.kind() == Operator::Kind::LOGICALAND && rightVal) || // (expr && true) -> (expr)
58 (op.kind() == Operator::Kind::LOGICALOR && !rightVal) || // (expr || false) -> (expr)
59 (op.kind() == Operator::Kind::LOGICALXOR && !rightVal) || // (expr ^^ false) -> (expr)
60 (op.kind() == Operator::Kind::EQEQ && rightVal) || // (expr == true) -> (expr)
61 (op.kind() == Operator::Kind::NEQ && !rightVal)) { // (expr != false) -> (expr)
62
63 return left.clone(pos);
64 }
65
66 return nullptr;
67 }
68
short_circuit_boolean(Position pos,const Expression & left,Operator op,const Expression & right)69 static std::unique_ptr<Expression> short_circuit_boolean(Position pos,
70 const Expression& left,
71 Operator op,
72 const Expression& right) {
73 bool leftVal = left.as<Literal>().boolValue();
74
75 // When the literal is on the left, we can sometimes eliminate the other expression entirely.
76 if ((op.kind() == Operator::Kind::LOGICALAND && !leftVal) || // (false && expr) -> (false)
77 (op.kind() == Operator::Kind::LOGICALOR && leftVal)) { // (true || expr) -> (true)
78
79 return left.clone(pos);
80 }
81
82 // We can't eliminate the right-side expression via short-circuit, but we might still be able to
83 // simplify away a no-op expression.
84 return eliminate_no_op_boolean(pos, right, op, left);
85 }
86
simplify_constant_equality(const Context & context,Position pos,const Expression & left,Operator op,const Expression & right)87 static std::unique_ptr<Expression> simplify_constant_equality(const Context& context,
88 Position pos,
89 const Expression& left,
90 Operator op,
91 const Expression& right) {
92 if (op.kind() == Operator::Kind::EQEQ || op.kind() == Operator::Kind::NEQ) {
93 bool equality = (op.kind() == Operator::Kind::EQEQ);
94
95 switch (left.compareConstant(right)) {
96 case Expression::ComparisonResult::kNotEqual:
97 equality = !equality;
98 [[fallthrough]];
99
100 case Expression::ComparisonResult::kEqual:
101 return Literal::MakeBool(context, pos, equality);
102
103 case Expression::ComparisonResult::kUnknown:
104 break;
105 }
106 }
107 return nullptr;
108 }
109
simplify_matrix_multiplication(const Context & context,Position pos,const Expression & left,const Expression & right,int leftColumns,int leftRows,int rightColumns,int rightRows)110 static std::unique_ptr<Expression> simplify_matrix_multiplication(const Context& context,
111 Position pos,
112 const Expression& left,
113 const Expression& right,
114 int leftColumns,
115 int leftRows,
116 int rightColumns,
117 int rightRows) {
118 const Type& componentType = left.type().componentType();
119 SkASSERT(componentType.matches(right.type().componentType()));
120
121 // Fetch the left matrix.
122 double leftVals[4][4];
123 for (int c = 0; c < leftColumns; ++c) {
124 for (int r = 0; r < leftRows; ++r) {
125 leftVals[c][r] = *left.getConstantValue((c * leftRows) + r);
126 }
127 }
128 // Fetch the right matrix.
129 double rightVals[4][4];
130 for (int c = 0; c < rightColumns; ++c) {
131 for (int r = 0; r < rightRows; ++r) {
132 rightVals[c][r] = *right.getConstantValue((c * rightRows) + r);
133 }
134 }
135
136 SkASSERT(leftColumns == rightRows);
137 int outColumns = rightColumns,
138 outRows = leftRows;
139
140 ExpressionArray args;
141 args.reserve_back(outColumns * outRows);
142 for (int c = 0; c < outColumns; ++c) {
143 for (int r = 0; r < outRows; ++r) {
144 // Compute a dot product for this position.
145 double val = 0;
146 for (int dotIdx = 0; dotIdx < leftColumns; ++dotIdx) {
147 val += leftVals[dotIdx][r] * rightVals[c][dotIdx];
148 }
149 args.push_back(Literal::Make(pos, val, &componentType));
150 }
151 }
152
153 if (outColumns == 1) {
154 // Matrix-times-vector conceptually makes a 1-column N-row matrix, but we return vecN.
155 std::swap(outColumns, outRows);
156 }
157
158 const Type& resultType = componentType.toCompound(context, outColumns, outRows);
159 return ConstructorCompound::Make(context, pos, resultType, std::move(args));
160 }
161
simplify_matrix_times_matrix(const Context & context,Position pos,const Expression & left,const Expression & right)162 static std::unique_ptr<Expression> simplify_matrix_times_matrix(const Context& context,
163 Position pos,
164 const Expression& left,
165 const Expression& right) {
166 const Type& leftType = left.type();
167 const Type& rightType = right.type();
168
169 SkASSERT(leftType.isMatrix());
170 SkASSERT(rightType.isMatrix());
171
172 return simplify_matrix_multiplication(context, pos, left, right,
173 leftType.columns(), leftType.rows(),
174 rightType.columns(), rightType.rows());
175 }
176
simplify_vector_times_matrix(const Context & context,Position pos,const Expression & left,const Expression & right)177 static std::unique_ptr<Expression> simplify_vector_times_matrix(const Context& context,
178 Position pos,
179 const Expression& left,
180 const Expression& right) {
181 const Type& leftType = left.type();
182 const Type& rightType = right.type();
183
184 SkASSERT(leftType.isVector());
185 SkASSERT(rightType.isMatrix());
186
187 return simplify_matrix_multiplication(context, pos, left, right,
188 /*leftColumns=*/leftType.columns(), /*leftRows=*/1,
189 rightType.columns(), rightType.rows());
190 }
191
simplify_matrix_times_vector(const Context & context,Position pos,const Expression & left,const Expression & right)192 static std::unique_ptr<Expression> simplify_matrix_times_vector(const Context& context,
193 Position pos,
194 const Expression& left,
195 const Expression& right) {
196 const Type& leftType = left.type();
197 const Type& rightType = right.type();
198
199 SkASSERT(leftType.isMatrix());
200 SkASSERT(rightType.isVector());
201
202 return simplify_matrix_multiplication(context, pos, left, right,
203 leftType.columns(), leftType.rows(),
204 /*rightColumns=*/1, /*rightRows=*/rightType.columns());
205 }
206
simplify_componentwise(const Context & context,Position pos,const Expression & left,Operator op,const Expression & right)207 static std::unique_ptr<Expression> simplify_componentwise(const Context& context,
208 Position pos,
209 const Expression& left,
210 Operator op,
211 const Expression& right) {
212 SkASSERT(is_vec_or_mat(left.type()));
213 SkASSERT(left.type().matches(right.type()));
214 const Type& type = left.type();
215
216 // Handle equality operations: == !=
217 if (std::unique_ptr<Expression> result = simplify_constant_equality(context, pos, left, op,
218 right)) {
219 return result;
220 }
221
222 // Handle floating-point arithmetic: + - * /
223 using FoldFn = double (*)(double, double);
224 FoldFn foldFn;
225 switch (op.kind()) {
226 case Operator::Kind::PLUS: foldFn = +[](double a, double b) { return a + b; }; break;
227 case Operator::Kind::MINUS: foldFn = +[](double a, double b) { return a - b; }; break;
228 case Operator::Kind::STAR: foldFn = +[](double a, double b) { return a * b; }; break;
229 case Operator::Kind::SLASH: foldFn = +[](double a, double b) { return a / b; }; break;
230 default:
231 return nullptr;
232 }
233
234 const Type& componentType = type.componentType();
235 SkASSERT(componentType.isNumber());
236
237 double minimumValue = componentType.minimumValue();
238 double maximumValue = componentType.maximumValue();
239
240 ExpressionArray args;
241 int numSlots = type.slotCount();
242 args.reserve_back(numSlots);
243 for (int i = 0; i < numSlots; i++) {
244 double value = foldFn(*left.getConstantValue(i), *right.getConstantValue(i));
245 if (value < minimumValue || value > maximumValue) {
246 return nullptr;
247 }
248
249 args.push_back(Literal::Make(pos, value, &componentType));
250 }
251 return ConstructorCompound::Make(context, pos, type, std::move(args));
252 }
253
splat_scalar(const Context & context,const Expression & scalar,const Type & type)254 static std::unique_ptr<Expression> splat_scalar(const Context& context,
255 const Expression& scalar,
256 const Type& type) {
257 if (type.isVector()) {
258 return ConstructorSplat::Make(context, scalar.fPosition, type, scalar.clone());
259 }
260 if (type.isMatrix()) {
261 int numSlots = type.slotCount();
262 ExpressionArray splatMatrix;
263 splatMatrix.reserve_back(numSlots);
264 for (int index = 0; index < numSlots; ++index) {
265 splatMatrix.push_back(scalar.clone());
266 }
267 return ConstructorCompound::Make(context, scalar.fPosition, type, std::move(splatMatrix));
268 }
269 SkDEBUGFAILF("unsupported type %s", type.description().c_str());
270 return nullptr;
271 }
272
cast_expression(const Context & context,Position pos,const Expression & expr,const Type & type)273 static std::unique_ptr<Expression> cast_expression(const Context& context,
274 Position pos,
275 const Expression& expr,
276 const Type& type) {
277 SkASSERT(type.componentType().matches(expr.type().componentType()));
278 if (expr.type().isScalar()) {
279 if (type.isMatrix()) {
280 return ConstructorDiagonalMatrix::Make(context, pos, type, expr.clone());
281 }
282 if (type.isVector()) {
283 return ConstructorSplat::Make(context, pos, type, expr.clone());
284 }
285 }
286 if (type.matches(expr.type())) {
287 return expr.clone(pos);
288 }
289 // We can't cast matrices into vectors or vice-versa.
290 return nullptr;
291 }
292
zero_expression(const Context & context,Position pos,const Type & type)293 static std::unique_ptr<Expression> zero_expression(const Context& context,
294 Position pos,
295 const Type& type) {
296 std::unique_ptr<Expression> zero = Literal::Make(pos, 0.0, &type.componentType());
297 if (type.isScalar()) {
298 return zero;
299 }
300 if (type.isVector()) {
301 return ConstructorSplat::Make(context, pos, type, std::move(zero));
302 }
303 if (type.isMatrix()) {
304 return ConstructorDiagonalMatrix::Make(context, pos, type, std::move(zero));
305 }
306 SkDEBUGFAILF("unsupported type %s", type.description().c_str());
307 return nullptr;
308 }
309
negate_expression(const Context & context,Position pos,const Expression & expr,const Type & type)310 static std::unique_ptr<Expression> negate_expression(const Context& context,
311 Position pos,
312 const Expression& expr,
313 const Type& type) {
314 std::unique_ptr<Expression> ctor = cast_expression(context, pos, expr, type);
315 return ctor ? PrefixExpression::Make(context, pos, Operator::Kind::MINUS, std::move(ctor))
316 : nullptr;
317 }
318
GetConstantInt(const Expression & value,SKSL_INT * out)319 bool ConstantFolder::GetConstantInt(const Expression& value, SKSL_INT* out) {
320 const Expression* expr = GetConstantValueForVariable(value);
321 if (!expr->isIntLiteral()) {
322 return false;
323 }
324 *out = expr->as<Literal>().intValue();
325 return true;
326 }
327
GetConstantValue(const Expression & value,double * out)328 bool ConstantFolder::GetConstantValue(const Expression& value, double* out) {
329 const Expression* expr = GetConstantValueForVariable(value);
330 if (!expr->is<Literal>()) {
331 return false;
332 }
333 *out = expr->as<Literal>().value();
334 return true;
335 }
336
contains_constant_zero(const Expression & expr)337 static bool contains_constant_zero(const Expression& expr) {
338 int numSlots = expr.type().slotCount();
339 for (int index = 0; index < numSlots; ++index) {
340 std::optional<double> slotVal = expr.getConstantValue(index);
341 if (slotVal.has_value() && *slotVal == 0.0) {
342 return true;
343 }
344 }
345 return false;
346 }
347
348 // Returns true if the expression contains `value` in every slot.
is_constant_splat(const Expression & expr,double value)349 static bool is_constant_splat(const Expression& expr, double value) {
350 int numSlots = expr.type().slotCount();
351 for (int index = 0; index < numSlots; ++index) {
352 std::optional<double> slotVal = expr.getConstantValue(index);
353 if (!slotVal.has_value() || *slotVal != value) {
354 return false;
355 }
356 }
357 return true;
358 }
359
360 // Returns true if the expression is a square diagonal matrix containing `value`.
is_constant_diagonal(const Expression & expr,double value)361 static bool is_constant_diagonal(const Expression& expr, double value) {
362 SkASSERT(expr.type().isMatrix());
363 int columns = expr.type().columns();
364 int rows = expr.type().rows();
365 if (columns != rows) {
366 return false;
367 }
368 int slotIdx = 0;
369 for (int c = 0; c < columns; ++c) {
370 for (int r = 0; r < rows; ++r) {
371 double expectation = (c == r) ? value : 0;
372 std::optional<double> slotVal = expr.getConstantValue(slotIdx++);
373 if (!slotVal.has_value() || *slotVal != expectation) {
374 return false;
375 }
376 }
377 }
378 return true;
379 }
380
381 // Returns true if the expression is a scalar, vector, or diagonal matrix containing `value`.
is_constant_value(const Expression & expr,double value)382 static bool is_constant_value(const Expression& expr, double value) {
383 return expr.type().isMatrix() ? is_constant_diagonal(expr, value)
384 : is_constant_splat(expr, value);
385 }
386
387 // The expression represents the right-hand side of a division op. If the division can be
388 // strength-reduced into multiplication by a reciprocal, returns that reciprocal as an expression.
389 // Note that this only supports literal values with safe-to-use reciprocals, and returns null if
390 // Expression contains anything else.
make_reciprocal_expression(const Context & context,const Expression & right)391 static std::unique_ptr<Expression> make_reciprocal_expression(const Context& context,
392 const Expression& right) {
393 if (right.type().isMatrix() || !right.type().componentType().isFloat()) {
394 return nullptr;
395 }
396 // Verify that each slot contains a finite, non-zero literal, take its reciprocal.
397 int nslots = right.type().slotCount();
398 SkSTArray<4, double> values;
399 for (int index = 0; index < nslots; ++index) {
400 std::optional<double> value = right.getConstantValue(index);
401 if (!value) {
402 return nullptr;
403 }
404 *value = sk_ieee_double_divide(1.0, *value);
405 if (*value >= -FLT_MAX && *value <= FLT_MAX && *value != 0.0) {
406 // The reciprocal can be represented safely as a finite 32-bit float.
407 values.push_back(*value);
408 } else {
409 // The value is outside the 32-bit float range, or is NaN; do not optimize.
410 return nullptr;
411 }
412 }
413 // Convert our reciprocal values to Literals.
414 ExpressionArray exprs;
415 exprs.reserve_back(nslots);
416 for (double value : values) {
417 exprs.push_back(Literal::Make(right.fPosition, value, &right.type().componentType()));
418 }
419 // Turn the expression array into a compound constructor. (If this is a single-slot expression,
420 // this will return the literal as-is.)
421 return ConstructorCompound::Make(context, right.fPosition, right.type(), std::move(exprs));
422 }
423
error_on_divide_by_zero(const Context & context,Position pos,Operator op,const Expression & right)424 static bool error_on_divide_by_zero(const Context& context, Position pos, Operator op,
425 const Expression& right) {
426 switch (op.kind()) {
427 case Operator::Kind::SLASH:
428 case Operator::Kind::SLASHEQ:
429 case Operator::Kind::PERCENT:
430 case Operator::Kind::PERCENTEQ:
431 if (contains_constant_zero(right)) {
432 context.fErrors->error(pos, "division by zero");
433 return true;
434 }
435 return false;
436 default:
437 return false;
438 }
439 }
440
GetConstantValueOrNullForVariable(const Expression & inExpr)441 const Expression* ConstantFolder::GetConstantValueOrNullForVariable(const Expression& inExpr) {
442 for (const Expression* expr = &inExpr;;) {
443 if (!expr->is<VariableReference>()) {
444 break;
445 }
446 const VariableReference& varRef = expr->as<VariableReference>();
447 if (varRef.refKind() != VariableRefKind::kRead) {
448 break;
449 }
450 const Variable& var = *varRef.variable();
451 if (!(var.modifiers().fFlags & Modifiers::kConst_Flag)) {
452 break;
453 }
454 expr = var.initialValue();
455 if (!expr) {
456 // Function parameters can be const but won't have an initial value.
457 break;
458 }
459 if (Analysis::IsCompileTimeConstant(*expr)) {
460 return expr;
461 }
462 }
463 // We didn't find a compile-time constant at the end.
464 return nullptr;
465 }
466
GetConstantValueForVariable(const Expression & inExpr)467 const Expression* ConstantFolder::GetConstantValueForVariable(const Expression& inExpr) {
468 const Expression* expr = GetConstantValueOrNullForVariable(inExpr);
469 return expr ? expr : &inExpr;
470 }
471
MakeConstantValueForVariable(Position pos,std::unique_ptr<Expression> inExpr)472 std::unique_ptr<Expression> ConstantFolder::MakeConstantValueForVariable(
473 Position pos, std::unique_ptr<Expression> inExpr) {
474 const Expression* expr = GetConstantValueOrNullForVariable(*inExpr);
475 return expr ? expr->clone(pos) : std::move(inExpr);
476 }
477
is_scalar_op_matrix(const Expression & left,const Expression & right)478 static bool is_scalar_op_matrix(const Expression& left, const Expression& right) {
479 return left.type().isScalar() && right.type().isMatrix();
480 }
481
is_matrix_op_scalar(const Expression & left,const Expression & right)482 static bool is_matrix_op_scalar(const Expression& left, const Expression& right) {
483 return is_scalar_op_matrix(right, left);
484 }
485
simplify_arithmetic(const Context & context,Position pos,const Expression & left,Operator op,const Expression & right,const Type & resultType)486 static std::unique_ptr<Expression> simplify_arithmetic(const Context& context,
487 Position pos,
488 const Expression& left,
489 Operator op,
490 const Expression& right,
491 const Type& resultType) {
492 switch (op.kind()) {
493 case Operator::Kind::PLUS:
494 if (!is_scalar_op_matrix(left, right) && is_constant_splat(right, 0.0)) { // x + 0
495 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, left,
496 resultType)) {
497 return expr;
498 }
499 }
500 if (!is_matrix_op_scalar(left, right) && is_constant_splat(left, 0.0)) { // 0 + x
501 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, right,
502 resultType)) {
503 return expr;
504 }
505 }
506 break;
507
508 case Operator::Kind::STAR:
509 if (is_constant_value(right, 1.0)) { // x * 1
510 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, left,
511 resultType)) {
512 return expr;
513 }
514 }
515 if (is_constant_value(left, 1.0)) { // 1 * x
516 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, right,
517 resultType)) {
518 return expr;
519 }
520 }
521 if (is_constant_value(right, 0.0) && !Analysis::HasSideEffects(left)) { // x * 0
522 return zero_expression(context, pos, resultType);
523 }
524 if (is_constant_value(left, 0.0) && !Analysis::HasSideEffects(right)) { // 0 * x
525 return zero_expression(context, pos, resultType);
526 }
527 if (is_constant_value(right, -1.0)) { // x * -1 (to `-x`)
528 if (std::unique_ptr<Expression> expr = negate_expression(context, pos, left,
529 resultType)) {
530 return expr;
531 }
532 }
533 if (is_constant_value(left, -1.0)) { // -1 * x (to `-x`)
534 if (std::unique_ptr<Expression> expr = negate_expression(context, pos, right,
535 resultType)) {
536 return expr;
537 }
538 }
539 break;
540
541 case Operator::Kind::MINUS:
542 if (!is_scalar_op_matrix(left, right) && is_constant_splat(right, 0.0)) { // x - 0
543 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, left,
544 resultType)) {
545 return expr;
546 }
547 }
548 if (!is_matrix_op_scalar(left, right) && is_constant_splat(left, 0.0)) { // 0 - x
549 if (std::unique_ptr<Expression> expr = negate_expression(context, pos, right,
550 resultType)) {
551 return expr;
552 }
553 }
554 break;
555
556 case Operator::Kind::SLASH:
557 if (!is_scalar_op_matrix(left, right) && is_constant_splat(right, 1.0)) { // x / 1
558 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, left,
559 resultType)) {
560 return expr;
561 }
562 }
563 if (!left.type().isMatrix()) { // convert `x / 2` into `x * 0.5`
564 if (std::unique_ptr<Expression> expr = make_reciprocal_expression(context, right)) {
565 return BinaryExpression::Make(context, pos, left.clone(), Operator::Kind::STAR,
566 std::move(expr));
567 }
568 }
569 break;
570
571 case Operator::Kind::PLUSEQ:
572 case Operator::Kind::MINUSEQ:
573 if (is_constant_splat(right, 0.0)) { // x += 0, x -= 0
574 if (std::unique_ptr<Expression> var = cast_expression(context, pos, left,
575 resultType)) {
576 Analysis::UpdateVariableRefKind(var.get(), VariableRefKind::kRead);
577 return var;
578 }
579 }
580 break;
581
582 case Operator::Kind::STAREQ:
583 if (is_constant_value(right, 1.0)) { // x *= 1
584 if (std::unique_ptr<Expression> var = cast_expression(context, pos, left,
585 resultType)) {
586 Analysis::UpdateVariableRefKind(var.get(), VariableRefKind::kRead);
587 return var;
588 }
589 }
590 break;
591
592 case Operator::Kind::SLASHEQ:
593 if (is_constant_splat(right, 1.0)) { // x /= 1
594 if (std::unique_ptr<Expression> var = cast_expression(context, pos, left,
595 resultType)) {
596 Analysis::UpdateVariableRefKind(var.get(), VariableRefKind::kRead);
597 return var;
598 }
599 }
600 if (std::unique_ptr<Expression> expr = make_reciprocal_expression(context, right)) {
601 return BinaryExpression::Make(context, pos, left.clone(), Operator::Kind::STAREQ,
602 std::move(expr));
603 }
604 break;
605
606 default:
607 break;
608 }
609
610 return nullptr;
611 }
612
613 // The expression must be scalar, and represents the right-hand side of a division op. It can
614 // contain anything, not just literal values. This returns the binary expression `1.0 / expr`. The
615 // expression might be further simplified by the constant folding, if possible.
one_over_scalar(const Context & context,const Expression & right)616 static std::unique_ptr<Expression> one_over_scalar(const Context& context,
617 const Expression& right) {
618 SkASSERT(right.type().isScalar());
619 Position pos = right.fPosition;
620 return BinaryExpression::Make(context, pos,
621 Literal::Make(pos, 1.0, &right.type()),
622 Operator::Kind::SLASH,
623 right.clone());
624 }
625
simplify_matrix_division(const Context & context,Position pos,const Expression & left,Operator op,const Expression & right,const Type & resultType)626 static std::unique_ptr<Expression> simplify_matrix_division(const Context& context,
627 Position pos,
628 const Expression& left,
629 Operator op,
630 const Expression& right,
631 const Type& resultType) {
632 // Convert matrix-over-scalar `x /= y` into `x *= (1.0 / y)`. This generates better
633 // code in SPIR-V and Metal, and should be roughly equivalent elsewhere.
634 switch (op.kind()) {
635 case OperatorKind::SLASH:
636 case OperatorKind::SLASHEQ:
637 if (left.type().isMatrix() && right.type().isScalar()) {
638 Operator multiplyOp = op.isAssignment() ? OperatorKind::STAREQ
639 : OperatorKind::STAR;
640 return BinaryExpression::Make(context, pos,
641 left.clone(),
642 multiplyOp,
643 one_over_scalar(context, right));
644 }
645 break;
646
647 default:
648 break;
649 }
650
651 return nullptr;
652 }
653
fold_expression(Position pos,double result,const Type * resultType)654 static std::unique_ptr<Expression> fold_expression(Position pos,
655 double result,
656 const Type* resultType) {
657 if (resultType->isNumber()) {
658 if (result >= resultType->minimumValue() && result <= resultType->maximumValue()) {
659 // This result will fit inside its type.
660 } else {
661 // The value is outside the range or is NaN (all if-checks fail); do not optimize.
662 return nullptr;
663 }
664 }
665
666 return Literal::Make(pos, result, resultType);
667 }
668
Simplify(const Context & context,Position pos,const Expression & leftExpr,Operator op,const Expression & rightExpr,const Type & resultType)669 std::unique_ptr<Expression> ConstantFolder::Simplify(const Context& context,
670 Position pos,
671 const Expression& leftExpr,
672 Operator op,
673 const Expression& rightExpr,
674 const Type& resultType) {
675 // Replace constant variables with their literal values.
676 const Expression* left = GetConstantValueForVariable(leftExpr);
677 const Expression* right = GetConstantValueForVariable(rightExpr);
678
679 // If this is the assignment operator, and both sides are the same trivial expression, this is
680 // self-assignment (i.e., `var = var`) and can be reduced to just a variable reference (`var`).
681 // This can happen when other parts of the assignment are optimized away.
682 if (op.kind() == Operator::Kind::EQ && Analysis::IsSameExpressionTree(*left, *right)) {
683 return right->clone(pos);
684 }
685
686 // Simplify the expression when both sides are constant Boolean literals.
687 if (left->isBoolLiteral() && right->isBoolLiteral()) {
688 bool leftVal = left->as<Literal>().boolValue();
689 bool rightVal = right->as<Literal>().boolValue();
690 bool result;
691 switch (op.kind()) {
692 case Operator::Kind::LOGICALAND: result = leftVal && rightVal; break;
693 case Operator::Kind::LOGICALOR: result = leftVal || rightVal; break;
694 case Operator::Kind::LOGICALXOR: result = leftVal ^ rightVal; break;
695 case Operator::Kind::EQEQ: result = leftVal == rightVal; break;
696 case Operator::Kind::NEQ: result = leftVal != rightVal; break;
697 default: return nullptr;
698 }
699 return Literal::MakeBool(context, pos, result);
700 }
701
702 // If the left side is a Boolean literal, apply short-circuit optimizations.
703 if (left->isBoolLiteral()) {
704 return short_circuit_boolean(pos, *left, op, *right);
705 }
706
707 // If the right side is a Boolean literal...
708 if (right->isBoolLiteral()) {
709 // ... and the left side has no side effects...
710 if (!Analysis::HasSideEffects(*left)) {
711 // We can reverse the expressions and short-circuit optimizations are still valid.
712 return short_circuit_boolean(pos, *right, op, *left);
713 }
714
715 // We can't use short-circuiting, but we can still optimize away no-op Boolean expressions.
716 return eliminate_no_op_boolean(pos, *left, op, *right);
717 }
718
719 if (op.kind() == Operator::Kind::EQEQ && Analysis::IsSameExpressionTree(*left, *right)) {
720 // With == comparison, if both sides are the same trivial expression, this is self-
721 // comparison and is always true. (We are not concerned with NaN.)
722 return Literal::MakeBool(context, pos, /*value=*/true);
723 }
724
725 if (op.kind() == Operator::Kind::NEQ && Analysis::IsSameExpressionTree(*left, *right)) {
726 // With != comparison, if both sides are the same trivial expression, this is self-
727 // comparison and is always false. (We are not concerned with NaN.)
728 return Literal::MakeBool(context, pos, /*value=*/false);
729 }
730
731 if (error_on_divide_by_zero(context, pos, op, *right)) {
732 return nullptr;
733 }
734
735 // Perform full constant folding when both sides are compile-time constants.
736 const Type& leftType = left->type();
737 const Type& rightType = right->type();
738 bool leftSideIsConstant = Analysis::IsCompileTimeConstant(*left);
739 bool rightSideIsConstant = Analysis::IsCompileTimeConstant(*right);
740
741 if (leftSideIsConstant && rightSideIsConstant) {
742 // Handle pairs of integer literals.
743 if (left->isIntLiteral() && right->isIntLiteral()) {
744 using SKSL_UINT = uint64_t;
745 SKSL_INT leftVal = left->as<Literal>().intValue();
746 SKSL_INT rightVal = right->as<Literal>().intValue();
747
748 // Note that fold_expression returns null if the result would overflow its type.
749 #define RESULT(Op) fold_expression(pos, (SKSL_INT)(leftVal) Op \
750 (SKSL_INT)(rightVal), &resultType)
751 #define URESULT(Op) fold_expression(pos, (SKSL_INT)((SKSL_UINT)(leftVal) Op \
752 (SKSL_UINT)(rightVal)), &resultType)
753 switch (op.kind()) {
754 case Operator::Kind::PLUS: return URESULT(+);
755 case Operator::Kind::MINUS: return URESULT(-);
756 case Operator::Kind::STAR: return URESULT(*);
757 case Operator::Kind::SLASH:
758 if (leftVal == std::numeric_limits<SKSL_INT>::min() && rightVal == -1) {
759 context.fErrors->error(pos, "arithmetic overflow");
760 return nullptr;
761 }
762 return RESULT(/);
763 case Operator::Kind::PERCENT:
764 if (leftVal == std::numeric_limits<SKSL_INT>::min() && rightVal == -1) {
765 context.fErrors->error(pos, "arithmetic overflow");
766 return nullptr;
767 }
768 return RESULT(%);
769 case Operator::Kind::BITWISEAND: return RESULT(&);
770 case Operator::Kind::BITWISEOR: return RESULT(|);
771 case Operator::Kind::BITWISEXOR: return RESULT(^);
772 case Operator::Kind::EQEQ: return RESULT(==);
773 case Operator::Kind::NEQ: return RESULT(!=);
774 case Operator::Kind::GT: return RESULT(>);
775 case Operator::Kind::GTEQ: return RESULT(>=);
776 case Operator::Kind::LT: return RESULT(<);
777 case Operator::Kind::LTEQ: return RESULT(<=);
778 case Operator::Kind::SHL:
779 if (rightVal >= 0 && rightVal <= 31) {
780 // Left-shifting a negative (or really, any signed) value is undefined
781 // behavior in C++, but not in GLSL. Do the shift on unsigned values to avoid
782 // triggering an UBSAN error.
783 return URESULT(<<);
784 }
785 context.fErrors->error(pos, "shift value out of range");
786 return nullptr;
787 case Operator::Kind::SHR:
788 if (rightVal >= 0 && rightVal <= 31) {
789 return RESULT(>>);
790 }
791 context.fErrors->error(pos, "shift value out of range");
792 return nullptr;
793
794 default:
795 return nullptr;
796 }
797 #undef RESULT
798 #undef URESULT
799 }
800
801 // Handle pairs of floating-point literals.
802 if (left->isFloatLiteral() && right->isFloatLiteral()) {
803 SKSL_FLOAT leftVal = left->as<Literal>().floatValue();
804 SKSL_FLOAT rightVal = right->as<Literal>().floatValue();
805
806 #define RESULT(Op) fold_expression(pos, leftVal Op rightVal, &resultType)
807 switch (op.kind()) {
808 case Operator::Kind::PLUS: return RESULT(+);
809 case Operator::Kind::MINUS: return RESULT(-);
810 case Operator::Kind::STAR: return RESULT(*);
811 case Operator::Kind::SLASH: return RESULT(/);
812 case Operator::Kind::EQEQ: return RESULT(==);
813 case Operator::Kind::NEQ: return RESULT(!=);
814 case Operator::Kind::GT: return RESULT(>);
815 case Operator::Kind::GTEQ: return RESULT(>=);
816 case Operator::Kind::LT: return RESULT(<);
817 case Operator::Kind::LTEQ: return RESULT(<=);
818 default: return nullptr;
819 }
820 #undef RESULT
821 }
822
823 // Perform matrix multiplication.
824 if (op.kind() == Operator::Kind::STAR) {
825 if (leftType.isMatrix() && rightType.isMatrix()) {
826 return simplify_matrix_times_matrix(context, pos, *left, *right);
827 }
828 if (leftType.isVector() && rightType.isMatrix()) {
829 return simplify_vector_times_matrix(context, pos, *left, *right);
830 }
831 if (leftType.isMatrix() && rightType.isVector()) {
832 return simplify_matrix_times_vector(context, pos, *left, *right);
833 }
834 }
835
836 // Perform constant folding on pairs of vectors/matrices.
837 if (is_vec_or_mat(leftType) && leftType.matches(rightType)) {
838 return simplify_componentwise(context, pos, *left, op, *right);
839 }
840
841 // Perform constant folding on vectors/matrices against scalars, e.g.: half4(2) + 2
842 if (rightType.isScalar() && is_vec_or_mat(leftType) &&
843 leftType.componentType().matches(rightType)) {
844 return simplify_componentwise(context, pos,
845 *left, op, *splat_scalar(context, *right, left->type()));
846 }
847
848 // Perform constant folding on scalars against vectors/matrices, e.g.: 2 + half4(2)
849 if (leftType.isScalar() && is_vec_or_mat(rightType) &&
850 rightType.componentType().matches(leftType)) {
851 return simplify_componentwise(context, pos,
852 *splat_scalar(context, *left, right->type()), op, *right);
853 }
854
855 // Perform constant folding on pairs of matrices, arrays or structs.
856 if ((leftType.isMatrix() && rightType.isMatrix()) ||
857 (leftType.isArray() && rightType.isArray()) ||
858 (leftType.isStruct() && rightType.isStruct())) {
859 return simplify_constant_equality(context, pos, *left, op, *right);
860 }
861 }
862
863 if (context.fConfig->fSettings.fOptimize) {
864 // If just one side is constant, we might still be able to simplify arithmetic expressions
865 // like `x * 1`, `x *= 1`, `x + 0`, `x * 0`, `0 / x`, etc.
866 if (leftSideIsConstant || rightSideIsConstant) {
867 if (std::unique_ptr<Expression> expr = simplify_arithmetic(context, pos, *left, op,
868 *right, resultType)) {
869 return expr;
870 }
871 }
872
873 // We can simplify some forms of matrix division even when neither side is constant.
874 if (std::unique_ptr<Expression> expr = simplify_matrix_division(context, pos, *left, op,
875 *right, resultType)) {
876 return expr;
877 }
878 }
879
880 // We aren't able to constant-fold.
881 return nullptr;
882 }
883
884 } // namespace SkSL
885