1 /*
2 * Copyright 2021 Google LLC
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8 #include "src/sksl/SkSLConstantFolder.h"
9 #include "src/sksl/SkSLProgramSettings.h"
10 #include "src/sksl/ir/SkSLBinaryExpression.h"
11 #include "src/sksl/ir/SkSLConstructorArray.h"
12 #include "src/sksl/ir/SkSLConstructorCompound.h"
13 #include "src/sksl/ir/SkSLIndexExpression.h"
14 #include "src/sksl/ir/SkSLLiteral.h"
15 #include "src/sksl/ir/SkSLSwizzle.h"
16 #include "src/sksl/ir/SkSLSymbolTable.h"
17 #include "src/sksl/ir/SkSLTypeReference.h"
18
19 namespace SkSL {
20
index_out_of_range(const Context & context,SKSL_INT index,const Expression & base)21 static bool index_out_of_range(const Context& context, SKSL_INT index, const Expression& base) {
22 if (index >= 0 && index < base.type().columns()) {
23 return false;
24 }
25 #ifdef SKSL_EXT
26 if (index >= 0 && Type::kUnsizedArray == base.type().columns()) {
27 return false;
28 }
29 #endif
30
31 context.fErrors->error(base.fLine, "index " + to_string(index) + " out of range for '" +
32 base.type().displayName() + "'");
33 return true;
34 }
35
IndexType(const Context & context,const Type & type)36 const Type& IndexExpression::IndexType(const Context& context, const Type& type) {
37 if (type.isMatrix()) {
38 if (type.componentType() == *context.fTypes.fFloat) {
39 switch (type.rows()) {
40 case 2: return *context.fTypes.fFloat2;
41 case 3: return *context.fTypes.fFloat3;
42 case 4: return *context.fTypes.fFloat4;
43 default: SkASSERT(false);
44 }
45 } else if (type.componentType() == *context.fTypes.fHalf) {
46 switch (type.rows()) {
47 case 2: return *context.fTypes.fHalf2;
48 case 3: return *context.fTypes.fHalf3;
49 case 4: return *context.fTypes.fHalf4;
50 default: SkASSERT(false);
51 }
52 }
53 }
54 return type.componentType();
55 }
56
Convert(const Context & context,SymbolTable & symbolTable,std::unique_ptr<Expression> base,std::unique_ptr<Expression> index)57 std::unique_ptr<Expression> IndexExpression::Convert(const Context& context,
58 SymbolTable& symbolTable,
59 std::unique_ptr<Expression> base,
60 std::unique_ptr<Expression> index) {
61 // Convert an array type reference: `int[10]`.
62 if (base->is<TypeReference>()) {
63 const Type& baseType = base->as<TypeReference>().value();
64 SKSL_INT arraySize = baseType.convertArraySize(context, std::move(index));
65 if (!arraySize) {
66 return nullptr;
67 }
68 return TypeReference::Convert(context, base->fLine,
69 symbolTable.addArrayDimension(&baseType, arraySize));
70 }
71 // Convert an index expression with an expression inside of it: `arr[a * 3]`.
72 const Type& baseType = base->type();
73 if (!baseType.isArray() && !baseType.isMatrix() && !baseType.isVector()) {
74 context.fErrors->error(base->fLine,
75 "expected array, but found '" + baseType.displayName() + "'");
76 return nullptr;
77 }
78 if (!index->type().isInteger()) {
79 index = context.fTypes.fInt->coerceExpression(std::move(index), context);
80 if (!index) {
81 return nullptr;
82 }
83 }
84 // Perform compile-time bounds checking on constant-expression indices.
85 const Expression* indexExpr = ConstantFolder::GetConstantValueForVariable(*index);
86 if (indexExpr->isIntLiteral()) {
87 SKSL_INT indexValue = indexExpr->as<Literal>().intValue();
88 if (index_out_of_range(context, indexValue, *base)) {
89 return nullptr;
90 }
91 }
92 return IndexExpression::Make(context, std::move(base), std::move(index));
93 }
94
Make(const Context & context,std::unique_ptr<Expression> base,std::unique_ptr<Expression> index)95 std::unique_ptr<Expression> IndexExpression::Make(const Context& context,
96 std::unique_ptr<Expression> base,
97 std::unique_ptr<Expression> index) {
98 const Type& baseType = base->type();
99 SkASSERT(baseType.isArray() || baseType.isMatrix() || baseType.isVector());
100 SkASSERT(index->type().isInteger());
101
102 const Expression* indexExpr = ConstantFolder::GetConstantValueForVariable(*index);
103 if (indexExpr->isIntLiteral()) {
104 SKSL_INT indexValue = indexExpr->as<Literal>().intValue();
105 if (!index_out_of_range(context, indexValue, *base)) {
106 if (baseType.isVector()) {
107 // Constant array indexes on vectors can be converted to swizzles: `v[2]` --> `v.z`.
108 // Swizzling is harmless and can unlock further simplifications for some base types.
109 return Swizzle::Make(context, std::move(base), ComponentArray{(int8_t)indexValue});
110 }
111
112 if (baseType.isArray() && !base->hasSideEffects()) {
113 // Indexing an constant array constructor with a constant index can just pluck out
114 // the requested value from the array.
115 const Expression* baseExpr = ConstantFolder::GetConstantValueForVariable(*base);
116 if (baseExpr->is<ConstructorArray>()) {
117 const ConstructorArray& arrayCtor = baseExpr->as<ConstructorArray>();
118 const ExpressionArray& arguments = arrayCtor.arguments();
119 SkASSERT(arguments.count() == baseType.columns());
120
121 return arguments[indexValue]->clone();
122 }
123 }
124
125 if (baseType.isMatrix() && !base->hasSideEffects()) {
126 // Matrices can be constructed with vectors that don't line up on column boundaries,
127 // so extracting out the values from the constructor can be tricky. Fortunately, we
128 // can reconstruct an equivalent vector using `getConstantValue`. If we
129 // can't extract the data using `getConstantValue`, it wasn't constant and
130 // we're not obligated to simplify anything.
131 const Expression* baseExpr = ConstantFolder::GetConstantValueForVariable(*base);
132 int vecWidth = baseType.rows();
133 const Type& scalarType = baseType.componentType();
134 const Type& vecType = scalarType.toCompound(context, vecWidth, /*rows=*/1);
135 indexValue *= vecWidth;
136
137 ExpressionArray ctorArgs;
138 ctorArgs.reserve_back(vecWidth);
139 for (int slot = 0; slot < vecWidth; ++slot) {
140 skstd::optional<double> slotVal = baseExpr->getConstantValue(indexValue + slot);
141 if (slotVal.has_value()) {
142 ctorArgs.push_back(Literal::Make(baseExpr->fLine, *slotVal, &scalarType));
143 } else {
144 ctorArgs.reset();
145 break;
146 }
147 }
148
149 if (!ctorArgs.empty()) {
150 int line = ctorArgs.front()->fLine;
151 return ConstructorCompound::Make(context, line, vecType, std::move(ctorArgs));
152 }
153 }
154 }
155 }
156
157 return std::make_unique<IndexExpression>(context, std::move(base), std::move(index));
158 }
159
160 } // namespace SkSL
161