• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2021 Google LLC
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/codegen/SkSLDSLCPPCodeGenerator.h"
9 
10 #include "src/sksl/SkSLAnalysis.h"
11 #include "src/sksl/SkSLCPPUniformCTypes.h"
12 #include "src/sksl/SkSLCompiler.h"
13 #include "src/sksl/codegen/SkSLHCodeGenerator.h"
14 #include "src/sksl/ir/SkSLBlock.h"
15 #include "src/sksl/ir/SkSLEnum.h"
16 #include "src/sksl/ir/SkSLExpressionStatement.h"
17 
18 #include <algorithm>
19 
20 #if defined(SKSL_STANDALONE) || GR_TEST_UTILS
21 
22 namespace SkSL {
23 
needs_uniform_var(const Variable & var)24 static bool needs_uniform_var(const Variable& var) {
25     return var.modifiers().fFlags & Modifiers::kUniform_Flag;
26 }
27 
get_scalar_type_name(const Context & context,const Type & type)28 static const char* get_scalar_type_name(const Context& context, const Type& type) {
29     if (type == *context.fTypes.fHalf) {
30         return "Half";
31     } else if (type == *context.fTypes.fFloat) {
32         return "Float";
33     } else if (type.isSigned()) {
34         return "Int";
35     } else if (type.isUnsigned()) {
36         return "UInt";
37     } else if (type.isBoolean()) {
38         return "Bool";
39     }
40     // TODO: support for unsigned types
41     SkDEBUGFAIL("unsupported scalar type");
42     return "Float";
43 }
44 
DSLCPPCodeGenerator(const Context * context,const Program * program,ErrorReporter * errors,String name,OutputStream * out)45 DSLCPPCodeGenerator::DSLCPPCodeGenerator(const Context* context, const Program* program,
46                                          ErrorReporter* errors, String name, OutputStream* out)
47     : INHERITED(context, program, errors, out)
48     , fName(std::move(name))
49     , fFullName(String::printf("Gr%s", fName.c_str()))
50     , fSectionAndParameterHelper(program, *errors) {
51     fLineEnding = "\n";
52     fTextureFunctionOverride = "sample";
53 }
54 
writef(const char * s,va_list va)55 void DSLCPPCodeGenerator::writef(const char* s, va_list va) {
56     static constexpr int BUFFER_SIZE = 1024;
57     va_list copy;
58     va_copy(copy, va);
59     char buffer[BUFFER_SIZE];
60     int length = std::vsnprintf(buffer, BUFFER_SIZE, s, va);
61     if (length < BUFFER_SIZE) {
62         fOut->write(buffer, length);
63     } else {
64         std::unique_ptr<char[]> heap(new char[length + 1]);
65         vsprintf(heap.get(), s, copy);
66         fOut->write(heap.get(), length);
67     }
68     va_end(copy);
69 }
70 
writef(const char * s,...)71 void DSLCPPCodeGenerator::writef(const char* s, ...) {
72     va_list va;
73     va_start(va, s);
74     this->writef(s, va);
75     va_end(va);
76 }
77 
writeHeader()78 void DSLCPPCodeGenerator::writeHeader() {
79 }
80 
usesPrecisionModifiers() const81 bool DSLCPPCodeGenerator::usesPrecisionModifiers() const {
82     return false;
83 }
84 
default_value(const Type & type)85 static String default_value(const Type& type) {
86     if (type.isBoolean()) {
87         return "false";
88     }
89     switch (type.typeKind()) {
90         case Type::TypeKind::kScalar: return "0";
91         case Type::TypeKind::kVector: return type.name() + "(0)";
92         case Type::TypeKind::kMatrix: return type.name() + "(1)";
93         default: SK_ABORT("unsupported default_value type");
94     }
95 }
96 
default_value(const Variable & var)97 static String default_value(const Variable& var) {
98     if (var.modifiers().fLayout.fCType == SkSL::Layout::CType::kSkPMColor4f) {
99         return "{SK_FloatNaN, SK_FloatNaN, SK_FloatNaN, SK_FloatNaN}";
100     }
101     return default_value(var.type());
102 }
103 
is_private(const Variable & var)104 static bool is_private(const Variable& var) {
105     const Modifiers& modifiers = var.modifiers();
106     return !(modifiers.fFlags & Modifiers::kUniform_Flag) &&
107            !(modifiers.fFlags & Modifiers::kIn_Flag) &&
108            var.storage() == Variable::Storage::kGlobal &&
109            modifiers.fLayout.fBuiltin == -1;
110 }
111 
is_uniform_in(const Variable & var)112 static bool is_uniform_in(const Variable& var) {
113     const Modifiers& modifiers = var.modifiers();
114     return (modifiers.fFlags & Modifiers::kUniform_Flag) &&
115            (modifiers.fFlags & Modifiers::kIn_Flag);
116 }
117 
formatRuntimeValue(const Type & type,const Layout & layout,const String & cppCode,std::vector<String> * formatArgs)118 String DSLCPPCodeGenerator::formatRuntimeValue(const Type& type,
119                                                const Layout& layout,
120                                                const String& cppCode,
121                                                std::vector<String>* formatArgs) {
122     if (type.isArray()) {
123         String result("[");
124         const char* separator = "";
125         for (int i = 0; i < type.columns(); i++) {
126             result += separator + this->formatRuntimeValue(type.componentType(), layout,
127                                                            "(" + cppCode + ")[" + to_string(i) +
128                                                            "]", formatArgs);
129             separator = ",";
130         }
131         result += "]";
132         return result;
133     }
134     if (type.isFloat()) {
135         formatArgs->push_back(cppCode);
136         return "%f";
137     }
138     if (type == *fContext.fTypes.fInt) {
139         formatArgs->push_back(cppCode);
140         return "%d";
141     }
142     if (type == *fContext.fTypes.fBool) {
143         formatArgs->push_back("!!(" + cppCode + ")");
144         return "%d";
145     }
146     if (type == *fContext.fTypes.fFloat2 || type == *fContext.fTypes.fHalf2) {
147         formatArgs->push_back(cppCode + ".fX");
148         formatArgs->push_back(cppCode + ".fY");
149         return type.name() + "(%f, %f)";
150     }
151     if (type == *fContext.fTypes.fFloat3 || type == *fContext.fTypes.fHalf3) {
152         formatArgs->push_back(cppCode + ".fX");
153         formatArgs->push_back(cppCode + ".fY");
154         formatArgs->push_back(cppCode + ".fZ");
155         return type.name() + "(%f, %f, %f)";
156     }
157     if (type == *fContext.fTypes.fFloat4 || type == *fContext.fTypes.fHalf4) {
158         switch (layout.fCType) {
159             case Layout::CType::kSkPMColor:
160                 formatArgs->push_back("SkGetPackedR32(" + cppCode + ") / 255.0");
161                 formatArgs->push_back("SkGetPackedG32(" + cppCode + ") / 255.0");
162                 formatArgs->push_back("SkGetPackedB32(" + cppCode + ") / 255.0");
163                 formatArgs->push_back("SkGetPackedA32(" + cppCode + ") / 255.0");
164                 break;
165             case Layout::CType::kSkPMColor4f:
166                 formatArgs->push_back(cppCode + ".fR");
167                 formatArgs->push_back(cppCode + ".fG");
168                 formatArgs->push_back(cppCode + ".fB");
169                 formatArgs->push_back(cppCode + ".fA");
170                 break;
171             case Layout::CType::kSkV4:
172                 formatArgs->push_back(cppCode + ".x");
173                 formatArgs->push_back(cppCode + ".y");
174                 formatArgs->push_back(cppCode + ".z");
175                 formatArgs->push_back(cppCode + ".w");
176                 break;
177             case Layout::CType::kSkRect:
178             case Layout::CType::kDefault:
179                 formatArgs->push_back(cppCode + ".left()");
180                 formatArgs->push_back(cppCode + ".top()");
181                 formatArgs->push_back(cppCode + ".right()");
182                 formatArgs->push_back(cppCode + ".bottom()");
183                 break;
184             default:
185                 SkASSERT(false);
186         }
187         return type.name() + "(%f, %f, %f, %f)";
188     }
189     if (type.isMatrix()) {
190         SkASSERT(type.componentType() == *fContext.fTypes.fFloat ||
191                  type.componentType() == *fContext.fTypes.fHalf);
192 
193         String format = type.name() + "(";
194         for (int c = 0; c < type.columns(); ++c) {
195             for (int r = 0; r < type.rows(); ++r) {
196                 formatArgs->push_back(String::printf("%s.rc(%d, %d)", cppCode.c_str(), r, c));
197                 format += "%f, ";
198             }
199         }
200 
201         // Replace trailing ", " with ")".
202         format.pop_back();
203         format.back() = ')';
204         return format;
205     }
206     if (type.isEnum()) {
207         formatArgs->push_back("(int) " + cppCode);
208         return "%d";
209     }
210     if (type == *fContext.fTypes.fInt4 ||
211         type == *fContext.fTypes.fShort4) {
212         formatArgs->push_back(cppCode + ".left()");
213         formatArgs->push_back(cppCode + ".top()");
214         formatArgs->push_back(cppCode + ".right()");
215         formatArgs->push_back(cppCode + ".bottom()");
216         return type.name() + "(%d, %d, %d, %d)";
217     }
218 
219     SkDEBUGFAILF("unsupported runtime value type '%s'\n", String(type.name()).c_str());
220     return "";
221 }
222 
writeSwizzle(const Swizzle & swizzle)223 void DSLCPPCodeGenerator::writeSwizzle(const Swizzle& swizzle) {
224     // Confirm that the component array only contains X/Y/Z/W.
225     SkASSERT(std::all_of(swizzle.components().begin(), swizzle.components().end(),
226              [](int8_t component) {
227                  return component >= SwizzleComponent::X && component <= SwizzleComponent::W;
228              }));
229 
230     if (fCPPMode) {
231         // no support for multiple swizzle components yet
232         SkASSERT(swizzle.components().size() == 1);
233         this->writeExpression(*swizzle.base(), Precedence::kPostfix);
234         switch (swizzle.components().front()) {
235             case SwizzleComponent::X: this->write(".left()");   break;
236             case SwizzleComponent::Y: this->write(".top()");    break;
237             case SwizzleComponent::Z: this->write(".right()");  break;
238             case SwizzleComponent::W: this->write(".bottom()"); break;
239         }
240     } else {
241         if (swizzle.components().size() == 1) {
242             // For single-element swizzles, we can generate nicer-looking code.
243             this->writeExpression(*swizzle.base(), Precedence::kPostfix);
244             switch (swizzle.components().front()) {
245                 case SwizzleComponent::X: this->write(".x()"); break;
246                 case SwizzleComponent::Y: this->write(".y()"); break;
247                 case SwizzleComponent::Z: this->write(".z()"); break;
248                 case SwizzleComponent::W: this->write(".w()"); break;
249             }
250         } else {
251             this->write("Swizzle(");
252             this->writeExpression(*swizzle.base(), Precedence::kSequence);
253             for (int8_t component : swizzle.components()) {
254                 switch (component) {
255                     case SwizzleComponent::X: this->write(", X"); break;
256                     case SwizzleComponent::Y: this->write(", Y"); break;
257                     case SwizzleComponent::Z: this->write(", Z"); break;
258                     case SwizzleComponent::W: this->write(", W"); break;
259                 }
260             }
261             this->write(")");
262         }
263     }
264 }
265 
writeTernaryExpression(const TernaryExpression & t,Precedence parentPrecedence)266 void DSLCPPCodeGenerator::writeTernaryExpression(const TernaryExpression& t,
267                                                  Precedence parentPrecedence) {
268     if (fCPPMode) {
269         INHERITED::writeTernaryExpression(t, parentPrecedence);
270     } else {
271         this->write("Select(");
272         this->writeExpression(*t.test(), Precedence::kSequence);
273         this->write(", /*If True:*/ ");
274         this->writeExpression(*t.ifTrue(), Precedence::kSequence);
275         this->write(", /*If False:*/ ");
276         this->writeExpression(*t.ifFalse(), Precedence::kSequence);
277         this->write(")");
278     }
279 }
280 
writeVariableReference(const VariableReference & ref)281 void DSLCPPCodeGenerator::writeVariableReference(const VariableReference& ref) {
282     const Variable& var = *ref.variable();
283     if (fCPPMode) {
284         this->write(var.name());
285         return;
286     }
287 
288     switch (var.modifiers().fLayout.fBuiltin) {
289         case SK_MAIN_COORDS_BUILTIN:
290             this->write("sk_SampleCoord()");
291             fAccessSampleCoordsDirectly = true;
292             return;
293         case SK_FRAGCOORD_BUILTIN:
294             this->write("sk_FragCoord()");
295             return;
296         default:
297             break;
298     }
299 
300     this->write(this->getVariableCppName(var));
301 }
302 
getChildFPIndex(const Variable & var) const303 int DSLCPPCodeGenerator::getChildFPIndex(const Variable& var) const {
304     int index = 0;
305     for (const ProgramElement* p : fProgram.elements()) {
306         if (p->is<GlobalVarDeclaration>()) {
307             const VarDeclaration& decl =
308                     p->as<GlobalVarDeclaration>().declaration()->as<VarDeclaration>();
309             if (&decl.var() == &var) {
310                 return index;
311             } else if (decl.var().type().isFragmentProcessor()) {
312                 ++index;
313             }
314         }
315     }
316     SkDEBUGFAILF("child fragment processor for '%s' not found", var.description().c_str());
317     return 0;
318 }
319 
writeFunctionCall(const FunctionCall & c)320 void DSLCPPCodeGenerator::writeFunctionCall(const FunctionCall& c) {
321     const FunctionDeclaration& function = c.function();
322     if (function.isBuiltin() && function.name() == "sample") {
323         // The first argument to sample() must be a fragment processor. (Old-school samplers are no
324         // longer supported in FP files.)
325         const ExpressionArray& arguments = c.arguments();
326         SkASSERT(arguments.size() >= 1 && arguments.size() <= 3);
327         const Expression& fpArgument = *arguments.front();
328         SkASSERT(fpArgument.type().isFragmentProcessor());
329 
330         // We can't look up the child FP index unless the fragment-processor is a real variable.
331         if (!fpArgument.is<VariableReference>()) {
332             fErrors.error(fpArgument.fOffset,
333                           "sample()'s fragmentProcessor argument must be a variable reference");
334             return;
335         }
336 
337         // Pass the index of the fragment processor, and all the other arguments as-is.
338         int childFPIndex = this->getChildFPIndex(*fpArgument.as<VariableReference>().variable());
339         this->writef("SampleChild(%d", childFPIndex);
340 
341         for (int index = 1; index < arguments.count(); ++index) {
342             this->write(", ");
343             this->writeExpression(*arguments[index], Precedence::kSequence);
344         }
345         this->write(")");
346         return;
347     }
348 
349     if (function.isBuiltin()) {
350         if (fCPPMode) {
351             this->write(function.name());
352         } else {
353             static const auto* kBuiltinNames = new std::unordered_map<IntrinsicKind, const char*>{
354                     {k_abs_IntrinsicKind, "Abs"},
355                     {k_all_IntrinsicKind, "All"},
356                     {k_any_IntrinsicKind, "Any"},
357                     {k_atan_IntrinsicKind, "Atan"},
358                     {k_ceil_IntrinsicKind, "Ceil"},
359                     {k_clamp_IntrinsicKind, "Clamp"},
360                     {k_cos_IntrinsicKind, "Cos"},
361                     {k_cross_IntrinsicKind, "Cross"},
362                     {k_degrees_IntrinsicKind, "Degrees"},
363                     {k_distance_IntrinsicKind, "Distance"},
364                     {k_dot_IntrinsicKind, "Dot"},
365                     {k_equal_IntrinsicKind, "Equal"},
366                     {k_exp_IntrinsicKind, "Exp"},
367                     {k_exp2_IntrinsicKind, "Exp2"},
368                     {k_faceforward_IntrinsicKind, "Faceforward"},
369                     {k_floor_IntrinsicKind, "Floor"},
370                     {k_fract_IntrinsicKind, "SkSL::dsl::Fract"},
371                     {k_greaterThan_IntrinsicKind, "GreaterThan"},
372                     {k_greaterThanEqual_IntrinsicKind, "GreaterThanEqual"},
373                     {k_inversesqrt_IntrinsicKind, "Inversesqrt"},
374                     {k_inverse_IntrinsicKind, "Inverse"},
375                     {k_length_IntrinsicKind, "Length"},
376                     {k_lessThan_IntrinsicKind, "LessThan"},
377                     {k_lessThanEqual_IntrinsicKind, "LessThanEqual"},
378                     {k_log_IntrinsicKind, "Log"},
379                     {k_max_IntrinsicKind, "Max"},
380                     {k_min_IntrinsicKind, "Min"},
381                     {k_mix_IntrinsicKind, "Mix"},
382                     {k_mod_IntrinsicKind, "Mod"},
383                     {k_normalize_IntrinsicKind, "Normalize"},
384                     {k_not_IntrinsicKind, "Not"},
385                     {k_pow_IntrinsicKind, "Pow"},
386                     {k_radians_IntrinsicKind, "Radians"},
387                     {k_reflect_IntrinsicKind, "Reflect"},
388                     {k_refract_IntrinsicKind, "Refract"},
389                     {k_saturate_IntrinsicKind, "Saturate"},
390                     {k_sign_IntrinsicKind, "Sign"},
391                     {k_sin_IntrinsicKind, "Sin"},
392                     {k_smoothstep_IntrinsicKind, "Smoothstep"},
393                     {k_sqrt_IntrinsicKind, "Sqrt"},
394                     {k_step_IntrinsicKind, "Step"},
395                     {k_tan_IntrinsicKind, "Tan"},
396                     {k_unpremul_IntrinsicKind, "Unpremul"}};
397 
398             auto iter = kBuiltinNames->find(function.intrinsicKind());
399             if (iter == kBuiltinNames->end()) {
400                 fErrors.error(c.fOffset,
401                               "unrecognized built-in function '" + function.name() + "'");
402                 return;
403             }
404 
405             this->write(iter->second);
406         }
407 
408         this->write("(");
409         const char* separator = "";
410         for (const std::unique_ptr<Expression>& argument : c.arguments()) {
411             this->write(separator);
412             separator = ", ";
413             this->writeExpression(*argument, Precedence::kSequence);
414         }
415         this->write(")");
416         return;
417     }
418 
419     SK_ABORT("not yet implemented: helper function support for DSL");
420 }
421 
prepareHelperFunction(const FunctionDeclaration & decl)422 void DSLCPPCodeGenerator::prepareHelperFunction(const FunctionDeclaration& decl) {
423     if (decl.isBuiltin() || decl.isMain()) {
424         return;
425     }
426 
427     SK_ABORT("not yet implemented: helper functions in DSL");
428 }
429 
prototypeHelperFunction(const FunctionDeclaration & decl)430 void DSLCPPCodeGenerator::prototypeHelperFunction(const FunctionDeclaration& decl) {
431     SK_ABORT("not yet implemented: function prototypes in DSL");
432 }
433 
writeFunction(const FunctionDefinition & f)434 void DSLCPPCodeGenerator::writeFunction(const FunctionDefinition& f) {
435     const FunctionDeclaration& decl = f.declaration();
436     if (decl.isBuiltin()) {
437         return;
438     }
439     fFunctionHeader.clear();
440     OutputStream* oldOut = fOut;
441     StringStream buffer;
442     fOut = &buffer;
443     if (decl.isMain()) {
444         fInMain = true;
445         this->writeFunctionBody(f.body()->as<Block>());
446         fInMain = false;
447 
448         fOut = oldOut;
449         this->write(fFunctionHeader);
450         this->write(buffer.str());
451     } else {
452         SK_ABORT("not yet implemented: helper functions in DSL");
453     }
454 }
455 
writeFunctionBody(const Block & b)456 void DSLCPPCodeGenerator::writeFunctionBody(const Block& b) {
457     // At the top level of a function, DSL statements need to be emitted as individual C++
458     // statements instead of being comma-separated expressions in a Block. (You could technically
459     // emit the entire function as one big comma-separated Block, and it would work, but you'd wrap
460     // everything with an extra unnecessary scope.)
461     for (const std::unique_ptr<Statement>& stmt : b.children()) {
462         if (!stmt->isEmpty()) {
463             this->writeStatement(*stmt);
464             this->write(";\n");
465         }
466     }
467 }
468 
writeBlock(const Block & b)469 void DSLCPPCodeGenerator::writeBlock(const Block& b) {
470     if (b.isEmpty()) {
471         // Write empty Blocks as an empty Statement, whether or not it was scoped.
472         // This is the simplest way to emit a valid Statement for an unscoped empty Block.
473         this->write("Statement()");
474         return;
475     }
476 
477     if (b.isScope()) {
478         this->write("Block(");
479     }
480 
481     const char* separator = "";
482     for (const std::unique_ptr<Statement>& stmt : b.children()) {
483         if (!stmt->isEmpty()) {
484             this->write(separator);
485             separator = ", ";
486 
487             this->writeStatement(*stmt);
488         }
489     }
490 
491     if (b.isScope()) {
492         this->write(")");
493     }
494 }
495 
writeReturnStatement(const ReturnStatement & r)496 void DSLCPPCodeGenerator::writeReturnStatement(const ReturnStatement& r) {
497     this->write("Return(");
498     if (r.expression()) {
499         this->writeExpression(*r.expression(), Precedence::kTopLevel);
500     }
501     this->write(")");
502 }
503 
writeIfStatement(const IfStatement & stmt)504 void DSLCPPCodeGenerator::writeIfStatement(const IfStatement& stmt) {
505     this->write(stmt.isStatic() ? "StaticIf(" : "If(");
506     this->writeExpression(*stmt.test(), Precedence::kTopLevel);
507     this->write(", /*Then:*/ ");
508     this->writeStatement(*stmt.ifTrue());
509     if (stmt.ifFalse()) {
510         this->write(", /*Else:*/ ");
511         this->writeStatement(*stmt.ifFalse());
512     }
513     this->write(")");
514 }
515 
variable_exists_with_name(const std::unordered_map<const Variable *,String> & varMap,const String & trialName)516 static bool variable_exists_with_name(const std::unordered_map<const Variable*, String>& varMap,
517                                       const String& trialName) {
518     for (const auto& [varPtr, varName] : varMap) {
519         if (varName == trialName) {
520             return true;
521         }
522     }
523     return false;
524 }
525 
getVariableCppName(const Variable & var)526 const char* DSLCPPCodeGenerator::getVariableCppName(const Variable& var) {
527     String& cppName = fVariableCppNames[&var];
528     if (cppName.empty()) {
529         // Append a prefix to the variable name. This serves two purposes:
530         // - disambiguates variables with the same name that live in different SkSL scopes
531         // - gives the DSLVar a distinct name, leaving the original name free to be given to a
532         //   C++ constant, for the case of layout-keys that need to work in C++ `when` expressions
533         // Probing for a unique name could be more efficient, but it really doesn't matter much;
534         // overlapping names are super rare, and we only compile DSLs in skslc at build time.
535         for (int prefix = 0;; ++prefix) {
536             String prefixedName = (prefix > 0 || !islower(var.name()[0]))
537                     ? String::printf("_%d_%.*s", prefix, (int)var.name().size(), var.name().data())
538                     : String::printf("_%.*s", (int)var.name().size(), var.name().data());
539 
540             if (!variable_exists_with_name(fVariableCppNames, prefixedName)) {
541                 cppName = std::move(prefixedName);
542                 break;
543             }
544         }
545     }
546 
547     return cppName.c_str();
548 }
549 
writeCppInitialValue(const Variable & var)550 void DSLCPPCodeGenerator::writeCppInitialValue(const Variable& var) {
551     // `formatRuntimeValue` generates a C++ format string (which we don't need, since
552     // we're not formatting anything at runtime) and a vector of the arguments within
553     // the variable (which we do need, to fill in the Var's initial value).
554     std::vector<String> argumentList;
555     (void) this->formatRuntimeValue(var.type(), var.modifiers().fLayout,
556                                     var.name(), &argumentList);
557 
558     this->write(this->getTypeName(var.type()));
559     this->write("(");
560     const char* separator = "";
561     for (const String& arg : argumentList) {
562         this->write(separator);
563         this->write(arg);
564         separator = ", ";
565     }
566     this->write(")");
567 }
568 
writeVarCtorExpression(const Variable & var)569 void DSLCPPCodeGenerator::writeVarCtorExpression(const Variable& var) {
570     this->write(this->getDSLModifiers(var.modifiers()));
571     this->write(", ");
572     this->write(this->getDSLType(var.type()));
573     this->write(", \"");
574     this->write(var.name());
575     this->write("\"");
576     if (var.initialValue()) {
577         this->write(", ");
578         if (is_private(var)) {
579             // The initial value was calculated in C++ (see writePrivateVarValues). This value can
580             // be baked into the DSL as a constant.
581             this->writeCppInitialValue(var);
582         } else {
583             // Write the variable's initial-value expression in DSL.
584             this->writeExpression(*var.initialValue(), Precedence::kTopLevel);
585         }
586     }
587 }
588 
writeVar(const Variable & var)589 void DSLCPPCodeGenerator::writeVar(const Variable& var) {
590     this->write("Var ");
591     this->write(this->getVariableCppName(var));
592     this->write("(");
593     this->writeVarCtorExpression(var);
594     this->write(");\n");
595 }
596 
writeVarDeclaration(const VarDeclaration & varDecl,bool global)597 void DSLCPPCodeGenerator::writeVarDeclaration(const VarDeclaration& varDecl, bool global) {
598     const Variable& var = varDecl.var();
599     if (!global) {
600         // We want to divert our output into fFunctionHeader, but fFunctionHeader is just a
601         // String, not a StringStream. So instead, we divert into a temporary stream and append
602         // that stream into fFunctionHeader afterwards.
603         StringStream stream;
604         AutoOutputStream divert(this, &stream);
605 
606         this->writeVar(var);
607 
608         fFunctionHeader += stream.str();
609     } else {
610         // For global variables, we can write the Var directly into the code stream.
611         this->writeVar(var);
612     }
613 
614     this->write("Declare(");
615     this->write(this->getVariableCppName(var));
616     this->write(")");
617 }
618 
writeForStatement(const ForStatement & f)619 void DSLCPPCodeGenerator::writeForStatement(const ForStatement& f) {
620     // Emit loops of the form 'for (; test;)' as 'while (test)', which is probably how they started.
621     if (!f.initializer() && f.test() && !f.next()) {
622         this->write("While(");
623         this->writeExpression(*f.test(), Precedence::kTopLevel);
624         this->write(", ");
625         this->writeStatement(*f.statement());
626         this->write(")");
627         return;
628     }
629 
630     this->write("For(");
631     if (f.initializer() && !f.initializer()->isEmpty()) {
632         this->writeStatement(*f.initializer());
633         this->write(", ");
634     } else {
635         this->write("Statement(), ");
636     }
637     if (f.test()) {
638         this->writeExpression(*f.test(), Precedence::kTopLevel);
639         this->write(", ");
640     } else {
641         this->write("Expression(), ");
642     }
643     if (f.next()) {
644         this->writeExpression(*f.next(), Precedence::kTopLevel);
645         this->write(", /*Body:*/ ");
646     } else {
647         this->write("Expression(), /*Body:*/ ");
648     }
649     this->writeStatement(*f.statement());
650     this->write(")");
651 }
652 
writeDoStatement(const DoStatement & d)653 void DSLCPPCodeGenerator::writeDoStatement(const DoStatement& d) {
654     this->write("Do(");
655     this->writeStatement(*d.statement());
656     this->write(", /*While:*/ ");
657     this->writeExpression(*d.test(), Precedence::kTopLevel);
658     this->write(")");
659 }
660 
writeSwitchStatement(const SwitchStatement & s)661 void DSLCPPCodeGenerator::writeSwitchStatement(const SwitchStatement& s) {
662     this->write(s.isStatic() ? "StaticSwitch(" : "Switch(");
663     this->writeExpression(*s.value(), Precedence::kTopLevel);
664     for (const std::unique_ptr<Statement>& stmt : s.cases()) {
665         const SwitchCase& c = stmt->as<SwitchCase>();
666         if (c.value()) {
667             this->write(",\n    Case(");
668             this->writeExpression(*c.value(), Precedence::kTopLevel);
669             if (!c.statement()->isEmpty()) {
670                 this->write(", ");
671                 this->writeStatement(*c.statement());
672             }
673         } else {
674             this->write(",\n    Default(");
675             if (!c.statement()->isEmpty()) {
676                 this->writeStatement(*c.statement());
677             }
678         }
679         this->write(")");
680     }
681     this->write(")");
682 }
683 
writeCastConstructor(const AnyConstructor & c,Precedence parentPrecedence)684 void DSLCPPCodeGenerator::writeCastConstructor(const AnyConstructor& c,
685                                                Precedence parentPrecedence) {
686     return this->writeAnyConstructor(c, parentPrecedence);
687 }
688 
writeAnyConstructor(const AnyConstructor & c,Precedence parentPrecedence)689 void DSLCPPCodeGenerator::writeAnyConstructor(const AnyConstructor& c,
690                                               Precedence parentPrecedence) {
691     if (c.type().isArray() || c.type().isStruct()) {
692         SK_ABORT("not yet supported: array/struct construction in DSL");
693     }
694 
695     INHERITED::writeAnyConstructor(c, parentPrecedence);
696 }
697 
getTypeName(const Type & type)698 String DSLCPPCodeGenerator::getTypeName(const Type& type) {
699     if (fCPPMode) {
700         return type.name();
701     }
702     switch (type.typeKind()) {
703         case Type::TypeKind::kScalar:
704             return get_scalar_type_name(fContext, type);
705 
706         case Type::TypeKind::kVector: {
707             const Type& component = type.componentType();
708             const char* baseName = get_scalar_type_name(fContext, component);
709             return String::printf("%s%d", baseName, type.columns());
710         }
711         case Type::TypeKind::kMatrix: {
712             const Type& component = type.componentType();
713             const char* baseName = get_scalar_type_name(fContext, component);
714             return String::printf("%s%dx%d", baseName, type.columns(), type.rows());
715         }
716         case Type::TypeKind::kEnum:
717             return "Int";
718 
719         default:
720             SK_ABORT("not yet supported: getTypeName of %s", type.displayName().c_str());
721             return type.name();
722     }
723 }
724 
getDSLType(const Type & type)725 String DSLCPPCodeGenerator::getDSLType(const Type& type) {
726     switch (type.typeKind()) {
727         case Type::TypeKind::kScalar:
728             return String::printf("DSLType(k%s_Type)", get_scalar_type_name(fContext, type));
729 
730         case Type::TypeKind::kVector: {
731             const Type& component = type.componentType();
732             const char* baseName = get_scalar_type_name(fContext, component);
733             return String::printf("DSLType(k%s%d_Type)", baseName, type.columns());
734         }
735         case Type::TypeKind::kMatrix: {
736             const Type& component = type.componentType();
737             const char* baseName = get_scalar_type_name(fContext, component);
738             return String::printf("DSLType(k%s%dx%d_Type)", baseName, type.columns(), type.rows());
739         }
740         case Type::TypeKind::kEnum:
741             return "DSLType(kInt_Type)";
742 
743         case Type::TypeKind::kArray: {
744             const Type& component = type.componentType();
745             SkASSERT(type.columns() != Type::kUnsizedArray);
746             return String::printf("Array(%s, %d)", this->getDSLType(component).c_str(),
747                                                    type.columns());
748         }
749         default:
750             SK_ABORT("not yet supported: getDSLType of %s", type.displayName().c_str());
751             return type.name();
752     }
753 }
754 
getDSLModifiers(const Modifiers & modifiers)755 String DSLCPPCodeGenerator::getDSLModifiers(const Modifiers& modifiers) {
756     String text;
757 
758     // Uniform variables can have `in uniform` flags in FP file; that's not how they are
759     // represented in DSL, however. Transform `in uniform` modifiers to just `uniform`.
760     if (modifiers.fFlags & Modifiers::kUniform_Flag) {
761         text += "kUniform_Modifier | ";
762     } else if (modifiers.fFlags & Modifiers::kIn_Flag) {
763         text += "kIn_Modifier | ";
764     }
765     if ((modifiers.fFlags & Modifiers::kConst_Flag) ||
766         (modifiers.fLayout.fFlags & Layout::kKey_Flag)) {
767         text += "kConst_Modifier | ";
768     }
769     if (modifiers.fFlags & Modifiers::kOut_Flag) {
770         text += "kOut_Modifier | ";
771     }
772     if (modifiers.fFlags & Modifiers::kFlat_Flag) {
773         text += "kFlat_Modifier | ";
774     }
775     if (modifiers.fFlags & Modifiers::kNoPerspective_Flag) {
776         text += "kNoPerspective_Modifier | ";
777     }
778 
779     if (text.empty()) {
780         return "kNo_Modifier";
781     }
782 
783     // Eliminate trailing ` | `.
784     text.pop_back();
785     text.pop_back();
786     text.pop_back();
787     return text;
788 }
789 
getDefaultDSLValue(const Variable & var)790 String DSLCPPCodeGenerator::getDefaultDSLValue(const Variable& var) {
791     // TODO: default_value returns half4(NaN) for colors, but DSL aborts if passed a literal NaN.
792     // Theoretically this really shouldn't matter.
793     switch (var.type().typeKind()) {
794         case Type::TypeKind::kScalar:
795         case Type::TypeKind::kVector: return this->getTypeName(var.type()) + "(0)";
796         case Type::TypeKind::kMatrix: return this->getTypeName(var.type()) + "(1)";
797         default: SK_ABORT("unsupported type: %s", var.type().description().c_str());
798     }
799 }
800 
writeStatement(const Statement & s)801 void DSLCPPCodeGenerator::writeStatement(const Statement& s) {
802     switch (s.kind()) {
803         case Statement::Kind::kBlock:
804             this->writeBlock(s.as<Block>());
805             break;
806         case Statement::Kind::kExpression:
807             this->writeExpression(*s.as<ExpressionStatement>().expression(), Precedence::kTopLevel);
808             break;
809         case Statement::Kind::kReturn:
810             this->writeReturnStatement(s.as<ReturnStatement>());
811             break;
812         case Statement::Kind::kVarDeclaration:
813             this->writeVarDeclaration(s.as<VarDeclaration>(), /*global=*/false);
814             break;
815         case Statement::Kind::kIf:
816             this->writeIfStatement(s.as<IfStatement>());
817             break;
818         case Statement::Kind::kFor:
819             this->writeForStatement(s.as<ForStatement>());
820             break;
821         case Statement::Kind::kDo:
822             this->writeDoStatement(s.as<DoStatement>());
823             break;
824         case Statement::Kind::kSwitch:
825             this->writeSwitchStatement(s.as<SwitchStatement>());
826             break;
827         case Statement::Kind::kBreak:
828             this->write("Break()");
829             break;
830         case Statement::Kind::kContinue:
831             this->write("Continue()");
832             break;
833         case Statement::Kind::kDiscard:
834             this->write("Discard()");
835             break;
836         case Statement::Kind::kInlineMarker:
837         case Statement::Kind::kNop:
838             this->write("Statement()");
839             break;
840         default:
841             SkDEBUGFAILF("unsupported statement: %s", s.description().c_str());
842             break;
843     }
844 }
845 
writeFloatLiteral(const FloatLiteral & f)846 void DSLCPPCodeGenerator::writeFloatLiteral(const FloatLiteral& f) {
847     this->write(to_string(f.value()));
848     this->write("f");
849 }
850 
writeSetting(const Setting & s)851 void DSLCPPCodeGenerator::writeSetting(const Setting& s) {
852     this->writef("sk_Caps.%s", s.name().c_str());
853 }
854 
writeSection(const char * name,const char * prefix)855 bool DSLCPPCodeGenerator::writeSection(const char* name, const char* prefix) {
856     const Section* s = fSectionAndParameterHelper.getSection(name);
857     if (s) {
858         this->writef("%s%s", prefix, s->text().c_str());
859         return true;
860     }
861     return false;
862 }
863 
writeProgramElement(const ProgramElement & p)864 void DSLCPPCodeGenerator::writeProgramElement(const ProgramElement& p) {
865     switch (p.kind()) {
866         case ProgramElement::Kind::kSection:
867             return;
868 
869         case ProgramElement::Kind::kGlobalVar: {
870             const GlobalVarDeclaration& decl = p.as<GlobalVarDeclaration>();
871             const Variable& var = decl.declaration()->as<VarDeclaration>().var();
872             // Don't write builtin uniforms.
873             if (var.modifiers().fFlags & (Modifiers::kIn_Flag | Modifiers::kUniform_Flag) ||
874                 -1 != var.modifiers().fLayout.fBuiltin) {
875                 return;
876             }
877             this->writeVarDeclaration(decl.declaration()->as<VarDeclaration>(), /*global=*/true);
878             this->write(";\n");
879             return;
880         }
881         case ProgramElement::Kind::kFunctionPrototype:
882             SK_ABORT("not yet implemented: function prototypes in DSL");
883             return;
884 
885         default:
886             break;
887     }
888     INHERITED::writeProgramElement(p);
889 }
890 
addUniform(const Variable & var)891 void DSLCPPCodeGenerator::addUniform(const Variable& var) {
892     if (!needs_uniform_var(var)) {
893         return;
894     }
895 
896     const char* varCppName = this->getVariableCppName(var);
897     if (var.modifiers().fLayout.fWhen.fLength) {
898         // In cases where the `when` clause is true, we set up the Var normally.
899         this->writef(
900                 "Var %s;\n"
901                 "if (%.*s) {\n"
902                 "    Var(",
903                 varCppName,
904                 (int)var.modifiers().fLayout.fWhen.size(), var.modifiers().fLayout.fWhen.data());
905         this->writeVarCtorExpression(var);
906         this->writef(").swap(%s);\n    ", varCppName);
907     } else {
908         this->writeVar(var);
909     }
910 
911     this->writef("%.*sVar = VarUniformHandle(%s);\n",
912                  (int)var.name().size(), var.name().data(), this->getVariableCppName(var));
913 
914     if (var.modifiers().fLayout.fWhen.fLength) {
915         this->writef("    DeclareGlobal(%s);\n", varCppName);
916         // In cases where the `when` is false, we declare the Var as a const with a default value.
917         this->writef("} else {\n"
918                      "    Var(kConst_Modifier, %s, \"%.*s\", %s).swap(%s);\n"
919                      "    Declare(%s);\n"
920                      "}\n",
921                      this->getDSLType(var.type()).c_str(),
922                      (int)var.name().size(), var.name().data(),
923                      this->getDefaultDSLValue(var).c_str(),
924                      varCppName, varCppName);
925     } else {
926         this->writef("DeclareGlobal(%s);\n", varCppName);
927     }
928 }
929 
writeInputVars()930 void DSLCPPCodeGenerator::writeInputVars() {
931 }
932 
writePrivateVars()933 void DSLCPPCodeGenerator::writePrivateVars() {
934     for (const ProgramElement* p : fProgram.elements()) {
935         if (p->is<GlobalVarDeclaration>()) {
936             const GlobalVarDeclaration& global = p->as<GlobalVarDeclaration>();
937             const Variable& var = global.declaration()->as<VarDeclaration>().var();
938             if (is_private(var)) {
939                 if (var.type().isFragmentProcessor()) {
940                     fErrors.error(global.fOffset,
941                                   "fragmentProcessor variables must be declared 'in'");
942                     return;
943                 }
944                 this->writef("%s %.*s = %s;\n",
945                              HCodeGenerator::FieldType(fContext, var.type(),
946                                                        var.modifiers().fLayout).c_str(),
947                              (int)var.name().size(), var.name().data(),
948                              default_value(var).c_str());
949             } else if (var.modifiers().fLayout.fFlags & Layout::kTracked_Flag) {
950                 // An auto-tracked uniform in variable, so add a field to hold onto the prior
951                 // state. Note that tracked variables must be uniform in's and that is validated
952                 // before writePrivateVars() is called.
953                 const UniformCTypeMapper* mapper = UniformCTypeMapper::Get(fContext, var);
954                 SkASSERT(mapper);
955 
956                 String name = HCodeGenerator::FieldName(String(var.name()).c_str());
957                 // The member statement is different if the mapper reports a default value
958                 if (mapper->defaultValue().size() > 0) {
959                     this->writef("%s %sPrev = %s;\n",
960                                     Layout::CTypeToStr(mapper->ctype()), name.c_str(),
961                                     mapper->defaultValue().c_str());
962                 } else {
963                     this->writef("%s %sPrev;\n",
964                                     Layout::CTypeToStr(mapper->ctype()), name.c_str());
965                 }
966             }
967         }
968     }
969 }
970 
writePrivateVarValues()971 void DSLCPPCodeGenerator::writePrivateVarValues() {
972     for (const ProgramElement* p : fProgram.elements()) {
973         if (p->is<GlobalVarDeclaration>()) {
974             const GlobalVarDeclaration& global = p->as<GlobalVarDeclaration>();
975             const VarDeclaration& decl = global.declaration()->as<VarDeclaration>();
976             if (is_private(decl.var()) && decl.value()) {
977                 // This function writes class member variables.
978                 // We need to emit plain C++ names and types, not DSL.
979                 fCPPMode = true;
980                 this->write(decl.var().name());
981                 this->write(" = ");
982                 this->writeExpression(*decl.value(), Precedence::kAssignment);
983                 this->write(";\n");
984                 fCPPMode = false;
985             }
986         }
987     }
988 }
989 
is_accessible(const Variable & var)990 static bool is_accessible(const Variable& var) {
991     const Type& type = var.type();
992     return !type.isFragmentProcessor() &&
993            Type::TypeKind::kOther != type.typeKind();
994 }
995 
writeEmitCode(std::vector<const Variable * > & uniforms)996 bool DSLCPPCodeGenerator::writeEmitCode(std::vector<const Variable*>& uniforms) {
997     this->writef("    void emitCode(EmitArgs& args) override {\n"
998                  "        [[maybe_unused]] const %s& _outer = args.fFp.cast<%s>();\n"
999                  "\n"
1000                  "        using namespace SkSL::dsl;\n"
1001                  "        StartFragmentProcessor(this, &args);\n",
1002                  fFullName.c_str(), fFullName.c_str());
1003     for (const ProgramElement* p : fProgram.elements()) {
1004         if (p->is<GlobalVarDeclaration>()) {
1005             const GlobalVarDeclaration& global = p->as<GlobalVarDeclaration>();
1006             const VarDeclaration& decl = global.declaration()->as<VarDeclaration>();
1007             const Variable& var = decl.var();
1008             if (var.modifiers().fFlags & Modifiers::kUniform_Flag) {
1009                 continue;
1010             }
1011             if (SectionAndParameterHelper::IsParameter(var) && is_accessible(var)) {
1012                 const char* varCppName = this->getVariableCppName(var);
1013                 this->writef("[[maybe_unused]] const auto& %.*s = _outer.%.*s;\n"
1014                              "Var %s(kConst_Modifier, %s, \"%.*s\", ",
1015                              (int)var.name().size(), var.name().data(),
1016                              (int)var.name().size(), var.name().data(),
1017                              varCppName, this->getDSLType(var.type()).c_str(),
1018                              (int)var.name().size(), var.name().data());
1019                 this->writeCppInitialValue(var);
1020                 this->writef(");\n"
1021                              "Declare(%s);\n", varCppName);
1022             }
1023         }
1024     }
1025 
1026     this->writePrivateVarValues();
1027     for (const Variable* u : uniforms) {
1028         this->addUniform(*u);
1029     }
1030     this->writeSection(kEmitCodeSection);
1031 
1032     // Generate mangled names and argument lists for helper functions.
1033     std::unordered_set<const FunctionDeclaration*> definedHelpers;
1034     for (const ProgramElement* p : fProgram.elements()) {
1035         if (p->is<FunctionDefinition>()) {
1036             const FunctionDeclaration* decl = &p->as<FunctionDefinition>().declaration();
1037             definedHelpers.insert(decl);
1038             this->prepareHelperFunction(*decl);
1039         }
1040     }
1041 
1042     // Emit prototypes for defined helper functions that originally had prototypes in the FP file.
1043     // (If a function was prototyped but never defined, we skip it, since it wasn't prepared above.)
1044     for (const ProgramElement* p : fProgram.elements()) {
1045         if (p->is<FunctionPrototype>()) {
1046             const FunctionDeclaration* decl = &p->as<FunctionPrototype>().declaration();
1047             if (definedHelpers.find(decl) != definedHelpers.end()) {
1048                 this->prototypeHelperFunction(*decl);
1049             }
1050         }
1051     }
1052 
1053     bool result = INHERITED::generateCode();
1054 
1055     this->write("        EndFragmentProcessor();\n"
1056                 "    }\n");
1057     return result;
1058 }
1059 
writeSetData(std::vector<const Variable * > & uniforms)1060 void DSLCPPCodeGenerator::writeSetData(std::vector<const Variable*>& uniforms) {
1061     const char* fullName = fFullName.c_str();
1062     const Section* section = fSectionAndParameterHelper.getSection(kSetDataSection);
1063     const char* pdman = section ? section->argument().c_str() : "pdman";
1064     this->writef("    void onSetData(const GrGLSLProgramDataManager& %s, "
1065                                     "const GrFragmentProcessor& _proc) override {\n",
1066                  pdman);
1067     bool wroteProcessor = false;
1068     for (const Variable* u : uniforms) {
1069         if (is_uniform_in(*u)) {
1070             if (!wroteProcessor) {
1071                 this->writef("        const %s& _outer = _proc.cast<%s>();\n", fullName, fullName);
1072                 wroteProcessor = true;
1073                 this->writef("        {\n");
1074             }
1075 
1076             const UniformCTypeMapper* mapper = UniformCTypeMapper::Get(fContext, *u);
1077             SkASSERT(mapper);
1078 
1079             String nameString(u->name());
1080             const char* name = nameString.c_str();
1081 
1082             // Switches for setData behavior in the generated code
1083             bool conditionalUniform = u->modifiers().fLayout.fWhen != "";
1084             bool isTracked = u->modifiers().fLayout.fFlags & Layout::kTracked_Flag;
1085             bool needsValueDeclaration = isTracked || !mapper->canInlineUniformValue();
1086 
1087             String uniformName = HCodeGenerator::FieldName(name) + "Var";
1088 
1089             String indent = "        "; // 8 by default, 12 when nested for conditional uniforms
1090             if (conditionalUniform) {
1091                 // Add a pre-check to make sure the uniform was emitted
1092                 // before trying to send any data to the GPU
1093                 this->writef("        if (%s.isValid()) {\n", uniformName.c_str());
1094                 indent += "    ";
1095             }
1096 
1097             String valueVar = "";
1098             if (needsValueDeclaration) {
1099                 valueVar.appendf("%sValue", name);
1100                 // Use AccessType since that will match the return type of _outer's public API.
1101                 String valueType = HCodeGenerator::AccessType(fContext, u->type(),
1102                                                               u->modifiers().fLayout);
1103                 this->writef("%s%s %s = _outer.%s;\n",
1104                              indent.c_str(), valueType.c_str(), valueVar.c_str(), name);
1105             } else {
1106                 // Not tracked and the mapper only needs to use the value once
1107                 // so send it a safe expression instead of the variable name
1108                 valueVar.appendf("(_outer.%s)", name);
1109             }
1110 
1111             if (isTracked) {
1112                 String prevVar = HCodeGenerator::FieldName(name) + "Prev";
1113                 this->writef("%sif (%s) {\n"
1114                              "%s    %s;\n"
1115                              "%s    %s;\n"
1116                              "%s}\n", indent.c_str(),
1117                         mapper->dirtyExpression(valueVar, prevVar).c_str(), indent.c_str(),
1118                         mapper->saveState(valueVar, prevVar).c_str(), indent.c_str(),
1119                         mapper->setUniform(pdman, uniformName, valueVar).c_str(), indent.c_str());
1120             } else {
1121                 this->writef("%s%s;\n", indent.c_str(),
1122                         mapper->setUniform(pdman, uniformName, valueVar).c_str());
1123             }
1124 
1125             if (conditionalUniform) {
1126                 // Close the earlier precheck block
1127                 this->writef("        }\n");
1128             }
1129         }
1130     }
1131     if (wroteProcessor) {
1132         this->writef("        }\n");
1133     }
1134     if (section) {
1135         for (const ProgramElement* p : fProgram.elements()) {
1136             if (p->is<GlobalVarDeclaration>()) {
1137                 const GlobalVarDeclaration& global = p->as<GlobalVarDeclaration>();
1138                 const VarDeclaration& decl = global.declaration()->as<VarDeclaration>();
1139                 const Variable& variable = decl.var();
1140 
1141                 if (needs_uniform_var(variable)) {
1142                     this->writef("        [[maybe_unused]] UniformHandle& %.*s = %.*sVar;\n",
1143                                  (int)variable.name().size(), variable.name().data(),
1144                                  (int)variable.name().size(), variable.name().data());
1145                 } else if (SectionAndParameterHelper::IsParameter(variable) &&
1146                            !variable.type().isFragmentProcessor()) {
1147                     if (!wroteProcessor) {
1148                         this->writef("        const %s& _outer = _proc.cast<%s>();\n",
1149                                      fullName, fullName);
1150                         wroteProcessor = true;
1151                     }
1152 
1153                     if (!variable.type().isFragmentProcessor()) {
1154                         this->writef("        [[maybe_unused]] const auto& %.*s = _outer.%.*s;\n",
1155                                      (int)variable.name().size(), variable.name().data(),
1156                                      (int)variable.name().size(), variable.name().data());
1157                     }
1158                 }
1159             }
1160         }
1161         this->writeSection(kSetDataSection);
1162     }
1163     this->write("    }\n");
1164 }
1165 
writeClone()1166 void DSLCPPCodeGenerator::writeClone() {
1167     if (!this->writeSection(kCloneSection)) {
1168         if (fSectionAndParameterHelper.getSection(kFieldsSection)) {
1169             fErrors.error(/*offset=*/0, "fragment processors with custom @fields must also have a "
1170                                         "custom @clone");
1171         }
1172         this->writef("%s::%s(const %s& src)\n"
1173                      ": INHERITED(k%s_ClassID, src.optimizationFlags())", fFullName.c_str(),
1174                      fFullName.c_str(), fFullName.c_str(), fFullName.c_str());
1175         for (const Variable* param : fSectionAndParameterHelper.getParameters()) {
1176             String fieldName = HCodeGenerator::FieldName(String(param->name()).c_str());
1177             if (!param->type().isFragmentProcessor()) {
1178                 this->writef("\n, %s(src.%s)",
1179                              fieldName.c_str(),
1180                              fieldName.c_str());
1181             }
1182         }
1183         this->writef(" {\n");
1184         this->writef("        this->cloneAndRegisterAllChildProcessors(src);\n");
1185         if (fAccessSampleCoordsDirectly) {
1186             this->writef("    this->setUsesSampleCoordsDirectly();\n");
1187         }
1188         this->write("}\n");
1189         this->writef("std::unique_ptr<GrFragmentProcessor> %s::clone() const {\n",
1190                      fFullName.c_str());
1191         this->writef("    return std::make_unique<%s>(*this);\n",
1192                      fFullName.c_str());
1193         this->write("}\n");
1194     }
1195 }
1196 
writeDumpInfo()1197 void DSLCPPCodeGenerator::writeDumpInfo() {
1198     this->writef("#if GR_TEST_UTILS\n"
1199                  "SkString %s::onDumpInfo() const {\n", fFullName.c_str());
1200 
1201     if (!this->writeSection(kDumpInfoSection)) {
1202         if (fSectionAndParameterHelper.getSection(kFieldsSection)) {
1203             fErrors.error(/*offset=*/0, "fragment processors with custom @fields must also have a "
1204                                         "custom @dumpInfo");
1205         }
1206 
1207         String formatString;
1208         std::vector<String> argumentList;
1209 
1210         for (const Variable* param : fSectionAndParameterHelper.getParameters()) {
1211             // dumpInfo() doesn't need to log child FPs.
1212             if (param->type().isFragmentProcessor()) {
1213                 continue;
1214             }
1215 
1216             // Add this field onto the format string and argument list.
1217             String fieldName = HCodeGenerator::FieldName(String(param->name()).c_str());
1218             String runtimeValue = this->formatRuntimeValue(param->type(),
1219                                                            param->modifiers().fLayout,
1220                                                            param->name(),
1221                                                            &argumentList);
1222             formatString.appendf("%s%s=%s",
1223                                  formatString.empty() ? "" : ", ",
1224                                  fieldName.c_str(),
1225                                  runtimeValue.c_str());
1226         }
1227 
1228         if (!formatString.empty()) {
1229             // Emit the finished format string and associated arguments.
1230             this->writef("    return SkStringPrintf(\"(%s)\"", formatString.c_str());
1231 
1232             for (const String& argument : argumentList) {
1233                 this->writef(", %s", argument.c_str());
1234             }
1235 
1236             this->write(");");
1237         } else {
1238             // No fields to dump at all; just return an empty string.
1239             this->write("    return SkString();");
1240         }
1241     }
1242 
1243     this->write("\n"
1244                 "}\n"
1245                 "#endif\n");
1246 }
1247 
writeTest()1248 void DSLCPPCodeGenerator::writeTest() {
1249     const Section* test = fSectionAndParameterHelper.getSection(kTestCodeSection);
1250     if (test) {
1251         this->writef(
1252                 "GR_DEFINE_FRAGMENT_PROCESSOR_TEST(%s);\n"
1253                 "#if GR_TEST_UTILS\n"
1254                 "std::unique_ptr<GrFragmentProcessor> %s::TestCreate(GrProcessorTestData* %s) {\n",
1255                 fFullName.c_str(),
1256                 fFullName.c_str(),
1257                 test->argument().c_str());
1258         this->writeSection(kTestCodeSection);
1259         this->write("}\n"
1260                     "#endif\n");
1261     }
1262 }
1263 
bits_needed(uint32_t v)1264 static int bits_needed(uint32_t v) {
1265     int bits = 1;
1266     while (v >= (1u << bits)) {
1267         bits++;
1268     }
1269     return bits;
1270 }
1271 
writeGetKey()1272 void DSLCPPCodeGenerator::writeGetKey() {
1273     auto bitsForEnum = [&](const Type& type) {
1274         for (const ProgramElement* e : fProgram.elements()) {
1275             if (!e->is<Enum>() || type.name() != e->as<Enum>().typeName()) {
1276                 continue;
1277             }
1278             SKSL_INT minVal = 0, maxVal = 0;
1279             auto gatherEnumRange = [&](StringFragment, SKSL_INT value) {
1280                 minVal = std::min(minVal, value);
1281                 maxVal = std::max(maxVal, value);
1282             };
1283             e->as<Enum>().foreach(gatherEnumRange);
1284             if (minVal < 0) {
1285                 // Found a negative value in the enum, just use 32 bits
1286                 return 32;
1287             }
1288             SkASSERT(SkTFitsIn<uint32_t>(maxVal));
1289             return bits_needed(maxVal);
1290         }
1291         SK_ABORT("Didn't find declaring element for enum type!");
1292         return 32;
1293     };
1294 
1295     this->writef("void %s::onGetGLSLProcessorKey(const GrShaderCaps& caps, "
1296                                                 "GrProcessorKeyBuilder* b) const {\n",
1297                  fFullName.c_str());
1298     for (const ProgramElement* p : fProgram.elements()) {
1299         if (p->is<GlobalVarDeclaration>()) {
1300             const GlobalVarDeclaration& global = p->as<GlobalVarDeclaration>();
1301             const VarDeclaration& decl = global.declaration()->as<VarDeclaration>();
1302             const Variable& var = decl.var();
1303             const Type& varType = var.type();
1304             String nameString(var.name());
1305             const char* name = nameString.c_str();
1306             if (var.modifiers().fLayout.fFlags & Layout::kKey_Flag) {
1307                 if (var.modifiers().fFlags & Modifiers::kUniform_Flag) {
1308                     fErrors.error(var.fOffset, "layout(key) may not be specified on uniforms");
1309                 }
1310                 if (is_private(var)) {
1311                     this->writef("%s %.*s = ",
1312                                  HCodeGenerator::FieldType(fContext, varType,
1313                                                            var.modifiers().fLayout).c_str(),
1314                                  (int)var.name().size(), var.name().data());
1315                     if (decl.value()) {
1316                         fCPPMode = true;
1317                         this->writeExpression(*decl.value(), Precedence::kAssignment);
1318                         fCPPMode = false;
1319                     } else {
1320                         this->write(default_value(var));
1321                     }
1322                     this->write(";\n");
1323                 }
1324                 if (var.modifiers().fLayout.fWhen.fLength) {
1325                     this->writef("if (%s) {\n", String(var.modifiers().fLayout.fWhen).c_str());
1326                 }
1327                 if (varType == *fContext.fTypes.fHalf4) {
1328                     this->writef("    uint16_t red = SkFloatToHalf(%s.fR);\n",
1329                                  HCodeGenerator::FieldName(name).c_str());
1330                     this->writef("    uint16_t green = SkFloatToHalf(%s.fG);\n",
1331                                  HCodeGenerator::FieldName(name).c_str());
1332                     this->writef("    uint16_t blue = SkFloatToHalf(%s.fB);\n",
1333                                  HCodeGenerator::FieldName(name).c_str());
1334                     this->writef("    uint16_t alpha = SkFloatToHalf(%s.fA);\n",
1335                                  HCodeGenerator::FieldName(name).c_str());
1336                     this->writef("    b->add32(((uint32_t)red << 16) | green, \"%s.rg\");\n", name);
1337                     this->writef("    b->add32(((uint32_t)blue << 16) | alpha, \"%s.ba\");\n",
1338                                  name);
1339                 } else if (varType == *fContext.fTypes.fHalf ||
1340                            varType == *fContext.fTypes.fFloat) {
1341                     this->writef("    b->add32(sk_bit_cast<uint32_t>(%s), \"%s\");\n",
1342                                  HCodeGenerator::FieldName(name).c_str(), name);
1343                 } else if (varType.isBoolean()) {
1344                     this->writef("    b->addBool(%s, \"%s\");\n",
1345                                  HCodeGenerator::FieldName(name).c_str(), name);
1346                 } else if (varType.isEnum()) {
1347                     this->writef("    b->addBits(%d, (uint32_t) %s, \"%s\");\n",
1348                                  bitsForEnum(varType), HCodeGenerator::FieldName(name).c_str(),
1349                                  name);
1350                 } else if (varType.isInteger()) {
1351                     this->writef("    b->add32((uint32_t) %s, \"%s\");\n",
1352                                  HCodeGenerator::FieldName(name).c_str(), name);
1353                 } else {
1354                     SK_ABORT("NOT YET IMPLEMENTED: automatic key handling for %s\n",
1355                              varType.displayName().c_str());
1356                 }
1357                 if (var.modifiers().fLayout.fWhen.fLength) {
1358                     this->write("}\n");
1359                 }
1360             }
1361         }
1362     }
1363     this->write("}\n");
1364 }
1365 
generateCode()1366 bool DSLCPPCodeGenerator::generateCode() {
1367     std::vector<const Variable*> uniforms;
1368     for (const ProgramElement* p : fProgram.elements()) {
1369         if (p->is<GlobalVarDeclaration>()) {
1370             const GlobalVarDeclaration& global = p->as<GlobalVarDeclaration>();
1371             const VarDeclaration& decl = global.declaration()->as<VarDeclaration>();
1372             SkASSERT(decl.var().type().typeKind() != Type::TypeKind::kSampler);
1373 
1374             if (decl.var().modifiers().fFlags & Modifiers::kUniform_Flag) {
1375                 uniforms.push_back(&decl.var());
1376             }
1377 
1378             if (is_uniform_in(decl.var())) {
1379                 // Validate the "uniform in" declarations to make sure they are fully supported,
1380                 // instead of generating surprising C++
1381                 const UniformCTypeMapper* mapper = UniformCTypeMapper::Get(fContext, decl.var());
1382                 if (mapper == nullptr) {
1383                     fErrors.error(decl.fOffset, String(decl.var().name())
1384                             + "'s type is not supported for use as a 'uniform in'");
1385                     return false;
1386                 }
1387             } else {
1388                 // If it's not a uniform_in, it's an error to be tracked
1389                 if (decl.var().modifiers().fLayout.fFlags & Layout::kTracked_Flag) {
1390                     fErrors.error(decl.fOffset, "Non-'in uniforms' cannot be tracked");
1391                     return false;
1392                 }
1393             }
1394         }
1395     }
1396     const char* baseName = fName.c_str();
1397     const char* fullName = fFullName.c_str();
1398     this->writef("%s\n", HCodeGenerator::GetHeader(fProgram, fErrors).c_str());
1399     this->writef(kFragmentProcessorHeader, fullName);
1400     this->write("/* TODO(skia:11854): DSLCPPCodeGenerator is currently a work in progress. */\n");
1401     this->writef("#include \"%s.h\"\n\n", fullName);
1402     this->writeSection(kCppSection);
1403     this->writef("#include \"src/core/SkUtils.h\"\n"
1404                  "#include \"src/gpu/GrTexture.h\"\n"
1405                  "#include \"src/gpu/glsl/GrGLSLFragmentProcessor.h\"\n"
1406                  "#include \"src/gpu/glsl/GrGLSLFragmentShaderBuilder.h\"\n"
1407                  "#include \"src/gpu/glsl/GrGLSLProgramBuilder.h\"\n"
1408                  "#include \"src/sksl/SkSLCPP.h\"\n"
1409                  "#include \"src/sksl/SkSLUtil.h\"\n"
1410                  "#include \"src/sksl/dsl/priv/DSLFPs.h\"\n"
1411                  "#include \"src/sksl/dsl/priv/DSLWriter.h\"\n"
1412                  "\n"
1413                  "class GrGLSL%s : public GrGLSLFragmentProcessor {\n"
1414                  "public:\n"
1415                  "    GrGLSL%s() {}\n",
1416                  baseName, baseName);
1417     bool result = this->writeEmitCode(uniforms);
1418     this->write("private:\n");
1419     this->writeSetData(uniforms);
1420     this->writePrivateVars();
1421     for (const Variable* u : uniforms) {
1422         if (needs_uniform_var(*u) && !(u->modifiers().fFlags & Modifiers::kIn_Flag)) {
1423             this->writef("    UniformHandle %.*sVar;\n", (int)u->name().size(), u->name().data());
1424         }
1425     }
1426     for (const Variable* param : fSectionAndParameterHelper.getParameters()) {
1427         if (needs_uniform_var(*param)) {
1428             this->writef("    UniformHandle %.*sVar;\n",
1429                          (int)param->name().size(), param->name().data());
1430         }
1431     }
1432     this->writef("};\n"
1433                  "std::unique_ptr<GrGLSLFragmentProcessor> %s::onMakeProgramImpl() const {\n"
1434                  "    return std::make_unique<GrGLSL%s>();\n"
1435                  "}\n",
1436                  fullName, baseName);
1437     this->writeGetKey();
1438     this->writef("bool %s::onIsEqual(const GrFragmentProcessor& other) const {\n"
1439                  "    const %s& that = other.cast<%s>();\n"
1440                  "    (void) that;\n",
1441                  fullName, fullName, fullName);
1442     for (const auto& param : fSectionAndParameterHelper.getParameters()) {
1443         if (param->type().isFragmentProcessor()) {
1444             continue;
1445         }
1446         String nameString(param->name());
1447         const char* name = nameString.c_str();
1448         this->writef("    if (%s != that.%s) return false;\n",
1449                      HCodeGenerator::FieldName(name).c_str(),
1450                      HCodeGenerator::FieldName(name).c_str());
1451     }
1452     this->write("    return true;\n"
1453                 "}\n");
1454     this->writeClone();
1455     this->writeDumpInfo();
1456     this->writeTest();
1457     this->writeSection(kCppEndSection);
1458 
1459     result &= 0 == fErrors.errorCount();
1460     return result;
1461 }
1462 
1463 }  // namespace SkSL
1464 
1465 #endif // defined(SKSL_STANDALONE) || GR_TEST_UTILS
1466