• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2020 Google LLC
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "include/sksl/DSLType.h"
9 
10 #include "src/sksl/SkSLThreadContext.h"
11 #include "src/sksl/ir/SkSLConstructor.h"
12 #include "src/sksl/ir/SkSLStructDefinition.h"
13 
14 namespace SkSL {
15 
16 namespace dsl {
17 
verify_type(const Context & context,const SkSL::Type * type,bool allowPrivateTypes,PositionInfo pos)18 static const SkSL::Type* verify_type(const Context& context,
19                                      const SkSL::Type* type,
20                                      bool allowPrivateTypes,
21                                      PositionInfo pos) {
22     if (!context.fConfig->fIsBuiltinCode) {
23         if (!allowPrivateTypes && type->isPrivate()) {
24             context.fErrors->error("type '" + std::string(type->name()) + "' is private", pos);
25             return context.fTypes.fPoison.get();
26         }
27         if (!type->isAllowedInES2(context)) {
28             context.fErrors->error("type '" + std::string(type->name()) + "' is not supported",
29                                    pos);
30             return context.fTypes.fPoison.get();
31         }
32     }
33     return type;
34 }
35 
find_type(const Context & context,std::string_view name,PositionInfo pos)36 static const SkSL::Type* find_type(const Context& context,
37                                    std::string_view name,
38                                    PositionInfo pos) {
39     const Symbol* symbol = (*ThreadContext::SymbolTable())[name];
40     if (!symbol) {
41         context.fErrors->error(String::printf("no symbol named '%.*s'",
42                                               (int)name.length(), name.data()), pos);
43         return context.fTypes.fPoison.get();
44     }
45     if (!symbol->is<SkSL::Type>()) {
46         context.fErrors->error(String::printf("symbol '%.*s' is not a type",
47                                               (int)name.length(), name.data()), pos);
48         return context.fTypes.fPoison.get();
49     }
50     const SkSL::Type* type = &symbol->as<SkSL::Type>();
51     return verify_type(context, type, /*allowPrivateTypes=*/false, pos);
52 }
53 
find_type(const Context & context,std::string_view name,Modifiers * modifiers,PositionInfo pos)54 static const SkSL::Type* find_type(const Context& context,
55                                    std::string_view name,
56                                    Modifiers* modifiers,
57                                    PositionInfo pos) {
58     const Type* type = find_type(context, name, pos);
59     type = type->applyPrecisionQualifiers(context, modifiers, ThreadContext::SymbolTable().get(),
60                                           pos.line());
61     ThreadContext::ReportErrors(pos);
62     return type;
63 }
64 
get_type_from_type_constant(const Context & context,TypeConstant tc)65 static const SkSL::Type* get_type_from_type_constant(const Context& context, TypeConstant tc) {
66     switch (tc) {
67         case kBool_Type:
68             return context.fTypes.fBool.get();
69         case kBool2_Type:
70             return context.fTypes.fBool2.get();
71         case kBool3_Type:
72             return context.fTypes.fBool3.get();
73         case kBool4_Type:
74             return context.fTypes.fBool4.get();
75         case kHalf_Type:
76             return context.fTypes.fHalf.get();
77         case kHalf2_Type:
78             return context.fTypes.fHalf2.get();
79         case kHalf3_Type:
80             return context.fTypes.fHalf3.get();
81         case kHalf4_Type:
82             return context.fTypes.fHalf4.get();
83         case kHalf2x2_Type:
84             return context.fTypes.fHalf2x2.get();
85         case kHalf3x2_Type:
86             return context.fTypes.fHalf3x2.get();
87         case kHalf4x2_Type:
88             return context.fTypes.fHalf4x2.get();
89         case kHalf2x3_Type:
90             return context.fTypes.fHalf2x3.get();
91         case kHalf3x3_Type:
92             return context.fTypes.fHalf3x3.get();
93         case kHalf4x3_Type:
94             return context.fTypes.fHalf4x3.get();
95         case kHalf2x4_Type:
96             return context.fTypes.fHalf2x4.get();
97         case kHalf3x4_Type:
98             return context.fTypes.fHalf3x4.get();
99         case kHalf4x4_Type:
100             return context.fTypes.fHalf4x4.get();
101         case kFloat_Type:
102             return context.fTypes.fFloat.get();
103         case kFloat2_Type:
104             return context.fTypes.fFloat2.get();
105         case kFloat3_Type:
106             return context.fTypes.fFloat3.get();
107         case kFloat4_Type:
108             return context.fTypes.fFloat4.get();
109         case kFloat2x2_Type:
110             return context.fTypes.fFloat2x2.get();
111         case kFloat3x2_Type:
112             return context.fTypes.fFloat3x2.get();
113         case kFloat4x2_Type:
114             return context.fTypes.fFloat4x2.get();
115         case kFloat2x3_Type:
116             return context.fTypes.fFloat2x3.get();
117         case kFloat3x3_Type:
118             return context.fTypes.fFloat3x3.get();
119         case kFloat4x3_Type:
120             return context.fTypes.fFloat4x3.get();
121         case kFloat2x4_Type:
122             return context.fTypes.fFloat2x4.get();
123         case kFloat3x4_Type:
124             return context.fTypes.fFloat3x4.get();
125         case kFloat4x4_Type:
126             return context.fTypes.fFloat4x4.get();
127         case kInt_Type:
128             return context.fTypes.fInt.get();
129         case kInt2_Type:
130             return context.fTypes.fInt2.get();
131         case kInt3_Type:
132             return context.fTypes.fInt3.get();
133         case kInt4_Type:
134             return context.fTypes.fInt4.get();
135         case kShader_Type:
136             return context.fTypes.fShader.get();
137         case kShort_Type:
138             return context.fTypes.fShort.get();
139         case kShort2_Type:
140             return context.fTypes.fShort2.get();
141         case kShort3_Type:
142             return context.fTypes.fShort3.get();
143         case kShort4_Type:
144             return context.fTypes.fShort4.get();
145         case kUInt_Type:
146             return context.fTypes.fUInt.get();
147         case kUInt2_Type:
148             return context.fTypes.fUInt2.get();
149         case kUInt3_Type:
150             return context.fTypes.fUInt3.get();
151         case kUInt4_Type:
152             return context.fTypes.fUInt4.get();
153         case kUShort_Type:
154             return context.fTypes.fUShort.get();
155         case kUShort2_Type:
156             return context.fTypes.fUShort2.get();
157         case kUShort3_Type:
158             return context.fTypes.fUShort3.get();
159         case kUShort4_Type:
160             return context.fTypes.fUShort4.get();
161         case kVoid_Type:
162             return context.fTypes.fVoid.get();
163         case kPoison_Type:
164             return context.fTypes.fPoison.get();
165         default:
166             SkUNREACHABLE;
167     }
168 }
169 
DSLType(std::string_view name)170 DSLType::DSLType(std::string_view name)
171         : fSkSLType(find_type(ThreadContext::Context(), name, PositionInfo())) {}
172 
DSLType(std::string_view name,DSLModifiers * modifiers,PositionInfo position)173 DSLType::DSLType(std::string_view name, DSLModifiers* modifiers, PositionInfo position)
174         : fSkSLType(find_type(ThreadContext::Context(), name, &modifiers->fModifiers, position)) {}
175 
DSLType(const SkSL::Type * type)176 DSLType::DSLType(const SkSL::Type* type)
177         : fSkSLType(verify_type(ThreadContext::Context(), type, /*allowPrivateTypes=*/true,
178                                 PositionInfo())) {}
179 
isBoolean() const180 bool DSLType::isBoolean() const {
181     return this->skslType().isBoolean();
182 }
183 
isNumber() const184 bool DSLType::isNumber() const {
185     return this->skslType().isNumber();
186 }
187 
isFloat() const188 bool DSLType::isFloat() const {
189     return this->skslType().isFloat();
190 }
191 
isSigned() const192 bool DSLType::isSigned() const {
193     return this->skslType().isSigned();
194 }
195 
isUnsigned() const196 bool DSLType::isUnsigned() const {
197     return this->skslType().isUnsigned();
198 }
199 
isInteger() const200 bool DSLType::isInteger() const {
201     return this->skslType().isInteger();
202 }
203 
isScalar() const204 bool DSLType::isScalar() const {
205     return this->skslType().isScalar();
206 }
207 
isVector() const208 bool DSLType::isVector() const {
209     return this->skslType().isVector();
210 }
211 
isMatrix() const212 bool DSLType::isMatrix() const {
213     return this->skslType().isMatrix();
214 }
215 
isArray() const216 bool DSLType::isArray() const {
217     return this->skslType().isArray();
218 }
219 
isStruct() const220 bool DSLType::isStruct() const {
221     return this->skslType().isStruct();
222 }
223 
isEffectChild() const224 bool DSLType::isEffectChild() const {
225     return this->skslType().isEffectChild();
226 }
227 
skslType() const228 const SkSL::Type& DSLType::skslType() const {
229     if (fSkSLType) {
230         return *fSkSLType;
231     }
232     const Context& context = ThreadContext::Context();
233     return *verify_type(context,
234                         get_type_from_type_constant(context, fTypeConstant),
235                         /*allowPrivateTypes=*/true,
236                         PositionInfo());
237 }
238 
Construct(DSLType type,SkSpan<DSLExpression> argArray)239 DSLPossibleExpression DSLType::Construct(DSLType type, SkSpan<DSLExpression> argArray) {
240     SkSL::ExpressionArray skslArgs;
241     skslArgs.reserve_back(argArray.size());
242 
243     for (DSLExpression& arg : argArray) {
244         if (!arg.hasValue()) {
245             return DSLPossibleExpression(nullptr);
246         }
247         skslArgs.push_back(arg.release());
248     }
249     return SkSL::Constructor::Convert(ThreadContext::Context(), /*line=*/-1, type.skslType(),
250             std::move(skslArgs));
251 }
252 
Array(const DSLType & base,int count,PositionInfo pos)253 DSLType Array(const DSLType& base, int count, PositionInfo pos) {
254     count = base.skslType().convertArraySize(ThreadContext::Context(),
255             DSLExpression(count, pos).release());
256     ThreadContext::ReportErrors(pos);
257     if (!count) {
258         return DSLType(kPoison_Type);
259     }
260     return ThreadContext::SymbolTable()->addArrayDimension(&base.skslType(), count);
261 }
262 
Struct(std::string_view name,SkSpan<DSLField> fields,PositionInfo pos)263 DSLType Struct(std::string_view name, SkSpan<DSLField> fields, PositionInfo pos) {
264     std::vector<SkSL::Type::Field> skslFields;
265     skslFields.reserve(fields.size());
266     for (const DSLField& field : fields) {
267         if (field.fModifiers.fModifiers.fFlags != Modifiers::kNo_Flag) {
268             std::string desc = field.fModifiers.fModifiers.description();
269             desc.pop_back();  // remove trailing space
270             ThreadContext::ReportError("modifier '" + desc + "' is not permitted on a struct field",
271                     field.fPosition);
272         }
273         if (field.fModifiers.fModifiers.fLayout.fFlags & Layout::kBinding_Flag) {
274             ThreadContext::ReportError(
275                     "layout qualifier 'binding' is not permitted on a struct field",
276                     field.fPosition);
277         }
278         if (field.fModifiers.fModifiers.fLayout.fFlags & Layout::kSet_Flag) {
279             ThreadContext::ReportError("layout qualifier 'set' is not permitted on a struct field",
280                                        field.fPosition);
281         }
282 
283         const SkSL::Type& type = field.fType.skslType();
284         if (type.isVoid()) {
285             ThreadContext::ReportError("type 'void' is not permitted in a struct", field.fPosition);
286         } else if (type.isOpaque()) {
287             ThreadContext::ReportError("opaque type '" + type.displayName() +
288                                        "' is not permitted in a struct", field.fPosition);
289         }
290         skslFields.emplace_back(field.fModifiers.fModifiers, field.fName, &type);
291     }
292     const SkSL::Type* result = ThreadContext::SymbolTable()->add(Type::MakeStructType(pos.line(),
293             name, skslFields));
294     if (result->isTooDeeplyNested()) {
295         ThreadContext::ReportError("struct '" + std::string(name) + "' is too deeply nested", pos);
296     }
297     ThreadContext::ProgramElements().push_back(std::make_unique<SkSL::StructDefinition>(/*line=*/-1,
298             *result));
299     return result;
300 }
301 
302 } // namespace dsl
303 
304 } // namespace SkSL
305