1 /*
2 * Copyright 2020 Google LLC
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8 #include "include/sksl/DSLType.h"
9
10 #include "src/sksl/SkSLThreadContext.h"
11 #include "src/sksl/ir/SkSLConstructor.h"
12 #include "src/sksl/ir/SkSLStructDefinition.h"
13
14 namespace SkSL {
15
16 namespace dsl {
17
verify_type(const Context & context,const SkSL::Type * type,bool allowPrivateTypes,PositionInfo pos)18 static const SkSL::Type* verify_type(const Context& context,
19 const SkSL::Type* type,
20 bool allowPrivateTypes,
21 PositionInfo pos) {
22 if (!context.fConfig->fIsBuiltinCode) {
23 if (!allowPrivateTypes && type->isPrivate()) {
24 context.fErrors->error("type '" + String(type->name()) + "' is private", pos);
25 return context.fTypes.fPoison.get();
26 }
27 if (!type->isAllowedInES2(context)) {
28 context.fErrors->error("type '" + String(type->name()) + "' is not supported", pos);
29 return context.fTypes.fPoison.get();
30 }
31 }
32 return type;
33 }
34
find_type(const Context & context,skstd::string_view name,PositionInfo pos)35 static const SkSL::Type* find_type(const Context& context,
36 skstd::string_view name,
37 PositionInfo pos) {
38 const Symbol* symbol = (*ThreadContext::SymbolTable())[name];
39 if (!symbol) {
40 context.fErrors->error(String::printf("no symbol named '%.*s'",
41 (int)name.length(), name.data()), pos);
42 return context.fTypes.fPoison.get();
43 }
44 if (!symbol->is<SkSL::Type>()) {
45 context.fErrors->error(String::printf("symbol '%.*s' is not a type",
46 (int)name.length(), name.data()), pos);
47 return context.fTypes.fPoison.get();
48 }
49 const SkSL::Type* type = &symbol->as<SkSL::Type>();
50 return verify_type(context, type, /*allowPrivateTypes=*/false, pos);
51 }
52
find_type(const Context & context,skstd::string_view name,Modifiers * modifiers,PositionInfo pos)53 static const SkSL::Type* find_type(const Context& context,
54 skstd::string_view name,
55 Modifiers* modifiers,
56 PositionInfo pos) {
57 const Type* type = find_type(context, name, pos);
58 type = type->applyPrecisionQualifiers(context, modifiers, ThreadContext::SymbolTable().get(),
59 pos.line());
60 ThreadContext::ReportErrors(pos);
61 return type;
62 }
63
get_type_from_type_constant(const Context & context,TypeConstant tc)64 static const SkSL::Type* get_type_from_type_constant(const Context& context, TypeConstant tc) {
65 switch (tc) {
66 case kBool_Type:
67 return context.fTypes.fBool.get();
68 case kBool2_Type:
69 return context.fTypes.fBool2.get();
70 case kBool3_Type:
71 return context.fTypes.fBool3.get();
72 case kBool4_Type:
73 return context.fTypes.fBool4.get();
74 case kHalf_Type:
75 return context.fTypes.fHalf.get();
76 case kHalf2_Type:
77 return context.fTypes.fHalf2.get();
78 case kHalf3_Type:
79 return context.fTypes.fHalf3.get();
80 case kHalf4_Type:
81 return context.fTypes.fHalf4.get();
82 case kHalf2x2_Type:
83 return context.fTypes.fHalf2x2.get();
84 case kHalf3x2_Type:
85 return context.fTypes.fHalf3x2.get();
86 case kHalf4x2_Type:
87 return context.fTypes.fHalf4x2.get();
88 case kHalf2x3_Type:
89 return context.fTypes.fHalf2x3.get();
90 case kHalf3x3_Type:
91 return context.fTypes.fHalf3x3.get();
92 case kHalf4x3_Type:
93 return context.fTypes.fHalf4x3.get();
94 case kHalf2x4_Type:
95 return context.fTypes.fHalf2x4.get();
96 case kHalf3x4_Type:
97 return context.fTypes.fHalf3x4.get();
98 case kHalf4x4_Type:
99 return context.fTypes.fHalf4x4.get();
100 case kFloat_Type:
101 return context.fTypes.fFloat.get();
102 case kFloat2_Type:
103 return context.fTypes.fFloat2.get();
104 case kFloat3_Type:
105 return context.fTypes.fFloat3.get();
106 case kFloat4_Type:
107 return context.fTypes.fFloat4.get();
108 case kFloat2x2_Type:
109 return context.fTypes.fFloat2x2.get();
110 case kFloat3x2_Type:
111 return context.fTypes.fFloat3x2.get();
112 case kFloat4x2_Type:
113 return context.fTypes.fFloat4x2.get();
114 case kFloat2x3_Type:
115 return context.fTypes.fFloat2x3.get();
116 case kFloat3x3_Type:
117 return context.fTypes.fFloat3x3.get();
118 case kFloat4x3_Type:
119 return context.fTypes.fFloat4x3.get();
120 case kFloat2x4_Type:
121 return context.fTypes.fFloat2x4.get();
122 case kFloat3x4_Type:
123 return context.fTypes.fFloat3x4.get();
124 case kFloat4x4_Type:
125 return context.fTypes.fFloat4x4.get();
126 case kInt_Type:
127 return context.fTypes.fInt.get();
128 case kInt2_Type:
129 return context.fTypes.fInt2.get();
130 case kInt3_Type:
131 return context.fTypes.fInt3.get();
132 case kInt4_Type:
133 return context.fTypes.fInt4.get();
134 case kShader_Type:
135 return context.fTypes.fShader.get();
136 case kShort_Type:
137 return context.fTypes.fShort.get();
138 case kShort2_Type:
139 return context.fTypes.fShort2.get();
140 case kShort3_Type:
141 return context.fTypes.fShort3.get();
142 case kShort4_Type:
143 return context.fTypes.fShort4.get();
144 case kUInt_Type:
145 return context.fTypes.fUInt.get();
146 case kUInt2_Type:
147 return context.fTypes.fUInt2.get();
148 case kUInt3_Type:
149 return context.fTypes.fUInt3.get();
150 case kUInt4_Type:
151 return context.fTypes.fUInt4.get();
152 case kUShort_Type:
153 return context.fTypes.fUShort.get();
154 case kUShort2_Type:
155 return context.fTypes.fUShort2.get();
156 case kUShort3_Type:
157 return context.fTypes.fUShort3.get();
158 case kUShort4_Type:
159 return context.fTypes.fUShort4.get();
160 case kVoid_Type:
161 return context.fTypes.fVoid.get();
162 case kPoison_Type:
163 return context.fTypes.fPoison.get();
164 default:
165 SkUNREACHABLE;
166 }
167 }
168
DSLType(skstd::string_view name)169 DSLType::DSLType(skstd::string_view name)
170 : fSkSLType(find_type(ThreadContext::Context(), name, PositionInfo())) {}
171
DSLType(skstd::string_view name,DSLModifiers * modifiers,PositionInfo position)172 DSLType::DSLType(skstd::string_view name, DSLModifiers* modifiers, PositionInfo position)
173 : fSkSLType(find_type(ThreadContext::Context(), name, &modifiers->fModifiers, position)) {}
174
DSLType(const SkSL::Type * type)175 DSLType::DSLType(const SkSL::Type* type)
176 : fSkSLType(verify_type(ThreadContext::Context(), type, /*allowPrivateTypes=*/true,
177 PositionInfo())) {}
178
isBoolean() const179 bool DSLType::isBoolean() const {
180 return this->skslType().isBoolean();
181 }
182
isNumber() const183 bool DSLType::isNumber() const {
184 return this->skslType().isNumber();
185 }
186
isFloat() const187 bool DSLType::isFloat() const {
188 return this->skslType().isFloat();
189 }
190
isSigned() const191 bool DSLType::isSigned() const {
192 return this->skslType().isSigned();
193 }
194
isUnsigned() const195 bool DSLType::isUnsigned() const {
196 return this->skslType().isUnsigned();
197 }
198
isInteger() const199 bool DSLType::isInteger() const {
200 return this->skslType().isInteger();
201 }
202
isScalar() const203 bool DSLType::isScalar() const {
204 return this->skslType().isScalar();
205 }
206
isVector() const207 bool DSLType::isVector() const {
208 return this->skslType().isVector();
209 }
210
isMatrix() const211 bool DSLType::isMatrix() const {
212 return this->skslType().isMatrix();
213 }
214
isArray() const215 bool DSLType::isArray() const {
216 return this->skslType().isArray();
217 }
218
isStruct() const219 bool DSLType::isStruct() const {
220 return this->skslType().isStruct();
221 }
222
isEffectChild() const223 bool DSLType::isEffectChild() const {
224 return this->skslType().isEffectChild();
225 }
226
skslType() const227 const SkSL::Type& DSLType::skslType() const {
228 if (fSkSLType) {
229 return *fSkSLType;
230 }
231 const Context& context = ThreadContext::Context();
232 return *verify_type(context,
233 get_type_from_type_constant(context, fTypeConstant),
234 /*allowPrivateTypes=*/true,
235 PositionInfo());
236 }
237
Construct(DSLType type,SkSpan<DSLExpression> argArray)238 DSLPossibleExpression DSLType::Construct(DSLType type, SkSpan<DSLExpression> argArray) {
239 SkSL::ExpressionArray skslArgs;
240 skslArgs.reserve_back(argArray.size());
241
242 for (DSLExpression& arg : argArray) {
243 if (!arg.hasValue()) {
244 return DSLPossibleExpression(nullptr);
245 }
246 skslArgs.push_back(arg.release());
247 }
248 return SkSL::Constructor::Convert(ThreadContext::Context(), /*line=*/-1, type.skslType(),
249 std::move(skslArgs));
250 }
251
Array(const DSLType & base,int count,PositionInfo pos)252 DSLType Array(const DSLType& base, int count, PositionInfo pos) {
253 count = base.skslType().convertArraySize(ThreadContext::Context(),
254 DSLExpression(count, pos).release());
255 ThreadContext::ReportErrors(pos);
256 if (!count) {
257 return DSLType(kPoison_Type);
258 }
259 return ThreadContext::SymbolTable()->addArrayDimension(&base.skslType(), count);
260 }
261
Struct(skstd::string_view name,SkSpan<DSLField> fields,PositionInfo pos)262 DSLType Struct(skstd::string_view name, SkSpan<DSLField> fields, PositionInfo pos) {
263 std::vector<SkSL::Type::Field> skslFields;
264 skslFields.reserve(fields.size());
265 for (const DSLField& field : fields) {
266 if (field.fModifiers.fModifiers.fFlags != Modifiers::kNo_Flag) {
267 String desc = field.fModifiers.fModifiers.description();
268 desc.pop_back(); // remove trailing space
269 ThreadContext::ReportError("modifier '" + desc + "' is not permitted on a struct field",
270 field.fPosition);
271 }
272
273 const SkSL::Type& type = field.fType.skslType();
274 if (type.isOpaque()) {
275 ThreadContext::ReportError("opaque type '" + type.displayName() +
276 "' is not permitted in a struct", field.fPosition);
277 }
278 skslFields.emplace_back(field.fModifiers.fModifiers, field.fName, &type);
279 }
280 const SkSL::Type* result = ThreadContext::SymbolTable()->add(Type::MakeStructType(pos.line(),
281 name, skslFields));
282 if (result->isTooDeeplyNested()) {
283 ThreadContext::ReportError("struct '" + String(name) + "' is too deeply nested", pos);
284 }
285 ThreadContext::ProgramElements().push_back(std::make_unique<SkSL::StructDefinition>(/*line=*/-1,
286 *result));
287 return result;
288 }
289
290 } // namespace dsl
291
292 } // namespace SkSL
293