1 /*
2 * Copyright 2020 Google LLC
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8 #include "include/sksl/DSLType.h"
9
10 #include "include/core/SkTypes.h"
11 #include "include/private/SkSLDefines.h"
12 #include "include/private/SkSLProgramElement.h"
13 #include "include/private/SkSLString.h"
14 #include "include/private/SkSLSymbol.h"
15 #include "include/sksl/SkSLErrorReporter.h"
16 #include "src/sksl/SkSLBuiltinTypes.h"
17 #include "src/sksl/SkSLContext.h"
18 #include "src/sksl/SkSLProgramSettings.h"
19 #include "src/sksl/SkSLThreadContext.h"
20 #include "src/sksl/ir/SkSLConstructor.h"
21 #include "src/sksl/ir/SkSLStructDefinition.h"
22 #include "src/sksl/ir/SkSLSymbolTable.h"
23 #include "src/sksl/ir/SkSLType.h"
24
25 #include <memory>
26 #include <string>
27 #include <vector>
28
29 namespace SkSL {
30
31 struct Modifiers;
32
33 namespace dsl {
34
verify_type(const Context & context,const SkSL::Type * type,bool allowGenericTypes,Position pos)35 static const SkSL::Type* verify_type(const Context& context,
36 const SkSL::Type* type,
37 bool allowGenericTypes,
38 Position pos) {
39 if (!context.fConfig->fIsBuiltinCode && type) {
40 if (!allowGenericTypes && (type->isGeneric() || type->isLiteral())) {
41 context.fErrors->error(pos, "type '" + std::string(type->name()) + "' is generic");
42 return context.fTypes.fPoison.get();
43 }
44 if (!type->isAllowedInES2(context)) {
45 context.fErrors->error(pos, "type '" + std::string(type->name()) +"' is not supported");
46 return context.fTypes.fPoison.get();
47 }
48 }
49 return type;
50 }
51
find_type(const Context & context,Position pos,std::string_view name)52 static const SkSL::Type* find_type(const Context& context,
53 Position pos,
54 std::string_view name) {
55 const Symbol* symbol = ThreadContext::SymbolTable()->find(name);
56 if (!symbol) {
57 context.fErrors->error(pos, String::printf("no symbol named '%.*s'",
58 (int)name.length(), name.data()));
59 return context.fTypes.fPoison.get();
60 }
61 if (!symbol->is<SkSL::Type>()) {
62 context.fErrors->error(pos, String::printf("symbol '%.*s' is not a type",
63 (int)name.length(), name.data()));
64 return context.fTypes.fPoison.get();
65 }
66 const SkSL::Type* type = &symbol->as<SkSL::Type>();
67 return verify_type(context, type, /*allowGenericTypes=*/false, pos);
68 }
69
find_type(const Context & context,Position overallPos,std::string_view name,Position modifiersPos,Modifiers * modifiers)70 static const SkSL::Type* find_type(const Context& context,
71 Position overallPos,
72 std::string_view name,
73 Position modifiersPos,
74 Modifiers* modifiers) {
75 const Type* type = find_type(context, overallPos, name);
76 return type->applyQualifiers(context, modifiers, ThreadContext::SymbolTable().get(),
77 modifiersPos);
78 }
79
get_type_from_type_constant(TypeConstant tc)80 static const SkSL::Type* get_type_from_type_constant(TypeConstant tc) {
81 const Context& context = ThreadContext::Context();
82 switch (tc) {
83 case kBool_Type:
84 return context.fTypes.fBool.get();
85 case kBool2_Type:
86 return context.fTypes.fBool2.get();
87 case kBool3_Type:
88 return context.fTypes.fBool3.get();
89 case kBool4_Type:
90 return context.fTypes.fBool4.get();
91 case kHalf_Type:
92 return context.fTypes.fHalf.get();
93 case kHalf2_Type:
94 return context.fTypes.fHalf2.get();
95 case kHalf3_Type:
96 return context.fTypes.fHalf3.get();
97 case kHalf4_Type:
98 return context.fTypes.fHalf4.get();
99 case kHalf2x2_Type:
100 return context.fTypes.fHalf2x2.get();
101 case kHalf3x2_Type:
102 return context.fTypes.fHalf3x2.get();
103 case kHalf4x2_Type:
104 return context.fTypes.fHalf4x2.get();
105 case kHalf2x3_Type:
106 return context.fTypes.fHalf2x3.get();
107 case kHalf3x3_Type:
108 return context.fTypes.fHalf3x3.get();
109 case kHalf4x3_Type:
110 return context.fTypes.fHalf4x3.get();
111 case kHalf2x4_Type:
112 return context.fTypes.fHalf2x4.get();
113 case kHalf3x4_Type:
114 return context.fTypes.fHalf3x4.get();
115 case kHalf4x4_Type:
116 return context.fTypes.fHalf4x4.get();
117 case kFloat_Type:
118 return context.fTypes.fFloat.get();
119 case kFloat2_Type:
120 return context.fTypes.fFloat2.get();
121 case kFloat3_Type:
122 return context.fTypes.fFloat3.get();
123 case kFloat4_Type:
124 return context.fTypes.fFloat4.get();
125 case kFloat2x2_Type:
126 return context.fTypes.fFloat2x2.get();
127 case kFloat3x2_Type:
128 return context.fTypes.fFloat3x2.get();
129 case kFloat4x2_Type:
130 return context.fTypes.fFloat4x2.get();
131 case kFloat2x3_Type:
132 return context.fTypes.fFloat2x3.get();
133 case kFloat3x3_Type:
134 return context.fTypes.fFloat3x3.get();
135 case kFloat4x3_Type:
136 return context.fTypes.fFloat4x3.get();
137 case kFloat2x4_Type:
138 return context.fTypes.fFloat2x4.get();
139 case kFloat3x4_Type:
140 return context.fTypes.fFloat3x4.get();
141 case kFloat4x4_Type:
142 return context.fTypes.fFloat4x4.get();
143 case kInt_Type:
144 return context.fTypes.fInt.get();
145 case kInt2_Type:
146 return context.fTypes.fInt2.get();
147 case kInt3_Type:
148 return context.fTypes.fInt3.get();
149 case kInt4_Type:
150 return context.fTypes.fInt4.get();
151 case kShader_Type:
152 return context.fTypes.fShader.get();
153 case kShort_Type:
154 return context.fTypes.fShort.get();
155 case kShort2_Type:
156 return context.fTypes.fShort2.get();
157 case kShort3_Type:
158 return context.fTypes.fShort3.get();
159 case kShort4_Type:
160 return context.fTypes.fShort4.get();
161 case kUInt_Type:
162 return context.fTypes.fUInt.get();
163 case kUInt2_Type:
164 return context.fTypes.fUInt2.get();
165 case kUInt3_Type:
166 return context.fTypes.fUInt3.get();
167 case kUInt4_Type:
168 return context.fTypes.fUInt4.get();
169 case kUShort_Type:
170 return context.fTypes.fUShort.get();
171 case kUShort2_Type:
172 return context.fTypes.fUShort2.get();
173 case kUShort3_Type:
174 return context.fTypes.fUShort3.get();
175 case kUShort4_Type:
176 return context.fTypes.fUShort4.get();
177 case kVoid_Type:
178 return context.fTypes.fVoid.get();
179 case kPoison_Type:
180 return context.fTypes.fPoison.get();
181 default:
182 SkUNREACHABLE;
183 }
184 }
185
DSLType(TypeConstant tc,Position pos)186 DSLType::DSLType(TypeConstant tc, Position pos)
187 : fSkSLType(verify_type(ThreadContext::Context(),
188 get_type_from_type_constant(tc),
189 /*allowGenericTypes=*/false,
190 pos)) {}
191
DSLType(std::string_view name,Position pos)192 DSLType::DSLType(std::string_view name, Position pos)
193 : fSkSLType(find_type(ThreadContext::Context(), pos, name)) {}
194
DSLType(std::string_view name,DSLModifiers * modifiers,Position pos)195 DSLType::DSLType(std::string_view name, DSLModifiers* modifiers, Position pos)
196 : fSkSLType(find_type(ThreadContext::Context(),
197 pos,
198 name,
199 modifiers->fPosition,
200 &modifiers->fModifiers)) {}
201
DSLType(const SkSL::Type * type,Position pos)202 DSLType::DSLType(const SkSL::Type* type, Position pos)
203 : fSkSLType(verify_type(ThreadContext::Context(), type, /*allowGenericTypes=*/true, pos)) {}
204
Invalid()205 DSLType DSLType::Invalid() {
206 return DSLType(ThreadContext::Context().fTypes.fInvalid.get(), Position());
207 }
208
isBoolean() const209 bool DSLType::isBoolean() const {
210 return this->skslType().isBoolean();
211 }
212
isNumber() const213 bool DSLType::isNumber() const {
214 return this->skslType().isNumber();
215 }
216
isFloat() const217 bool DSLType::isFloat() const {
218 return this->skslType().isFloat();
219 }
220
isSigned() const221 bool DSLType::isSigned() const {
222 return this->skslType().isSigned();
223 }
224
isUnsigned() const225 bool DSLType::isUnsigned() const {
226 return this->skslType().isUnsigned();
227 }
228
isInteger() const229 bool DSLType::isInteger() const {
230 return this->skslType().isInteger();
231 }
232
isScalar() const233 bool DSLType::isScalar() const {
234 return this->skslType().isScalar();
235 }
236
isVector() const237 bool DSLType::isVector() const {
238 return this->skslType().isVector();
239 }
240
isMatrix() const241 bool DSLType::isMatrix() const {
242 return this->skslType().isMatrix();
243 }
244
isArray() const245 bool DSLType::isArray() const {
246 return this->skslType().isArray();
247 }
248
isStruct() const249 bool DSLType::isStruct() const {
250 return this->skslType().isStruct();
251 }
252
isEffectChild() const253 bool DSLType::isEffectChild() const {
254 return this->skslType().isEffectChild();
255 }
256
Construct(DSLType type,SkSpan<DSLExpression> argArray)257 DSLExpression DSLType::Construct(DSLType type, SkSpan<DSLExpression> argArray) {
258 SkSL::ExpressionArray skslArgs;
259 skslArgs.reserve_back(argArray.size());
260
261 for (DSLExpression& arg : argArray) {
262 if (!arg.hasValue()) {
263 return DSLExpression();
264 }
265 skslArgs.push_back(arg.release());
266 }
267 return DSLExpression(SkSL::Constructor::Convert(ThreadContext::Context(), Position(),
268 type.skslType(), std::move(skslArgs)));
269 }
270
Array(const DSLType & base,int count,Position pos)271 DSLType Array(const DSLType& base, int count, Position pos) {
272 count = base.skslType().convertArraySize(ThreadContext::Context(), pos,
273 DSLExpression(count, pos).release());
274 if (!count) {
275 return DSLType(kPoison_Type);
276 }
277 return DSLType(ThreadContext::SymbolTable()->addArrayDimension(&base.skslType(), count), pos);
278 }
279
UnsizedArray(const DSLType & base,Position pos)280 DSLType UnsizedArray(const DSLType& base, Position pos) {
281 if (!base.skslType().checkIfUsableInArray(ThreadContext::Context(), pos)) {
282 return DSLType(kPoison_Type);
283 }
284 return ThreadContext::SymbolTable()->addArrayDimension(&base.skslType(),
285 SkSL::Type::kUnsizedArray);
286 }
287
StructType(std::string_view name,SkSpan<DSLField> fields,bool interfaceBlock,Position pos)288 DSLType StructType(std::string_view name,
289 SkSpan<DSLField> fields,
290 bool interfaceBlock,
291 Position pos) {
292 std::vector<SkSL::Type::Field> skslFields;
293 skslFields.reserve(fields.size());
294 for (const DSLField& field : fields) {
295 skslFields.emplace_back(field.fPosition, field.fModifiers.fModifiers, field.fName,
296 &field.fType.skslType());
297 }
298 std::unique_ptr<Type> newType = Type::MakeStructType(ThreadContext::Context(), pos, name,
299 std::move(skslFields), interfaceBlock);
300 return DSLType(ThreadContext::SymbolTable()->add(std::move(newType)), pos);
301 }
302
Struct(std::string_view name,SkSpan<DSLField> fields,Position pos)303 DSLType Struct(std::string_view name, SkSpan<DSLField> fields, Position pos) {
304 DSLType result = StructType(name, fields, /*interfaceBlock=*/false, pos);
305 ThreadContext::ProgramElements().push_back(
306 std::make_unique<SkSL::StructDefinition>(pos, result.skslType()));
307 return result;
308 }
309
310 } // namespace dsl
311
312 } // namespace SkSL
313