1 /*
2  * Copyright 2021 Google LLC
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/ir/SkSLBinaryExpression.h"
9 
10 #include "include/private/SkSLDefines.h"
11 #include "include/sksl/SkSLErrorReporter.h"
12 #include "src/sksl/SkSLAnalysis.h"
13 #include "src/sksl/SkSLConstantFolder.h"
14 #include "src/sksl/SkSLContext.h"
15 #include "src/sksl/SkSLProgramSettings.h"
16 #include "src/sksl/SkSLUtil.h"
17 #include "src/sksl/ir/SkSLFieldAccess.h"
18 #include "src/sksl/ir/SkSLIndexExpression.h"
19 #include "src/sksl/ir/SkSLLiteral.h"
20 #include "src/sksl/ir/SkSLSetting.h"
21 #include "src/sksl/ir/SkSLSwizzle.h"
22 #include "src/sksl/ir/SkSLTernaryExpression.h"
23 #include "src/sksl/ir/SkSLType.h"
24 #include "src/sksl/ir/SkSLVariableReference.h"
25 
26 namespace SkSL {
27 
is_low_precision_matrix_vector_multiply(const Expression & left,const Operator & op,const Expression & right,const Type & resultType)28 static bool is_low_precision_matrix_vector_multiply(const Expression& left,
29                                                     const Operator& op,
30                                                     const Expression& right,
31                                                     const Type& resultType) {
32     return !resultType.highPrecision() &&
33            op.kind() == Operator::Kind::STAR &&
34            left.type().isMatrix() &&
35            right.type().isVector() &&
36            left.type().rows() == right.type().columns() &&
37            Analysis::IsTrivialExpression(left) &&
38            Analysis::IsTrivialExpression(right);
39 }
40 
rewrite_matrix_vector_multiply(const Context & context,Position pos,const Expression & left,const Operator & op,const Expression & right,const Type & resultType)41 static std::unique_ptr<Expression> rewrite_matrix_vector_multiply(const Context& context,
42                                                                   Position pos,
43                                                                   const Expression& left,
44                                                                   const Operator& op,
45                                                                   const Expression& right,
46                                                                   const Type& resultType) {
47     // Rewrite m33 * v3 as (m[0] * v[0] + m[1] * v[1] + m[2] * v[2])
48     std::unique_ptr<Expression> sum;
49     for (int n = 0; n < left.type().rows(); ++n) {
50         // Get mat[N] with an index expression.
51         std::unique_ptr<Expression> matN = IndexExpression::Make(
52                 context, pos, left.clone(), Literal::MakeInt(context, left.fPosition, n));
53         // Get vec[N] with a swizzle expression.
54         std::unique_ptr<Expression> vecN = Swizzle::Make(context,
55                 left.fPosition.rangeThrough(right.fPosition), right.clone(),
56                 ComponentArray{(SkSL::SwizzleComponent::Type)n});
57         // Multiply them together.
58         const Type* matNType = &matN->type();
59         std::unique_ptr<Expression> product =
60                 BinaryExpression::Make(context, pos, std::move(matN), op, std::move(vecN),
61                                        matNType);
62         // Sum all the components together.
63         if (!sum) {
64             sum = std::move(product);
65         } else {
66             sum = BinaryExpression::Make(context,
67                                          pos,
68                                          std::move(sum),
69                                          Operator(Operator::Kind::PLUS),
70                                          std::move(product),
71                                          matNType);
72         }
73     }
74 
75     return sum;
76 }
77 
Convert(const Context & context,Position pos,std::unique_ptr<Expression> left,Operator op,std::unique_ptr<Expression> right)78 std::unique_ptr<Expression> BinaryExpression::Convert(const Context& context,
79                                                       Position pos,
80                                                       std::unique_ptr<Expression> left,
81                                                       Operator op,
82                                                       std::unique_ptr<Expression> right) {
83     if (!left || !right) {
84         return nullptr;
85     }
86     const Type* rawLeftType = (left->isIntLiteral() && right->type().isInteger())
87             ? &right->type()
88             : &left->type();
89     const Type* rawRightType = (right->isIntLiteral() && left->type().isInteger())
90             ? &left->type()
91             : &right->type();
92 
93     bool isAssignment = op.isAssignment();
94     if (isAssignment &&
95         !Analysis::UpdateVariableRefKind(left.get(),
96                                          op.kind() != Operator::Kind::EQ
97                                                  ? VariableReference::RefKind::kReadWrite
98                                                  : VariableReference::RefKind::kWrite,
99                                          context.fErrors)) {
100         return nullptr;
101     }
102 
103     const Type* leftType;
104     const Type* rightType;
105     const Type* resultType;
106     if (!op.determineBinaryType(context, *rawLeftType, *rawRightType,
107                                 &leftType, &rightType, &resultType)) {
108         context.fErrors->error(pos, "type mismatch: '" + std::string(op.tightOperatorName()) +
109                 "' cannot operate on '" + left->type().displayName() + "', '" +
110                 right->type().displayName() + "'");
111         return nullptr;
112     }
113 
114     if (isAssignment && (leftType->componentType().isOpaque() || leftType->isOrContainsAtomic())) {
115         context.fErrors->error(pos, "assignments to opaque type '" + left->type().displayName() +
116                 "' are not permitted");
117         return nullptr;
118     }
119     if (context.fConfig->strictES2Mode()) {
120         if (!op.isAllowedInStrictES2Mode()) {
121             context.fErrors->error(pos, "operator '" + std::string(op.tightOperatorName()) +
122                     "' is not allowed");
123             return nullptr;
124         }
125         if (leftType->isOrContainsArray()) {
126             // Most operators are already rejected on arrays, but GLSL ES 1.0 is very explicit that
127             // the *only* operator allowed on arrays is subscripting (and the rules against
128             // assignment, comparison, and even sequence apply to structs containing arrays as well)
129             context.fErrors->error(pos, "operator '" + std::string(op.tightOperatorName()) +
130                     "' can not operate on arrays (or structs containing arrays)");
131             return nullptr;
132         }
133     }
134 
135     left = leftType->coerceExpression(std::move(left), context);
136     right = rightType->coerceExpression(std::move(right), context);
137     if (!left || !right) {
138         return nullptr;
139     }
140 
141     return BinaryExpression::Make(context, pos, std::move(left), op, std::move(right), resultType);
142 }
143 
Make(const Context & context,Position pos,std::unique_ptr<Expression> left,Operator op,std::unique_ptr<Expression> right)144 std::unique_ptr<Expression> BinaryExpression::Make(const Context& context,
145                                                    Position pos,
146                                                    std::unique_ptr<Expression> left,
147                                                    Operator op,
148                                                    std::unique_ptr<Expression> right) {
149     // Determine the result type of the binary expression.
150     const Type* leftType;
151     const Type* rightType;
152     const Type* resultType;
153     SkAssertResult(op.determineBinaryType(context, left->type(), right->type(),
154                                           &leftType, &rightType, &resultType));
155 
156     return BinaryExpression::Make(context, pos, std::move(left), op, std::move(right), resultType);
157 }
158 
Make(const Context & context,Position pos,std::unique_ptr<Expression> left,Operator op,std::unique_ptr<Expression> right,const Type * resultType)159 std::unique_ptr<Expression> BinaryExpression::Make(const Context& context,
160                                                    Position pos,
161                                                    std::unique_ptr<Expression> left,
162                                                    Operator op,
163                                                    std::unique_ptr<Expression> right,
164                                                    const Type* resultType) {
165     // We should have detected non-ES2 compliant behavior in Convert.
166     SkASSERT(!context.fConfig->strictES2Mode() || op.isAllowedInStrictES2Mode());
167     SkASSERT(!context.fConfig->strictES2Mode() || !left->type().isOrContainsArray());
168 
169     // We should have detected non-assignable assignment expressions in Convert.
170     SkASSERT(!op.isAssignment() || Analysis::IsAssignable(*left));
171     SkASSERT(!op.isAssignment() || !left->type().componentType().isOpaque());
172 
173     // For simple assignments, detect and report out-of-range literal values.
174     if (op.kind() == Operator::Kind::EQ) {
175         left->type().checkForOutOfRangeLiteral(context, *right);
176     }
177 
178     // Perform constant-folding on the expression.
179     if (std::unique_ptr<Expression> result = ConstantFolder::Simplify(context, pos, *left,
180                                                                       op, *right, *resultType)) {
181         return result;
182     }
183 
184     if (context.fConfig->fSettings.fOptimize && !context.fConfig->fIsBuiltinCode) {
185         // When sk_Caps.rewriteMatrixVectorMultiply is set, we rewrite medium-precision
186         // matrix * vector multiplication as:
187         //   (sk_Caps.rewriteMatrixVectorMultiply ? (mat[0]*vec[0] + ... + mat[N]*vec[N])
188         //                                        : mat * vec)
189         if (is_low_precision_matrix_vector_multiply(*left, op, *right, *resultType)) {
190             // Look up `sk_Caps.rewriteMatrixVectorMultiply`.
191             auto caps = Setting::Make(context, pos, &ShaderCaps::fRewriteMatrixVectorMultiply);
192 
193             // There are three possible outcomes from Setting::Convert:
194             // - If the ShaderCaps aren't known (fCaps in the Context is null), we will get back a
195             //   Setting IRNode. In practice, this should happen when compiling a module.
196             //   In this case, we generate a ternary expression which will be optimized away when
197             //   the module code is actually incorporated into a program.
198             // - If `rewriteMatrixVectorMultiply` is true in our shader caps, we will get back a
199             //   Literal set to true. When this happens, we always return the rewritten expression.
200             // - If `rewriteMatrixVectorMultiply` is false in our shader caps, we will get back a
201             //   Literal set to false. When this happens, we return the expression as-is.
202             bool capsBitIsTrue = caps->isBoolLiteral() && caps->as<Literal>().boolValue();
203             if (capsBitIsTrue || !caps->isBoolLiteral()) {
204                 // Rewrite the multiplication as a sum of vector-scalar products.
205                 std::unique_ptr<Expression> rewrite =
206                         rewrite_matrix_vector_multiply(context, pos, *left, op, *right,
207                                                        *resultType);
208 
209                 // If we know the caps bit is true, return the rewritten expression directly.
210                 if (capsBitIsTrue) {
211                     return rewrite;
212                 }
213 
214                 // Return a ternary expression:
215                 //     sk_Caps.rewriteMatrixVectorMultiply ? (rewrite) : (mat * vec)
216                 return TernaryExpression::Make(
217                         context,
218                         pos,
219                         std::move(caps),
220                         std::move(rewrite),
221                         std::make_unique<BinaryExpression>(pos, std::move(left), op,
222                                                            std::move(right), resultType));
223             }
224         }
225     }
226 
227     return std::make_unique<BinaryExpression>(pos, std::move(left), op,
228                                               std::move(right), resultType);
229 }
230 
CheckRef(const Expression & expr)231 bool BinaryExpression::CheckRef(const Expression& expr) {
232     switch (expr.kind()) {
233         case Expression::Kind::kFieldAccess:
234             return CheckRef(*expr.as<FieldAccess>().base());
235 
236         case Expression::Kind::kIndex:
237             return CheckRef(*expr.as<IndexExpression>().base());
238 
239         case Expression::Kind::kSwizzle:
240             return CheckRef(*expr.as<Swizzle>().base());
241 
242         case Expression::Kind::kTernary: {
243             const TernaryExpression& t = expr.as<TernaryExpression>();
244             return CheckRef(*t.ifTrue()) && CheckRef(*t.ifFalse());
245         }
246         case Expression::Kind::kVariableReference: {
247             const VariableReference& ref = expr.as<VariableReference>();
248             return ref.refKind() == VariableRefKind::kWrite ||
249                    ref.refKind() == VariableRefKind::kReadWrite;
250         }
251         default:
252             return false;
253     }
254 }
255 
clone(Position pos) const256 std::unique_ptr<Expression> BinaryExpression::clone(Position pos) const {
257     return std::make_unique<BinaryExpression>(pos,
258                                               this->left()->clone(),
259                                               this->getOperator(),
260                                               this->right()->clone(),
261                                               &this->type());
262 }
263 
description(OperatorPrecedence parentPrecedence) const264 std::string BinaryExpression::description(OperatorPrecedence parentPrecedence) const {
265     OperatorPrecedence operatorPrecedence = this->getOperator().getBinaryPrecedence();
266     bool needsParens = (operatorPrecedence >= parentPrecedence);
267     return std::string(needsParens ? "(" : "") +
268            this->left()->description(operatorPrecedence) +
269            this->getOperator().operatorName() +
270            this->right()->description(operatorPrecedence) +
271            std::string(needsParens ? ")" : "");
272 }
273 
isAssignmentIntoVariable()274 VariableReference* BinaryExpression::isAssignmentIntoVariable() {
275     if (this->getOperator().isAssignment()) {
276         Analysis::AssignmentInfo assignmentInfo;
277         if (Analysis::IsAssignable(*this->left(), &assignmentInfo, /*errors=*/nullptr)) {
278             return assignmentInfo.fAssignedVar;
279         }
280     }
281     return nullptr;
282 }
283 
284 }  // namespace SkSL
285