• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2020 Google LLC
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "include/sksl/DSLType.h"
9 
10 #include "src/sksl/SkSLThreadContext.h"
11 #include "src/sksl/ir/SkSLConstructor.h"
12 #include "src/sksl/ir/SkSLStructDefinition.h"
13 
14 namespace SkSL {
15 
16 namespace dsl {
17 
verify_type(const Context & context,const SkSL::Type * type,bool allowPrivateTypes,PositionInfo pos)18 static const SkSL::Type* verify_type(const Context& context,
19                                      const SkSL::Type* type,
20                                      bool allowPrivateTypes,
21                                      PositionInfo pos) {
22     if (!context.fConfig->fIsBuiltinCode) {
23         if (!allowPrivateTypes && type->isPrivate()) {
24             context.fErrors->error("type '" + String(type->name()) + "' is private", pos);
25             return context.fTypes.fPoison.get();
26         }
27         if (!type->isAllowedInES2(context)) {
28             context.fErrors->error("type '" + String(type->name()) + "' is not supported", pos);
29             return context.fTypes.fPoison.get();
30         }
31     }
32     return type;
33 }
34 
find_type(const Context & context,skstd::string_view name,PositionInfo pos)35 static const SkSL::Type* find_type(const Context& context,
36                                    skstd::string_view name,
37                                    PositionInfo pos) {
38     const Symbol* symbol = (*ThreadContext::SymbolTable())[name];
39     if (!symbol) {
40         context.fErrors->error(String::printf("no symbol named '%.*s'",
41                                               (int)name.length(), name.data()), pos);
42         return context.fTypes.fPoison.get();
43     }
44     if (!symbol->is<SkSL::Type>()) {
45         context.fErrors->error(String::printf("symbol '%.*s' is not a type",
46                                               (int)name.length(), name.data()), pos);
47         return context.fTypes.fPoison.get();
48     }
49     const SkSL::Type* type = &symbol->as<SkSL::Type>();
50     return verify_type(context, type, /*allowPrivateTypes=*/false, pos);
51 }
52 
find_type(const Context & context,skstd::string_view name,Modifiers * modifiers,PositionInfo pos)53 static const SkSL::Type* find_type(const Context& context,
54                                    skstd::string_view name,
55                                    Modifiers* modifiers,
56                                    PositionInfo pos) {
57     const Type* type = find_type(context, name, pos);
58     type = type->applyPrecisionQualifiers(context, modifiers, ThreadContext::SymbolTable().get(),
59                                           pos.line());
60     ThreadContext::ReportErrors(pos);
61     return type;
62 }
63 
get_type_from_type_constant(const Context & context,TypeConstant tc)64 static const SkSL::Type* get_type_from_type_constant(const Context& context, TypeConstant tc) {
65     switch (tc) {
66         case kBool_Type:
67             return context.fTypes.fBool.get();
68         case kBool2_Type:
69             return context.fTypes.fBool2.get();
70         case kBool3_Type:
71             return context.fTypes.fBool3.get();
72         case kBool4_Type:
73             return context.fTypes.fBool4.get();
74         case kHalf_Type:
75             return context.fTypes.fHalf.get();
76         case kHalf2_Type:
77             return context.fTypes.fHalf2.get();
78         case kHalf3_Type:
79             return context.fTypes.fHalf3.get();
80         case kHalf4_Type:
81             return context.fTypes.fHalf4.get();
82         case kHalf2x2_Type:
83             return context.fTypes.fHalf2x2.get();
84         case kHalf3x2_Type:
85             return context.fTypes.fHalf3x2.get();
86         case kHalf4x2_Type:
87             return context.fTypes.fHalf4x2.get();
88         case kHalf2x3_Type:
89             return context.fTypes.fHalf2x3.get();
90         case kHalf3x3_Type:
91             return context.fTypes.fHalf3x3.get();
92         case kHalf4x3_Type:
93             return context.fTypes.fHalf4x3.get();
94         case kHalf2x4_Type:
95             return context.fTypes.fHalf2x4.get();
96         case kHalf3x4_Type:
97             return context.fTypes.fHalf3x4.get();
98         case kHalf4x4_Type:
99             return context.fTypes.fHalf4x4.get();
100         case kFloat_Type:
101             return context.fTypes.fFloat.get();
102         case kFloat2_Type:
103             return context.fTypes.fFloat2.get();
104         case kFloat3_Type:
105             return context.fTypes.fFloat3.get();
106         case kFloat4_Type:
107             return context.fTypes.fFloat4.get();
108         case kFloat2x2_Type:
109             return context.fTypes.fFloat2x2.get();
110         case kFloat3x2_Type:
111             return context.fTypes.fFloat3x2.get();
112         case kFloat4x2_Type:
113             return context.fTypes.fFloat4x2.get();
114         case kFloat2x3_Type:
115             return context.fTypes.fFloat2x3.get();
116         case kFloat3x3_Type:
117             return context.fTypes.fFloat3x3.get();
118         case kFloat4x3_Type:
119             return context.fTypes.fFloat4x3.get();
120         case kFloat2x4_Type:
121             return context.fTypes.fFloat2x4.get();
122         case kFloat3x4_Type:
123             return context.fTypes.fFloat3x4.get();
124         case kFloat4x4_Type:
125             return context.fTypes.fFloat4x4.get();
126         case kInt_Type:
127             return context.fTypes.fInt.get();
128         case kInt2_Type:
129             return context.fTypes.fInt2.get();
130         case kInt3_Type:
131             return context.fTypes.fInt3.get();
132         case kInt4_Type:
133             return context.fTypes.fInt4.get();
134         case kShader_Type:
135             return context.fTypes.fShader.get();
136         case kShort_Type:
137             return context.fTypes.fShort.get();
138         case kShort2_Type:
139             return context.fTypes.fShort2.get();
140         case kShort3_Type:
141             return context.fTypes.fShort3.get();
142         case kShort4_Type:
143             return context.fTypes.fShort4.get();
144         case kUInt_Type:
145             return context.fTypes.fUInt.get();
146         case kUInt2_Type:
147             return context.fTypes.fUInt2.get();
148         case kUInt3_Type:
149             return context.fTypes.fUInt3.get();
150         case kUInt4_Type:
151             return context.fTypes.fUInt4.get();
152         case kUShort_Type:
153             return context.fTypes.fUShort.get();
154         case kUShort2_Type:
155             return context.fTypes.fUShort2.get();
156         case kUShort3_Type:
157             return context.fTypes.fUShort3.get();
158         case kUShort4_Type:
159             return context.fTypes.fUShort4.get();
160         case kVoid_Type:
161             return context.fTypes.fVoid.get();
162         case kPoison_Type:
163             return context.fTypes.fPoison.get();
164         default:
165             SkUNREACHABLE;
166     }
167 }
168 
DSLType(skstd::string_view name)169 DSLType::DSLType(skstd::string_view name)
170         : fSkSLType(find_type(ThreadContext::Context(), name, PositionInfo())) {}
171 
DSLType(skstd::string_view name,DSLModifiers * modifiers,PositionInfo position)172 DSLType::DSLType(skstd::string_view name, DSLModifiers* modifiers, PositionInfo position)
173         : fSkSLType(find_type(ThreadContext::Context(), name, &modifiers->fModifiers, position)) {}
174 
DSLType(const SkSL::Type * type)175 DSLType::DSLType(const SkSL::Type* type)
176         : fSkSLType(verify_type(ThreadContext::Context(), type, /*allowPrivateTypes=*/true,
177                                 PositionInfo())) {}
178 
isBoolean() const179 bool DSLType::isBoolean() const {
180     return this->skslType().isBoolean();
181 }
182 
isNumber() const183 bool DSLType::isNumber() const {
184     return this->skslType().isNumber();
185 }
186 
isFloat() const187 bool DSLType::isFloat() const {
188     return this->skslType().isFloat();
189 }
190 
isSigned() const191 bool DSLType::isSigned() const {
192     return this->skslType().isSigned();
193 }
194 
isUnsigned() const195 bool DSLType::isUnsigned() const {
196     return this->skslType().isUnsigned();
197 }
198 
isInteger() const199 bool DSLType::isInteger() const {
200     return this->skslType().isInteger();
201 }
202 
isScalar() const203 bool DSLType::isScalar() const {
204     return this->skslType().isScalar();
205 }
206 
isVector() const207 bool DSLType::isVector() const {
208     return this->skslType().isVector();
209 }
210 
isMatrix() const211 bool DSLType::isMatrix() const {
212     return this->skslType().isMatrix();
213 }
214 
isArray() const215 bool DSLType::isArray() const {
216     return this->skslType().isArray();
217 }
218 
isStruct() const219 bool DSLType::isStruct() const {
220     return this->skslType().isStruct();
221 }
222 
isEffectChild() const223 bool DSLType::isEffectChild() const {
224     return this->skslType().isEffectChild();
225 }
226 
skslType() const227 const SkSL::Type& DSLType::skslType() const {
228     if (fSkSLType) {
229         return *fSkSLType;
230     }
231     const Context& context = ThreadContext::Context();
232     return *verify_type(context,
233                         get_type_from_type_constant(context, fTypeConstant),
234                         /*allowPrivateTypes=*/true,
235                         PositionInfo());
236 }
237 
Construct(DSLType type,SkSpan<DSLExpression> argArray)238 DSLPossibleExpression DSLType::Construct(DSLType type, SkSpan<DSLExpression> argArray) {
239     SkSL::ExpressionArray skslArgs;
240     skslArgs.reserve_back(argArray.size());
241 
242     for (DSLExpression& arg : argArray) {
243         if (!arg.hasValue()) {
244             return DSLPossibleExpression(nullptr);
245         }
246         skslArgs.push_back(arg.release());
247     }
248     return SkSL::Constructor::Convert(ThreadContext::Context(), /*line=*/-1, type.skslType(),
249             std::move(skslArgs));
250 }
251 
Array(const DSLType & base,int count,PositionInfo pos)252 DSLType Array(const DSLType& base, int count, PositionInfo pos) {
253     count = base.skslType().convertArraySize(ThreadContext::Context(),
254             DSLExpression(count, pos).release());
255     ThreadContext::ReportErrors(pos);
256     if (!count) {
257         return DSLType(kPoison_Type);
258     }
259     return ThreadContext::SymbolTable()->addArrayDimension(&base.skslType(), count);
260 }
261 
Struct(skstd::string_view name,SkSpan<DSLField> fields,PositionInfo pos)262 DSLType Struct(skstd::string_view name, SkSpan<DSLField> fields, PositionInfo pos) {
263     std::vector<SkSL::Type::Field> skslFields;
264     skslFields.reserve(fields.size());
265     for (const DSLField& field : fields) {
266         if (field.fModifiers.fModifiers.fFlags != Modifiers::kNo_Flag) {
267             String desc = field.fModifiers.fModifiers.description();
268             desc.pop_back();  // remove trailing space
269             ThreadContext::ReportError("modifier '" + desc + "' is not permitted on a struct field",
270                     field.fPosition);
271         }
272 
273         const SkSL::Type& type = field.fType.skslType();
274         if (type.isOpaque()) {
275             ThreadContext::ReportError("opaque type '" + type.displayName() +
276                     "' is not permitted in a struct", field.fPosition);
277         }
278         skslFields.emplace_back(field.fModifiers.fModifiers, field.fName, &type);
279     }
280     const SkSL::Type* result = ThreadContext::SymbolTable()->add(Type::MakeStructType(pos.line(),
281             name, skslFields));
282     if (result->isTooDeeplyNested()) {
283         ThreadContext::ReportError("struct '" + String(name) + "' is too deeply nested", pos);
284     }
285     ThreadContext::ProgramElements().push_back(std::make_unique<SkSL::StructDefinition>(/*line=*/-1,
286             *result));
287     return result;
288 }
289 
290 } // namespace dsl
291 
292 } // namespace SkSL
293