• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2020 Google LLC
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "include/sksl/DSLType.h"
9 
10 #include "src/sksl/SkSLThreadContext.h"
11 #include "src/sksl/ir/SkSLConstructor.h"
12 #include "src/sksl/ir/SkSLStructDefinition.h"
13 
14 namespace SkSL {
15 
16 namespace dsl {
17 
verify_type(const Context & context,const SkSL::Type * type,bool allowPrivateTypes,PositionInfo pos)18 static const SkSL::Type* verify_type(const Context& context,
19                                      const SkSL::Type* type,
20                                      bool allowPrivateTypes,
21                                      PositionInfo pos) {
22     if (!context.fConfig->fIsBuiltinCode) {
23         if (!allowPrivateTypes && type->isPrivate()) {
24             context.fErrors->error("type '" + String(type->name()) + "' is private", pos);
25             return context.fTypes.fPoison.get();
26         }
27         if (!type->isAllowedInES2(context)) {
28             context.fErrors->error("type '" + String(type->name()) + "' is not supported", pos);
29             return context.fTypes.fPoison.get();
30         }
31     }
32     return type;
33 }
34 
find_type(const Context & context,skstd::string_view name,PositionInfo pos)35 static const SkSL::Type* find_type(const Context& context,
36                                    skstd::string_view name,
37                                    PositionInfo pos) {
38     const Symbol* symbol = (*ThreadContext::SymbolTable())[name];
39     if (!symbol) {
40         context.fErrors->error(String::printf("no symbol named '%.*s'",
41                                               (int)name.length(), name.data()), pos);
42         return context.fTypes.fPoison.get();
43     }
44     if (!symbol->is<SkSL::Type>()) {
45         context.fErrors->error(String::printf("symbol '%.*s' is not a type",
46                                               (int)name.length(), name.data()), pos);
47         return context.fTypes.fPoison.get();
48     }
49     const SkSL::Type* type = &symbol->as<SkSL::Type>();
50     return verify_type(context, type, /*allowPrivateTypes=*/false, pos);
51 }
52 
find_type(const Context & context,skstd::string_view name,Modifiers * modifiers,PositionInfo pos)53 static const SkSL::Type* find_type(const Context& context,
54                                    skstd::string_view name,
55                                    Modifiers* modifiers,
56                                    PositionInfo pos) {
57     const Type* type = find_type(context, name, pos);
58     type = type->applyPrecisionQualifiers(context, modifiers, ThreadContext::SymbolTable().get(),
59                                           pos.line());
60     ThreadContext::ReportErrors(pos);
61     return type;
62 }
63 
get_type_from_type_constant(const Context & context,TypeConstant tc)64 static const SkSL::Type* get_type_from_type_constant(const Context& context, TypeConstant tc) {
65     switch (tc) {
66         case kBool_Type:
67             return context.fTypes.fBool.get();
68         case kBool2_Type:
69             return context.fTypes.fBool2.get();
70         case kBool3_Type:
71             return context.fTypes.fBool3.get();
72         case kBool4_Type:
73             return context.fTypes.fBool4.get();
74         case kHalf_Type:
75             return context.fTypes.fHalf.get();
76         case kHalf2_Type:
77             return context.fTypes.fHalf2.get();
78         case kHalf3_Type:
79             return context.fTypes.fHalf3.get();
80         case kHalf4_Type:
81             return context.fTypes.fHalf4.get();
82         case kHalf2x2_Type:
83             return context.fTypes.fHalf2x2.get();
84         case kHalf3x2_Type:
85             return context.fTypes.fHalf3x2.get();
86         case kHalf4x2_Type:
87             return context.fTypes.fHalf4x2.get();
88         case kHalf2x3_Type:
89             return context.fTypes.fHalf2x3.get();
90         case kHalf3x3_Type:
91             return context.fTypes.fHalf3x3.get();
92         case kHalf4x3_Type:
93             return context.fTypes.fHalf4x3.get();
94         case kHalf2x4_Type:
95             return context.fTypes.fHalf2x4.get();
96         case kHalf3x4_Type:
97             return context.fTypes.fHalf3x4.get();
98         case kHalf4x4_Type:
99             return context.fTypes.fHalf4x4.get();
100         case kFloat_Type:
101             return context.fTypes.fFloat.get();
102         case kFloat2_Type:
103             return context.fTypes.fFloat2.get();
104         case kFloat3_Type:
105             return context.fTypes.fFloat3.get();
106         case kFloat4_Type:
107             return context.fTypes.fFloat4.get();
108         case kFloat2x2_Type:
109             return context.fTypes.fFloat2x2.get();
110         case kFloat3x2_Type:
111             return context.fTypes.fFloat3x2.get();
112         case kFloat4x2_Type:
113             return context.fTypes.fFloat4x2.get();
114         case kFloat2x3_Type:
115             return context.fTypes.fFloat2x3.get();
116         case kFloat3x3_Type:
117             return context.fTypes.fFloat3x3.get();
118         case kFloat4x3_Type:
119             return context.fTypes.fFloat4x3.get();
120         case kFloat2x4_Type:
121             return context.fTypes.fFloat2x4.get();
122         case kFloat3x4_Type:
123             return context.fTypes.fFloat3x4.get();
124         case kFloat4x4_Type:
125             return context.fTypes.fFloat4x4.get();
126         case kInt_Type:
127             return context.fTypes.fInt.get();
128         case kInt2_Type:
129             return context.fTypes.fInt2.get();
130         case kInt3_Type:
131             return context.fTypes.fInt3.get();
132         case kInt4_Type:
133             return context.fTypes.fInt4.get();
134         case kShader_Type:
135             return context.fTypes.fShader.get();
136         case kShort_Type:
137             return context.fTypes.fShort.get();
138         case kShort2_Type:
139             return context.fTypes.fShort2.get();
140         case kShort3_Type:
141             return context.fTypes.fShort3.get();
142         case kShort4_Type:
143             return context.fTypes.fShort4.get();
144         case kUInt_Type:
145             return context.fTypes.fUInt.get();
146         case kUInt2_Type:
147             return context.fTypes.fUInt2.get();
148         case kUInt3_Type:
149             return context.fTypes.fUInt3.get();
150         case kUInt4_Type:
151             return context.fTypes.fUInt4.get();
152         case kUShort_Type:
153             return context.fTypes.fUShort.get();
154         case kUShort2_Type:
155             return context.fTypes.fUShort2.get();
156         case kUShort3_Type:
157             return context.fTypes.fUShort3.get();
158         case kUShort4_Type:
159             return context.fTypes.fUShort4.get();
160         case kVoid_Type:
161             return context.fTypes.fVoid.get();
162         case kPoison_Type:
163             return context.fTypes.fPoison.get();
164         default:
165             SkUNREACHABLE;
166     }
167 }
168 
DSLType(skstd::string_view name)169 DSLType::DSLType(skstd::string_view name)
170         : fSkSLType(find_type(ThreadContext::Context(), name, PositionInfo())) {}
171 
DSLType(skstd::string_view name,DSLModifiers * modifiers,PositionInfo position)172 DSLType::DSLType(skstd::string_view name, DSLModifiers* modifiers, PositionInfo position)
173         : fSkSLType(find_type(ThreadContext::Context(), name, &modifiers->fModifiers, position)) {}
174 
DSLType(const SkSL::Type * type)175 DSLType::DSLType(const SkSL::Type* type)
176         : fSkSLType(verify_type(ThreadContext::Context(), type, /*allowPrivateTypes=*/true,
177                                 PositionInfo())) {}
178 
isBoolean() const179 bool DSLType::isBoolean() const {
180     return this->skslType().isBoolean();
181 }
182 
isNumber() const183 bool DSLType::isNumber() const {
184     return this->skslType().isNumber();
185 }
186 
isFloat() const187 bool DSLType::isFloat() const {
188     return this->skslType().isFloat();
189 }
190 
isSigned() const191 bool DSLType::isSigned() const {
192     return this->skslType().isSigned();
193 }
194 
isUnsigned() const195 bool DSLType::isUnsigned() const {
196     return this->skslType().isUnsigned();
197 }
198 
isInteger() const199 bool DSLType::isInteger() const {
200     return this->skslType().isInteger();
201 }
202 
isScalar() const203 bool DSLType::isScalar() const {
204     return this->skslType().isScalar();
205 }
206 
isVector() const207 bool DSLType::isVector() const {
208     return this->skslType().isVector();
209 }
210 
isMatrix() const211 bool DSLType::isMatrix() const {
212     return this->skslType().isMatrix();
213 }
214 
isArray() const215 bool DSLType::isArray() const {
216     return this->skslType().isArray();
217 }
218 
219 #ifdef SKSL_EXT
isUnsizedArray() const220 bool DSLType::isUnsizedArray() const {
221     return this->skslType().isUnsizedArray();
222 }
223 #endif
224 
isStruct() const225 bool DSLType::isStruct() const {
226     return this->skslType().isStruct();
227 }
228 
isEffectChild() const229 bool DSLType::isEffectChild() const {
230     return this->skslType().isEffectChild();
231 }
232 
skslType() const233 const SkSL::Type& DSLType::skslType() const {
234     if (fSkSLType) {
235         return *fSkSLType;
236     }
237     const Context& context = ThreadContext::Context();
238     return *verify_type(context,
239                         get_type_from_type_constant(context, fTypeConstant),
240                         /*allowPrivateTypes=*/true,
241                         PositionInfo());
242 }
243 
Construct(DSLType type,SkSpan<DSLExpression> argArray)244 DSLPossibleExpression DSLType::Construct(DSLType type, SkSpan<DSLExpression> argArray) {
245     SkSL::ExpressionArray skslArgs;
246     skslArgs.reserve_back(argArray.size());
247 
248     for (DSLExpression& arg : argArray) {
249         if (!arg.hasValue()) {
250             return DSLPossibleExpression(nullptr);
251         }
252         skslArgs.push_back(arg.release());
253     }
254     return SkSL::Constructor::Convert(ThreadContext::Context(), /*line=*/-1, type.skslType(),
255             std::move(skslArgs));
256 }
257 
Array(const DSLType & base,int count,PositionInfo pos)258 DSLType Array(const DSLType& base, int count, PositionInfo pos) {
259     count = base.skslType().convertArraySize(ThreadContext::Context(),
260             DSLExpression(count, pos).release());
261     ThreadContext::ReportErrors(pos);
262     if (!count) {
263         return DSLType(kPoison_Type);
264     }
265     return ThreadContext::SymbolTable()->addArrayDimension(&base.skslType(), count);
266 }
267 
268 #ifdef SKSL_EXT
UnsizedArray(const DSLType & base,PositionInfo pos)269 DSLType UnsizedArray(const DSLType& base, PositionInfo pos) {
270     ThreadContext::ReportErrors(pos);
271     return ThreadContext::SymbolTable()->addArrayDimension(&base.skslType(), Type::kUnsizedArray);
272 }
273 #endif
274 
Struct(skstd::string_view name,SkSpan<DSLField> fields,PositionInfo pos)275 DSLType Struct(skstd::string_view name, SkSpan<DSLField> fields, PositionInfo pos) {
276     std::vector<SkSL::Type::Field> skslFields;
277     skslFields.reserve(fields.size());
278     for (const DSLField& field : fields) {
279         if (field.fModifiers.fModifiers.fFlags != Modifiers::kNo_Flag) {
280             String desc = field.fModifiers.fModifiers.description();
281             desc.pop_back();  // remove trailing space
282             ThreadContext::ReportError("modifier '" + desc + "' is not permitted on a struct field",
283                     field.fPosition);
284         }
285 
286         const SkSL::Type& type = field.fType.skslType();
287         if (type.isOpaque()) {
288             ThreadContext::ReportError("opaque type '" + type.displayName() +
289                     "' is not permitted in a struct", field.fPosition);
290         }
291         skslFields.emplace_back(field.fModifiers.fModifiers, field.fName, &type);
292     }
293     const SkSL::Type* result = ThreadContext::SymbolTable()->add(Type::MakeStructType(pos.line(),
294             name, skslFields));
295     if (result->isTooDeeplyNested()) {
296         ThreadContext::ReportError("struct '" + String(name) + "' is too deeply nested", pos);
297     }
298     ThreadContext::ProgramElements().push_back(std::make_unique<SkSL::StructDefinition>(/*line=*/-1,
299             *result));
300     return result;
301 }
302 
303 } // namespace dsl
304 
305 } // namespace SkSL
306