1 /*
2 * Copyright 2020 Google LLC
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8 #include "include/sksl/DSLType.h"
9
10 #include "src/sksl/SkSLThreadContext.h"
11 #include "src/sksl/ir/SkSLConstructor.h"
12 #include "src/sksl/ir/SkSLStructDefinition.h"
13
14 namespace SkSL {
15
16 namespace dsl {
17
verify_type(const Context & context,const SkSL::Type * type,bool allowPrivateTypes,PositionInfo pos)18 static const SkSL::Type* verify_type(const Context& context,
19 const SkSL::Type* type,
20 bool allowPrivateTypes,
21 PositionInfo pos) {
22 if (!context.fConfig->fIsBuiltinCode) {
23 if (!allowPrivateTypes && type->isPrivate()) {
24 context.fErrors->error("type '" + std::string(type->name()) + "' is private", pos);
25 return context.fTypes.fPoison.get();
26 }
27 if (!type->isAllowedInES2(context)) {
28 context.fErrors->error("type '" + std::string(type->name()) + "' is not supported",
29 pos);
30 return context.fTypes.fPoison.get();
31 }
32 }
33 return type;
34 }
35
find_type(const Context & context,std::string_view name,PositionInfo pos)36 static const SkSL::Type* find_type(const Context& context,
37 std::string_view name,
38 PositionInfo pos) {
39 const Symbol* symbol = (*ThreadContext::SymbolTable())[name];
40 if (!symbol) {
41 context.fErrors->error(String::printf("no symbol named '%.*s'",
42 (int)name.length(), name.data()), pos);
43 return context.fTypes.fPoison.get();
44 }
45 if (!symbol->is<SkSL::Type>()) {
46 context.fErrors->error(String::printf("symbol '%.*s' is not a type",
47 (int)name.length(), name.data()), pos);
48 return context.fTypes.fPoison.get();
49 }
50 const SkSL::Type* type = &symbol->as<SkSL::Type>();
51 return verify_type(context, type, /*allowPrivateTypes=*/false, pos);
52 }
53
find_type(const Context & context,std::string_view name,Modifiers * modifiers,PositionInfo pos)54 static const SkSL::Type* find_type(const Context& context,
55 std::string_view name,
56 Modifiers* modifiers,
57 PositionInfo pos) {
58 const Type* type = find_type(context, name, pos);
59 type = type->applyPrecisionQualifiers(context, modifiers, ThreadContext::SymbolTable().get(),
60 pos.line());
61 ThreadContext::ReportErrors(pos);
62 return type;
63 }
64
get_type_from_type_constant(const Context & context,TypeConstant tc)65 static const SkSL::Type* get_type_from_type_constant(const Context& context, TypeConstant tc) {
66 switch (tc) {
67 case kBool_Type:
68 return context.fTypes.fBool.get();
69 case kBool2_Type:
70 return context.fTypes.fBool2.get();
71 case kBool3_Type:
72 return context.fTypes.fBool3.get();
73 case kBool4_Type:
74 return context.fTypes.fBool4.get();
75 case kHalf_Type:
76 return context.fTypes.fHalf.get();
77 case kHalf2_Type:
78 return context.fTypes.fHalf2.get();
79 case kHalf3_Type:
80 return context.fTypes.fHalf3.get();
81 case kHalf4_Type:
82 return context.fTypes.fHalf4.get();
83 case kHalf2x2_Type:
84 return context.fTypes.fHalf2x2.get();
85 case kHalf3x2_Type:
86 return context.fTypes.fHalf3x2.get();
87 case kHalf4x2_Type:
88 return context.fTypes.fHalf4x2.get();
89 case kHalf2x3_Type:
90 return context.fTypes.fHalf2x3.get();
91 case kHalf3x3_Type:
92 return context.fTypes.fHalf3x3.get();
93 case kHalf4x3_Type:
94 return context.fTypes.fHalf4x3.get();
95 case kHalf2x4_Type:
96 return context.fTypes.fHalf2x4.get();
97 case kHalf3x4_Type:
98 return context.fTypes.fHalf3x4.get();
99 case kHalf4x4_Type:
100 return context.fTypes.fHalf4x4.get();
101 case kFloat_Type:
102 return context.fTypes.fFloat.get();
103 case kFloat2_Type:
104 return context.fTypes.fFloat2.get();
105 case kFloat3_Type:
106 return context.fTypes.fFloat3.get();
107 case kFloat4_Type:
108 return context.fTypes.fFloat4.get();
109 case kFloat2x2_Type:
110 return context.fTypes.fFloat2x2.get();
111 case kFloat3x2_Type:
112 return context.fTypes.fFloat3x2.get();
113 case kFloat4x2_Type:
114 return context.fTypes.fFloat4x2.get();
115 case kFloat2x3_Type:
116 return context.fTypes.fFloat2x3.get();
117 case kFloat3x3_Type:
118 return context.fTypes.fFloat3x3.get();
119 case kFloat4x3_Type:
120 return context.fTypes.fFloat4x3.get();
121 case kFloat2x4_Type:
122 return context.fTypes.fFloat2x4.get();
123 case kFloat3x4_Type:
124 return context.fTypes.fFloat3x4.get();
125 case kFloat4x4_Type:
126 return context.fTypes.fFloat4x4.get();
127 case kInt_Type:
128 return context.fTypes.fInt.get();
129 case kInt2_Type:
130 return context.fTypes.fInt2.get();
131 case kInt3_Type:
132 return context.fTypes.fInt3.get();
133 case kInt4_Type:
134 return context.fTypes.fInt4.get();
135 case kShader_Type:
136 return context.fTypes.fShader.get();
137 case kShort_Type:
138 return context.fTypes.fShort.get();
139 case kShort2_Type:
140 return context.fTypes.fShort2.get();
141 case kShort3_Type:
142 return context.fTypes.fShort3.get();
143 case kShort4_Type:
144 return context.fTypes.fShort4.get();
145 case kUInt_Type:
146 return context.fTypes.fUInt.get();
147 case kUInt2_Type:
148 return context.fTypes.fUInt2.get();
149 case kUInt3_Type:
150 return context.fTypes.fUInt3.get();
151 case kUInt4_Type:
152 return context.fTypes.fUInt4.get();
153 case kUShort_Type:
154 return context.fTypes.fUShort.get();
155 case kUShort2_Type:
156 return context.fTypes.fUShort2.get();
157 case kUShort3_Type:
158 return context.fTypes.fUShort3.get();
159 case kUShort4_Type:
160 return context.fTypes.fUShort4.get();
161 case kVoid_Type:
162 return context.fTypes.fVoid.get();
163 case kPoison_Type:
164 return context.fTypes.fPoison.get();
165 default:
166 SkUNREACHABLE;
167 }
168 }
169
DSLType(std::string_view name)170 DSLType::DSLType(std::string_view name)
171 : fSkSLType(find_type(ThreadContext::Context(), name, PositionInfo())) {}
172
DSLType(std::string_view name,DSLModifiers * modifiers,PositionInfo position)173 DSLType::DSLType(std::string_view name, DSLModifiers* modifiers, PositionInfo position)
174 : fSkSLType(find_type(ThreadContext::Context(), name, &modifiers->fModifiers, position)) {}
175
DSLType(const SkSL::Type * type)176 DSLType::DSLType(const SkSL::Type* type)
177 : fSkSLType(verify_type(ThreadContext::Context(), type, /*allowPrivateTypes=*/true,
178 PositionInfo())) {}
179
isBoolean() const180 bool DSLType::isBoolean() const {
181 return this->skslType().isBoolean();
182 }
183
isNumber() const184 bool DSLType::isNumber() const {
185 return this->skslType().isNumber();
186 }
187
isFloat() const188 bool DSLType::isFloat() const {
189 return this->skslType().isFloat();
190 }
191
isSigned() const192 bool DSLType::isSigned() const {
193 return this->skslType().isSigned();
194 }
195
isUnsigned() const196 bool DSLType::isUnsigned() const {
197 return this->skslType().isUnsigned();
198 }
199
isInteger() const200 bool DSLType::isInteger() const {
201 return this->skslType().isInteger();
202 }
203
isScalar() const204 bool DSLType::isScalar() const {
205 return this->skslType().isScalar();
206 }
207
isVector() const208 bool DSLType::isVector() const {
209 return this->skslType().isVector();
210 }
211
isMatrix() const212 bool DSLType::isMatrix() const {
213 return this->skslType().isMatrix();
214 }
215
isArray() const216 bool DSLType::isArray() const {
217 return this->skslType().isArray();
218 }
219
isStruct() const220 bool DSLType::isStruct() const {
221 return this->skslType().isStruct();
222 }
223
isEffectChild() const224 bool DSLType::isEffectChild() const {
225 return this->skslType().isEffectChild();
226 }
227
skslType() const228 const SkSL::Type& DSLType::skslType() const {
229 if (fSkSLType) {
230 return *fSkSLType;
231 }
232 const Context& context = ThreadContext::Context();
233 return *verify_type(context,
234 get_type_from_type_constant(context, fTypeConstant),
235 /*allowPrivateTypes=*/true,
236 PositionInfo());
237 }
238
Construct(DSLType type,SkSpan<DSLExpression> argArray)239 DSLPossibleExpression DSLType::Construct(DSLType type, SkSpan<DSLExpression> argArray) {
240 SkSL::ExpressionArray skslArgs;
241 skslArgs.reserve_back(argArray.size());
242
243 for (DSLExpression& arg : argArray) {
244 if (!arg.hasValue()) {
245 return DSLPossibleExpression(nullptr);
246 }
247 skslArgs.push_back(arg.release());
248 }
249 return SkSL::Constructor::Convert(ThreadContext::Context(), /*line=*/-1, type.skslType(),
250 std::move(skslArgs));
251 }
252
Array(const DSLType & base,int count,PositionInfo pos)253 DSLType Array(const DSLType& base, int count, PositionInfo pos) {
254 count = base.skslType().convertArraySize(ThreadContext::Context(),
255 DSLExpression(count, pos).release());
256 ThreadContext::ReportErrors(pos);
257 if (!count) {
258 return DSLType(kPoison_Type);
259 }
260 return ThreadContext::SymbolTable()->addArrayDimension(&base.skslType(), count);
261 }
262
Struct(std::string_view name,SkSpan<DSLField> fields,PositionInfo pos)263 DSLType Struct(std::string_view name, SkSpan<DSLField> fields, PositionInfo pos) {
264 std::vector<SkSL::Type::Field> skslFields;
265 skslFields.reserve(fields.size());
266 for (const DSLField& field : fields) {
267 if (field.fModifiers.fModifiers.fFlags != Modifiers::kNo_Flag) {
268 std::string desc = field.fModifiers.fModifiers.description();
269 desc.pop_back(); // remove trailing space
270 ThreadContext::ReportError("modifier '" + desc + "' is not permitted on a struct field",
271 field.fPosition);
272 }
273 if (field.fModifiers.fModifiers.fLayout.fFlags & Layout::kBinding_Flag) {
274 ThreadContext::ReportError(
275 "layout qualifier 'binding' is not permitted on a struct field",
276 field.fPosition);
277 }
278 if (field.fModifiers.fModifiers.fLayout.fFlags & Layout::kSet_Flag) {
279 ThreadContext::ReportError("layout qualifier 'set' is not permitted on a struct field",
280 field.fPosition);
281 }
282
283 const SkSL::Type& type = field.fType.skslType();
284 if (type.isVoid()) {
285 ThreadContext::ReportError("type 'void' is not permitted in a struct", field.fPosition);
286 } else if (type.isOpaque()) {
287 ThreadContext::ReportError("opaque type '" + type.displayName() +
288 "' is not permitted in a struct", field.fPosition);
289 }
290 skslFields.emplace_back(field.fModifiers.fModifiers, field.fName, &type);
291 }
292 const SkSL::Type* result = ThreadContext::SymbolTable()->add(Type::MakeStructType(pos.line(),
293 name, skslFields));
294 if (result->isTooDeeplyNested()) {
295 ThreadContext::ReportError("struct '" + std::string(name) + "' is too deeply nested", pos);
296 }
297 ThreadContext::ProgramElements().push_back(std::make_unique<SkSL::StructDefinition>(/*line=*/-1,
298 *result));
299 return result;
300 }
301
302 } // namespace dsl
303
304 } // namespace SkSL
305