• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2021 Google LLC
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "include/sksl/SkSLErrorReporter.h"
9 #include "src/sksl/SkSLAnalysis.h"
10 #include "src/sksl/SkSLConstantFolder.h"
11 #include "src/sksl/SkSLProgramSettings.h"
12 #include "src/sksl/ir/SkSLBinaryExpression.h"
13 #include "src/sksl/ir/SkSLIndexExpression.h"
14 #include "src/sksl/ir/SkSLLiteral.h"
15 #include "src/sksl/ir/SkSLSetting.h"
16 #include "src/sksl/ir/SkSLSwizzle.h"
17 #include "src/sksl/ir/SkSLTernaryExpression.h"
18 #include "src/sksl/ir/SkSLType.h"
19 #include "src/sksl/ir/SkSLVariableReference.h"
20 
21 namespace SkSL {
22 
is_low_precision_matrix_vector_multiply(const Expression & left,const Operator & op,const Expression & right,const Type & resultType)23 static bool is_low_precision_matrix_vector_multiply(const Expression& left,
24                                                     const Operator& op,
25                                                     const Expression& right,
26                                                     const Type& resultType) {
27     return !resultType.highPrecision() &&
28            op.kind() == Token::Kind::TK_STAR &&
29            left.type().isMatrix() &&
30            right.type().isVector() &&
31            left.type().rows() == right.type().columns() &&
32            Analysis::IsTrivialExpression(left) &&
33            Analysis::IsTrivialExpression(right);
34 }
35 
rewrite_matrix_vector_multiply(const Context & context,const Expression & left,const Operator & op,const Expression & right,const Type & resultType)36 static std::unique_ptr<Expression> rewrite_matrix_vector_multiply(const Context& context,
37                                                                   const Expression& left,
38                                                                   const Operator& op,
39                                                                   const Expression& right,
40                                                                   const Type& resultType) {
41     // Rewrite m33 * v3 as (m[0] * v[0] + m[1] * v[1] + m[2] * v[2])
42     std::unique_ptr<Expression> sum;
43     for (int n = 0; n < left.type().rows(); ++n) {
44         // Get mat[N] with an index expression.
45         std::unique_ptr<Expression> matN = IndexExpression::Make(
46                 context, left.clone(), Literal::MakeInt(context, left.fLine, n));
47         // Get vec[N] with a swizzle expression.
48         std::unique_ptr<Expression> vecN = Swizzle::Make(
49                 context, right.clone(), ComponentArray{(SkSL::SwizzleComponent::Type)n});
50         // Multiply them together.
51         const Type* matNType = &matN->type();
52         std::unique_ptr<Expression> product =
53                 BinaryExpression::Make(context, std::move(matN), op, std::move(vecN), matNType);
54         // Sum all the components together.
55         if (!sum) {
56             sum = std::move(product);
57         } else {
58             sum = BinaryExpression::Make(context,
59                                          std::move(sum),
60                                          Operator(Token::Kind::TK_PLUS),
61                                          std::move(product),
62                                          matNType);
63         }
64     }
65 
66     return sum;
67 }
68 
Convert(const Context & context,std::unique_ptr<Expression> left,Operator op,std::unique_ptr<Expression> right)69 std::unique_ptr<Expression> BinaryExpression::Convert(const Context& context,
70                                                       std::unique_ptr<Expression> left,
71                                                       Operator op,
72                                                       std::unique_ptr<Expression> right) {
73     if (!left || !right) {
74         return nullptr;
75     }
76     const int line = left->fLine;
77 
78     const Type* rawLeftType = (left->isIntLiteral() && right->type().isInteger())
79             ? &right->type()
80             : &left->type();
81     const Type* rawRightType = (right->isIntLiteral() && left->type().isInteger())
82             ? &left->type()
83             : &right->type();
84 
85     bool isAssignment = op.isAssignment();
86     if (isAssignment &&
87         !Analysis::UpdateVariableRefKind(left.get(),
88                                          op.kind() != Token::Kind::TK_EQ
89                                                  ? VariableReference::RefKind::kReadWrite
90                                                  : VariableReference::RefKind::kWrite,
91                                          context.fErrors)) {
92         return nullptr;
93     }
94 
95     const Type* leftType;
96     const Type* rightType;
97     const Type* resultType;
98     if (!op.determineBinaryType(context, *rawLeftType, *rawRightType,
99                                 &leftType, &rightType, &resultType)) {
100         context.fErrors->error(line, "type mismatch: '" + std::string(op.tightOperatorName()) +
101                                      "' cannot operate on '" + left->type().displayName() +
102                                      "', '" + right->type().displayName() + "'");
103         return nullptr;
104     }
105 
106     if (isAssignment && leftType->componentType().isOpaque()) {
107         context.fErrors->error(line, "assignments to opaque type '" + left->type().displayName() +
108                                      "' are not permitted");
109         return nullptr;
110     }
111     if (context.fConfig->strictES2Mode()) {
112         if (!op.isAllowedInStrictES2Mode()) {
113             context.fErrors->error(line, "operator '" + std::string(op.tightOperatorName()) +
114                                          "' is not allowed");
115             return nullptr;
116         }
117         if (leftType->isOrContainsArray()) {
118             // Most operators are already rejected on arrays, but GLSL ES 1.0 is very explicit that
119             // the *only* operator allowed on arrays is subscripting (and the rules against
120             // assignment, comparison, and even sequence apply to structs containing arrays as well)
121             context.fErrors->error(line,
122                                    "operator '" + std::string(op.tightOperatorName()) +
123                                    "' can not operate on arrays (or structs containing arrays)");
124             return nullptr;
125         }
126     }
127 
128     left = leftType->coerceExpression(std::move(left), context);
129     right = rightType->coerceExpression(std::move(right), context);
130     if (!left || !right) {
131         return nullptr;
132     }
133 
134     return BinaryExpression::Make(context, std::move(left), op, std::move(right), resultType);
135 }
136 
Make(const Context & context,std::unique_ptr<Expression> left,Operator op,std::unique_ptr<Expression> right)137 std::unique_ptr<Expression> BinaryExpression::Make(const Context& context,
138                                                    std::unique_ptr<Expression> left,
139                                                    Operator op,
140                                                    std::unique_ptr<Expression> right) {
141     // Determine the result type of the binary expression.
142     const Type* leftType;
143     const Type* rightType;
144     const Type* resultType;
145     SkAssertResult(op.determineBinaryType(context, left->type(), right->type(),
146                                           &leftType, &rightType, &resultType));
147 
148     return BinaryExpression::Make(context, std::move(left), op, std::move(right), resultType);
149 }
150 
Make(const Context & context,std::unique_ptr<Expression> left,Operator op,std::unique_ptr<Expression> right,const Type * resultType)151 std::unique_ptr<Expression> BinaryExpression::Make(const Context& context,
152                                                    std::unique_ptr<Expression> left,
153                                                    Operator op,
154                                                    std::unique_ptr<Expression> right,
155                                                    const Type* resultType) {
156     // We should have detected non-ES2 compliant behavior in Convert.
157     SkASSERT(!context.fConfig->strictES2Mode() || op.isAllowedInStrictES2Mode());
158     SkASSERT(!context.fConfig->strictES2Mode() || !left->type().isOrContainsArray());
159 
160     // We should have detected non-assignable assignment expressions in Convert.
161     SkASSERT(!op.isAssignment() || Analysis::IsAssignable(*left));
162     SkASSERT(!op.isAssignment() || !left->type().componentType().isOpaque());
163 
164     // For simple assignments, detect and report out-of-range literal values.
165     if (op.kind() == Token::Kind::TK_EQ) {
166         left->type().checkForOutOfRangeLiteral(context, *right);
167     }
168 
169     // Perform constant-folding on the expression.
170     const int line = left->fLine;
171     if (std::unique_ptr<Expression> result = ConstantFolder::Simplify(context, line, *left,
172                                                                       op, *right, *resultType)) {
173         return result;
174     }
175 
176     if (context.fConfig->fSettings.fOptimize) {
177         // When sk_Caps.rewriteMatrixVectorMultiply is set, we rewrite medium-precision
178         // matrix * vector multiplication as:
179         //   (sk_Caps.rewriteMatrixVectorMultiply ? (mat[0]*vec[0] + ... + mat[N]*vec[N])
180         //                                        : mat * vec)
181         if (is_low_precision_matrix_vector_multiply(*left, op, *right, *resultType)) {
182             // Look up `sk_Caps.rewriteMatrixVectorMultiply`.
183             auto caps = Setting::Convert(context, line, "rewriteMatrixVectorMultiply");
184 
185             bool capsBitIsTrue = caps->isBoolLiteral() && caps->as<Literal>().boolValue();
186             if (capsBitIsTrue || !caps->isBoolLiteral()) {
187                 // Rewrite the multiplication as a sum of vector-scalar products.
188                 std::unique_ptr<Expression> rewrite =
189                         rewrite_matrix_vector_multiply(context, *left, op, *right, *resultType);
190 
191                 // If we know the caps bit is true, return the rewritten expression directly.
192                 if (capsBitIsTrue) {
193                     return rewrite;
194                 }
195 
196                 // Return a ternary expression:
197                 //     sk_Caps.rewriteMatrixVectorMultiply ? (rewrite) : (mat * vec)
198                 return TernaryExpression::Make(
199                         context,
200                         std::move(caps),
201                         std::move(rewrite),
202                         std::make_unique<BinaryExpression>(line, std::move(left), op,
203                                                            std::move(right), resultType));
204             }
205         }
206     }
207 
208     return std::make_unique<BinaryExpression>(line, std::move(left), op,
209                                               std::move(right), resultType);
210 }
211 
CheckRef(const Expression & expr)212 bool BinaryExpression::CheckRef(const Expression& expr) {
213     switch (expr.kind()) {
214         case Expression::Kind::kFieldAccess:
215             return CheckRef(*expr.as<FieldAccess>().base());
216 
217         case Expression::Kind::kIndex:
218             return CheckRef(*expr.as<IndexExpression>().base());
219 
220         case Expression::Kind::kSwizzle:
221             return CheckRef(*expr.as<Swizzle>().base());
222 
223         case Expression::Kind::kTernary: {
224             const TernaryExpression& t = expr.as<TernaryExpression>();
225             return CheckRef(*t.ifTrue()) && CheckRef(*t.ifFalse());
226         }
227         case Expression::Kind::kVariableReference: {
228             const VariableReference& ref = expr.as<VariableReference>();
229             return ref.refKind() == VariableRefKind::kWrite ||
230                    ref.refKind() == VariableRefKind::kReadWrite;
231         }
232         default:
233             return false;
234     }
235 }
236 
clone() const237 std::unique_ptr<Expression> BinaryExpression::clone() const {
238     return std::make_unique<BinaryExpression>(fLine,
239                                               this->left()->clone(),
240                                               this->getOperator(),
241                                               this->right()->clone(),
242                                               &this->type());
243 }
244 
description() const245 std::string BinaryExpression::description() const {
246     return "(" + this->left()->description() +
247                  this->getOperator().operatorName() +
248                  this->right()->description() + ")";
249 }
250 
251 }  // namespace SkSL
252