1 /*
2  * Copyright 2022 Google Inc.
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/codegen/SkSLWGSLCodeGenerator.h"
9 
10 #include <memory>
11 #include <optional>
12 #include <string>
13 #include <vector>
14 
15 #include "include/core/SkSpan.h"
16 #include "include/core/SkTypes.h"
17 #include "include/private/SkBitmaskEnum.h"
18 #include "include/private/SkSLIRNode.h"
19 #include "include/private/SkSLLayout.h"
20 #include "include/private/SkSLModifiers.h"
21 #include "include/private/SkSLProgramElement.h"
22 #include "include/private/SkSLStatement.h"
23 #include "include/private/SkSLString.h"
24 #include "include/private/SkSLSymbol.h"
25 #include "include/private/base/SkTArray.h"
26 #include "include/private/base/SkTo.h"
27 #include "include/sksl/SkSLErrorReporter.h"
28 #include "include/sksl/SkSLOperator.h"
29 #include "include/sksl/SkSLPosition.h"
30 #include "src/sksl/SkSLAnalysis.h"
31 #include "src/sksl/SkSLBuiltinTypes.h"
32 #include "src/sksl/SkSLCompiler.h"
33 #include "src/sksl/SkSLContext.h"
34 #include "src/sksl/SkSLOutputStream.h"
35 #include "src/sksl/SkSLProgramSettings.h"
36 #include "src/sksl/SkSLStringStream.h"
37 #include "src/sksl/SkSLUtil.h"
38 #include "src/sksl/analysis/SkSLProgramVisitor.h"
39 #include "src/sksl/ir/SkSLBinaryExpression.h"
40 #include "src/sksl/ir/SkSLBlock.h"
41 #include "src/sksl/ir/SkSLConstructor.h"
42 #include "src/sksl/ir/SkSLConstructorCompound.h"
43 #include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
44 #include "src/sksl/ir/SkSLExpression.h"
45 #include "src/sksl/ir/SkSLExpressionStatement.h"
46 #include "src/sksl/ir/SkSLFieldAccess.h"
47 #include "src/sksl/ir/SkSLFunctionCall.h"
48 #include "src/sksl/ir/SkSLFunctionDeclaration.h"
49 #include "src/sksl/ir/SkSLFunctionDefinition.h"
50 #include "src/sksl/ir/SkSLIfStatement.h"
51 #include "src/sksl/ir/SkSLIndexExpression.h"
52 #include "src/sksl/ir/SkSLInterfaceBlock.h"
53 #include "src/sksl/ir/SkSLLiteral.h"
54 #include "src/sksl/ir/SkSLProgram.h"
55 #include "src/sksl/ir/SkSLReturnStatement.h"
56 #include "src/sksl/ir/SkSLStructDefinition.h"
57 #include "src/sksl/ir/SkSLSwizzle.h"
58 #include "src/sksl/ir/SkSLSymbolTable.h"
59 #include "src/sksl/ir/SkSLTernaryExpression.h"
60 #include "src/sksl/ir/SkSLType.h"
61 #include "src/sksl/ir/SkSLVarDeclarations.h"
62 #include "src/sksl/ir/SkSLVariable.h"
63 #include "src/sksl/ir/SkSLVariableReference.h"
64 
65 // TODO(skia:13092): This is a temporary debug feature. Remove when the implementation is
66 // complete and this is no longer needed.
67 #define DUMP_SRC_IR 0
68 
69 namespace SkSL {
70 
71 enum class ProgramKind : int8_t;
72 
73 namespace {
74 
75 // See https://www.w3.org/TR/WGSL/#memory-view-types
76 enum class PtrAddressSpace {
77     kFunction,
78     kPrivate,
79     kStorage,
80 };
81 
pipeline_struct_prefix(ProgramKind kind)82 std::string_view pipeline_struct_prefix(ProgramKind kind) {
83     if (ProgramConfig::IsVertex(kind)) {
84         return "VS";
85     }
86     if (ProgramConfig::IsFragment(kind)) {
87         return "FS";
88     }
89     return "";
90 }
91 
address_space_to_str(PtrAddressSpace addressSpace)92 std::string_view address_space_to_str(PtrAddressSpace addressSpace) {
93     switch (addressSpace) {
94         case PtrAddressSpace::kFunction:
95             return "function";
96         case PtrAddressSpace::kPrivate:
97             return "private";
98         case PtrAddressSpace::kStorage:
99             return "storage";
100     }
101     SkDEBUGFAIL("unsupported ptr address space");
102     return "unsupported";
103 }
104 
to_scalar_type(const Type & type)105 std::string_view to_scalar_type(const Type& type) {
106     SkASSERT(type.typeKind() == Type::TypeKind::kScalar);
107     switch (type.numberKind()) {
108         // Floating-point numbers in WebGPU currently always have 32-bit footprint and
109         // relaxed-precision is not supported without extensions. f32 is the only floating-point
110         // number type in WGSL (see the discussion on https://github.com/gpuweb/gpuweb/issues/658).
111         case Type::NumberKind::kFloat:
112             return "f32";
113         case Type::NumberKind::kSigned:
114             return "i32";
115         case Type::NumberKind::kUnsigned:
116             return "u32";
117         case Type::NumberKind::kBoolean:
118             return "bool";
119         case Type::NumberKind::kNonnumeric:
120             [[fallthrough]];
121         default:
122             break;
123     }
124     return type.name();
125 }
126 
127 // Convert a SkSL type to a WGSL type. Handles all plain types except structure types
128 // (see https://www.w3.org/TR/WGSL/#plain-types-section).
to_wgsl_type(const Type & type)129 std::string to_wgsl_type(const Type& type) {
130     switch (type.typeKind()) {
131         case Type::TypeKind::kScalar:
132             return std::string(to_scalar_type(type));
133         case Type::TypeKind::kVector: {
134             std::string_view ct = to_scalar_type(type.componentType());
135             return String::printf("vec%d<%.*s>", type.columns(), (int)ct.length(), ct.data());
136         }
137         case Type::TypeKind::kMatrix: {
138             std::string_view ct = to_scalar_type(type.componentType());
139             return String::printf(
140                     "mat%dx%d<%.*s>", type.columns(), type.rows(), (int)ct.length(), ct.data());
141         }
142         case Type::TypeKind::kArray: {
143             std::string elementType = to_wgsl_type(type.componentType());
144             if (type.isUnsizedArray()) {
145                 return String::printf("array<%s>", elementType.c_str());
146             }
147             return String::printf("array<%s, %d>", elementType.c_str(), type.columns());
148         }
149         default:
150             break;
151     }
152     return std::string(type.name());
153 }
154 
155 // Create a mangled WGSL type name that can be used in function and variable declarations (regular
156 // type names cannot be used in this manner since they may contain tokens that are not allowed in
157 // symbol names).
to_mangled_wgsl_type_name(const Type & type)158 std::string to_mangled_wgsl_type_name(const Type& type) {
159     switch (type.typeKind()) {
160         case Type::TypeKind::kScalar:
161             return std::string(to_scalar_type(type));
162         case Type::TypeKind::kVector: {
163             std::string_view ct = to_scalar_type(type.componentType());
164             return String::printf("vec%d%.*s", type.columns(), (int)ct.length(), ct.data());
165         }
166         case Type::TypeKind::kMatrix: {
167             std::string_view ct = to_scalar_type(type.componentType());
168             return String::printf(
169                     "mat%dx%d%.*s", type.columns(), type.rows(), (int)ct.length(), ct.data());
170         }
171         case Type::TypeKind::kArray: {
172             std::string elementType = to_wgsl_type(type.componentType());
173             if (type.isUnsizedArray()) {
174                 return String::printf("arrayof%s", elementType.c_str());
175             }
176             return String::printf("array%dof%s", type.columns(), elementType.c_str());
177         }
178         default:
179             break;
180     }
181     return std::string(type.name());
182 }
183 
to_ptr_type(const Type & type,PtrAddressSpace addressSpace=PtrAddressSpace::kFunction)184 std::string to_ptr_type(const Type& type,
185                         PtrAddressSpace addressSpace = PtrAddressSpace::kFunction) {
186     return "ptr<" + std::string(address_space_to_str(addressSpace)) + ", " + to_wgsl_type(type) +
187            ">";
188 }
189 
wgsl_builtin_name(WGSLCodeGenerator::Builtin builtin)190 std::string_view wgsl_builtin_name(WGSLCodeGenerator::Builtin builtin) {
191     using Builtin = WGSLCodeGenerator::Builtin;
192     switch (builtin) {
193         case Builtin::kVertexIndex:
194             return "vertex_index";
195         case Builtin::kInstanceIndex:
196             return "instance_index";
197         case Builtin::kPosition:
198             return "position";
199         case Builtin::kFrontFacing:
200             return "front_facing";
201         case Builtin::kSampleIndex:
202             return "sample_index";
203         case Builtin::kFragDepth:
204             return "frag_depth";
205         case Builtin::kSampleMask:
206             return "sample_mask";
207         case Builtin::kLocalInvocationId:
208             return "local_invocation_id";
209         case Builtin::kLocalInvocationIndex:
210             return "local_invocation_index";
211         case Builtin::kGlobalInvocationId:
212             return "global_invocation_id";
213         case Builtin::kWorkgroupId:
214             return "workgroup_id";
215         case Builtin::kNumWorkgroups:
216             return "num_workgroups";
217         default:
218             break;
219     }
220 
221     SkDEBUGFAIL("unsupported builtin");
222     return "unsupported";
223 }
224 
wgsl_builtin_type(WGSLCodeGenerator::Builtin builtin)225 std::string_view wgsl_builtin_type(WGSLCodeGenerator::Builtin builtin) {
226     using Builtin = WGSLCodeGenerator::Builtin;
227     switch (builtin) {
228         case Builtin::kVertexIndex:
229             return "u32";
230         case Builtin::kInstanceIndex:
231             return "u32";
232         case Builtin::kPosition:
233             return "vec4<f32>";
234         case Builtin::kFrontFacing:
235             return "bool";
236         case Builtin::kSampleIndex:
237             return "u32";
238         case Builtin::kFragDepth:
239             return "f32";
240         case Builtin::kSampleMask:
241             return "u32";
242         case Builtin::kLocalInvocationId:
243             return "vec3<u32>";
244         case Builtin::kLocalInvocationIndex:
245             return "u32";
246         case Builtin::kGlobalInvocationId:
247             return "vec3<u32>";
248         case Builtin::kWorkgroupId:
249             return "vec3<u32>";
250         case Builtin::kNumWorkgroups:
251             return "vec3<u32>";
252         default:
253             break;
254     }
255 
256     SkDEBUGFAIL("unsupported builtin");
257     return "unsupported";
258 }
259 
260 // Some built-in variables have a type that differs from their SkSL counterpart (e.g. signed vs
261 // unsigned integer). We handle these cases with an explicit type conversion during a variable
262 // reference. Returns the WGSL type of the conversion target if conversion is needed, otherwise
263 // returns std::nullopt.
needs_builtin_type_conversion(const Variable & v)264 std::optional<std::string_view> needs_builtin_type_conversion(const Variable& v) {
265     switch (v.modifiers().fLayout.fBuiltin) {
266         case SK_VERTEXID_BUILTIN:
267         case SK_INSTANCEID_BUILTIN:
268             return {"i32"};
269         default:
270             break;
271     }
272     return std::nullopt;
273 }
274 
275 // Map a SkSL builtin flag to a WGSL builtin kind. Returns std::nullopt if `builtin` is not
276 // not supported for WGSL.
277 //
278 // Also see //src/sksl/sksl_vert.sksl and //src/sksl/sksl_frag.sksl for supported built-ins.
builtin_from_sksl_name(int builtin)279 std::optional<WGSLCodeGenerator::Builtin> builtin_from_sksl_name(int builtin) {
280     using Builtin = WGSLCodeGenerator::Builtin;
281     switch (builtin) {
282         case SK_POSITION_BUILTIN:
283             [[fallthrough]];
284         case SK_FRAGCOORD_BUILTIN:
285             return {Builtin::kPosition};
286         case SK_VERTEXID_BUILTIN:
287             return {Builtin::kVertexIndex};
288         case SK_INSTANCEID_BUILTIN:
289             return {Builtin::kInstanceIndex};
290         case SK_CLOCKWISE_BUILTIN:
291             // TODO(skia:13092): While `front_facing` is the corresponding built-in, it does not
292             // imply a particular winding order. We correctly compute the face orientation based
293             // on how Skia configured the render pipeline for all references to this built-in
294             // variable (see `SkSL::Program::Inputs::fUseFlipRTUniform`).
295             return {Builtin::kFrontFacing};
296         default:
297             break;
298     }
299     return std::nullopt;
300 }
301 
top_level_symbol_table(const FunctionDefinition & f)302 const SymbolTable* top_level_symbol_table(const FunctionDefinition& f) {
303     return f.body()->as<Block>().symbolTable()->fParent.get();
304 }
305 
delimiter_to_str(WGSLCodeGenerator::Delimiter delimiter)306 const char* delimiter_to_str(WGSLCodeGenerator::Delimiter delimiter) {
307     using Delim = WGSLCodeGenerator::Delimiter;
308     switch (delimiter) {
309         case Delim::kComma:
310             return ",";
311         case Delim::kSemicolon:
312             return ";";
313         case Delim::kNone:
314         default:
315             break;
316     }
317     return "";
318 }
319 
320 // FunctionDependencyResolver visits the IR tree rooted at a particular function definition and
321 // computes that function's dependencies on pipeline stage IO parameters. These are later used to
322 // synthesize arguments when writing out function definitions.
323 class FunctionDependencyResolver : public ProgramVisitor {
324 public:
325     using Deps = WGSLCodeGenerator::FunctionDependencies;
326     using DepsMap = WGSLCodeGenerator::ProgramRequirements::DepsMap;
327 
FunctionDependencyResolver(const Program * p,const FunctionDeclaration * f,DepsMap * programDependencyMap)328     FunctionDependencyResolver(const Program* p,
329                                const FunctionDeclaration* f,
330                                DepsMap* programDependencyMap)
331             : fProgram(p), fFunction(f), fDependencyMap(programDependencyMap) {}
332 
resolve()333     Deps resolve() {
334         fDeps = Deps::kNone;
335         this->visit(*fProgram);
336         return fDeps;
337     }
338 
339 private:
visitProgramElement(const ProgramElement & p)340     bool visitProgramElement(const ProgramElement& p) override {
341         // Only visit the program that matches the requested function.
342         if (p.is<FunctionDefinition>() && &p.as<FunctionDefinition>().declaration() == fFunction) {
343             return INHERITED::visitProgramElement(p);
344         }
345         // Continue visiting other program elements.
346         return false;
347     }
348 
visitExpression(const Expression & e)349     bool visitExpression(const Expression& e) override {
350         if (e.is<VariableReference>()) {
351             const VariableReference& v = e.as<VariableReference>();
352             const Modifiers& modifiers = v.variable()->modifiers();
353             if (v.variable()->storage() == Variable::Storage::kGlobal) {
354                 if (modifiers.fFlags & Modifiers::kIn_Flag) {
355                     fDeps |= Deps::kPipelineInputs;
356                 }
357                 if (modifiers.fFlags & Modifiers::kOut_Flag) {
358                     fDeps |= Deps::kPipelineOutputs;
359                 }
360             }
361         } else if (e.is<FunctionCall>()) {
362             // The current function that we're processing (`fFunction`) inherits the dependencies of
363             // functions that it makes calls to, because the pipeline stage IO parameters need to be
364             // passed down as an argument.
365             const FunctionCall& callee = e.as<FunctionCall>();
366 
367             // Don't process a function again if we have already resolved it.
368             Deps* found = fDependencyMap->find(&callee.function());
369             if (found) {
370                 fDeps |= *found;
371             } else {
372                 // Store the dependencies that have been discovered for the current function so far.
373                 // If `callee` directly or indirectly calls the current function, then this value
374                 // will prevent an infinite recursion.
375                 fDependencyMap->set(fFunction, fDeps);
376 
377                 // Separately traverse the called function's definition and determine its
378                 // dependencies.
379                 FunctionDependencyResolver resolver(fProgram, &callee.function(), fDependencyMap);
380                 Deps calleeDeps = resolver.resolve();
381 
382                 // Store the callee's dependencies in the global map to avoid processing
383                 // the function again for future calls.
384                 fDependencyMap->set(&callee.function(), calleeDeps);
385 
386                 // Add to the current function's dependencies.
387                 fDeps |= calleeDeps;
388             }
389         }
390         return INHERITED::visitExpression(e);
391     }
392 
393     const Program* const fProgram;
394     const FunctionDeclaration* const fFunction;
395     DepsMap* const fDependencyMap;
396     Deps fDeps = Deps::kNone;
397 
398     using INHERITED = ProgramVisitor;
399 };
400 
resolve_program_requirements(const Program * program)401 WGSLCodeGenerator::ProgramRequirements resolve_program_requirements(const Program* program) {
402     bool mainNeedsCoordsArgument = false;
403     WGSLCodeGenerator::ProgramRequirements::DepsMap dependencies;
404 
405     for (const ProgramElement* e : program->elements()) {
406         if (!e->is<FunctionDefinition>()) {
407             continue;
408         }
409 
410         const FunctionDeclaration& decl = e->as<FunctionDefinition>().declaration();
411         if (decl.isMain()) {
412             for (const Variable* v : decl.parameters()) {
413                 if (v->modifiers().fLayout.fBuiltin == SK_MAIN_COORDS_BUILTIN) {
414                     mainNeedsCoordsArgument = true;
415                     break;
416                 }
417             }
418         }
419 
420         FunctionDependencyResolver resolver(program, &decl, &dependencies);
421         dependencies.set(&decl, resolver.resolve());
422     }
423 
424     return WGSLCodeGenerator::ProgramRequirements(std::move(dependencies), mainNeedsCoordsArgument);
425 }
426 
count_pipeline_inputs(const Program * program)427 int count_pipeline_inputs(const Program* program) {
428     int inputCount = 0;
429     for (const ProgramElement* e : program->elements()) {
430         if (e->is<GlobalVarDeclaration>()) {
431             const Variable* v = e->as<GlobalVarDeclaration>().varDeclaration().var();
432             if (v->modifiers().fFlags & Modifiers::kIn_Flag) {
433                 inputCount++;
434             }
435         } else if (e->is<InterfaceBlock>()) {
436             const Variable* v = e->as<InterfaceBlock>().var();
437             if (v->modifiers().fFlags & Modifiers::kIn_Flag) {
438                 inputCount++;
439             }
440         }
441     }
442     return inputCount;
443 }
444 
is_in_global_uniforms(const Variable & var)445 static bool is_in_global_uniforms(const Variable& var) {
446     SkASSERT(var.storage() == VariableStorage::kGlobal);
447     return var.modifiers().fFlags & Modifiers::kUniform_Flag && !var.type().isOpaque();
448 }
449 
450 }  // namespace
451 
generateCode()452 bool WGSLCodeGenerator::generateCode() {
453     // The resources of a WGSL program are structured in the following way:
454     // - Vertex and fragment stage attribute inputs and outputs are bundled
455     //   inside synthetic structs called VSIn/VSOut/FSIn/FSOut.
456     // - All uniform and storage type resources are declared in global scope.
457     this->preprocessProgram();
458 
459     StringStream header;
460     {
461         AutoOutputStream outputToHeader(this, &header, &fIndentation);
462         // TODO(skia:13092): Implement the following:
463         // - global uniform/storage resource declarations, including interface blocks.
464         this->writeStageInputStruct();
465         this->writeStageOutputStruct();
466         this->writeNonBlockUniformsForTests();
467     }
468     StringStream body;
469     {
470         AutoOutputStream outputToBody(this, &body, &fIndentation);
471         for (const ProgramElement* e : fProgram.elements()) {
472             this->writeProgramElement(*e);
473         }
474 
475 // TODO(skia:13092): This is a temporary debug feature. Remove when the implementation is
476 // complete and this is no longer needed.
477 #if DUMP_SRC_IR
478         this->writeLine("\n----------");
479         this->writeLine("Source IR:\n");
480         for (const ProgramElement* e : fProgram.elements()) {
481             this->writeLine(e->description().c_str());
482         }
483 #endif
484     }
485 
486     write_stringstream(header, *fOut);
487     write_stringstream(fExtraFunctions, *fOut);
488     write_stringstream(body, *fOut);
489     return fContext.fErrors->errorCount() == 0;
490 }
491 
preprocessProgram()492 void WGSLCodeGenerator::preprocessProgram() {
493     fRequirements = resolve_program_requirements(&fProgram);
494     fPipelineInputCount = count_pipeline_inputs(&fProgram);
495 }
496 
write(std::string_view s)497 void WGSLCodeGenerator::write(std::string_view s) {
498     if (s.empty()) {
499         return;
500     }
501     if (fAtLineStart) {
502         for (int i = 0; i < fIndentation; i++) {
503             fOut->writeText("    ");
504         }
505     }
506     fOut->writeText(std::string(s).c_str());
507     fAtLineStart = false;
508 }
509 
writeLine(std::string_view s)510 void WGSLCodeGenerator::writeLine(std::string_view s) {
511     this->write(s);
512     fOut->writeText("\n");
513     fAtLineStart = true;
514 }
515 
finishLine()516 void WGSLCodeGenerator::finishLine() {
517     if (!fAtLineStart) {
518         this->writeLine();
519     }
520 }
521 
writeName(std::string_view name)522 void WGSLCodeGenerator::writeName(std::string_view name) {
523     // Add underscore before name to avoid conflict with reserved words.
524     if (fReservedWords.contains(name)) {
525         this->write("_");
526     }
527     this->write(name);
528 }
529 
writeVariableDecl(const Type & type,std::string_view name,Delimiter delimiter)530 void WGSLCodeGenerator::writeVariableDecl(const Type& type,
531                                           std::string_view name,
532                                           Delimiter delimiter) {
533     this->writeName(name);
534     this->write(": " + to_wgsl_type(type));
535     this->writeLine(delimiter_to_str(delimiter));
536 }
537 
writePipelineIODeclaration(Modifiers modifiers,const Type & type,std::string_view name,Delimiter delimiter)538 void WGSLCodeGenerator::writePipelineIODeclaration(Modifiers modifiers,
539                                                    const Type& type,
540                                                    std::string_view name,
541                                                    Delimiter delimiter) {
542     // In WGSL, an entry-point IO parameter is "one of either a built-in value or
543     // assigned a location". However, some SkSL declarations, specifically sk_FragColor, can
544     // contain both a location and a builtin modifier. In addition, WGSL doesn't have a built-in
545     // equivalent for sk_FragColor as it relies on the user-defined location for a render
546     // target.
547     //
548     // Instead of special-casing sk_FragColor, we just give higher precedence to a location
549     // modifier if a declaration happens to both have a location and it's a built-in.
550     //
551     // Also see:
552     // https://www.w3.org/TR/WGSL/#input-output-locations
553     // https://www.w3.org/TR/WGSL/#attribute-location
554     // https://www.w3.org/TR/WGSL/#builtin-inputs-outputs
555     int location = modifiers.fLayout.fLocation;
556     if (location >= 0) {
557         this->writeUserDefinedIODecl(type, name, location, delimiter);
558     } else if (modifiers.fLayout.fBuiltin >= 0) {
559         auto builtin = builtin_from_sksl_name(modifiers.fLayout.fBuiltin);
560         if (builtin.has_value()) {
561             this->writeBuiltinIODecl(type, name, *builtin, delimiter);
562         }
563     }
564 }
565 
writeUserDefinedIODecl(const Type & type,std::string_view name,int location,Delimiter delimiter)566 void WGSLCodeGenerator::writeUserDefinedIODecl(const Type& type,
567                                                std::string_view name,
568                                                int location,
569                                                Delimiter delimiter) {
570     this->write("@location(" + std::to_string(location) + ") ");
571 
572     // "User-defined IO of scalar or vector integer type must always be specified as
573     // @interpolate(flat)" (see https://www.w3.org/TR/WGSL/#interpolation)
574     if (type.isInteger() || (type.isVector() && type.componentType().isInteger())) {
575         this->write("@interpolate(flat) ");
576     }
577 
578     this->writeVariableDecl(type, name, delimiter);
579 }
580 
writeBuiltinIODecl(const Type & type,std::string_view name,Builtin builtin,Delimiter delimiter)581 void WGSLCodeGenerator::writeBuiltinIODecl(const Type& type,
582                                            std::string_view name,
583                                            Builtin builtin,
584                                            Delimiter delimiter) {
585     this->write("@builtin(");
586     this->write(wgsl_builtin_name(builtin));
587     this->write(") ");
588 
589     this->writeName(name);
590     this->write(": ");
591     this->write(wgsl_builtin_type(builtin));
592     this->writeLine(delimiter_to_str(delimiter));
593 }
594 
writeFunction(const FunctionDefinition & f)595 void WGSLCodeGenerator::writeFunction(const FunctionDefinition& f) {
596     this->writeFunctionDeclaration(f.declaration());
597     this->write(" ");
598     this->writeBlock(f.body()->as<Block>());
599 
600     if (f.declaration().isMain()) {
601         // We just emitted the user-defined main function. Next, we generate a program entry point
602         // that calls the user-defined main.
603         this->writeEntryPoint(f);
604     }
605 }
606 
writeFunctionDeclaration(const FunctionDeclaration & f)607 void WGSLCodeGenerator::writeFunctionDeclaration(const FunctionDeclaration& f) {
608     this->write("fn ");
609     this->write(f.mangledName());
610     this->write("(");
611     auto separator = SkSL::String::Separator();
612     if (this->writeFunctionDependencyParams(f)) {
613         separator();  // update the separator as parameters have been written
614     }
615     for (const Variable* param : f.parameters()) {
616         this->write(separator());
617         this->writeName(param->mangledName());
618         this->write(": ");
619 
620         // Declare an "out" function parameter as a pointer.
621         if (param->modifiers().fFlags & Modifiers::kOut_Flag) {
622             this->write(to_ptr_type(param->type()));
623         } else {
624             this->write(to_wgsl_type(param->type()));
625         }
626     }
627     this->write(")");
628     if (!f.returnType().isVoid()) {
629         this->write(" -> ");
630         this->write(to_wgsl_type(f.returnType()));
631     }
632 }
633 
writeEntryPoint(const FunctionDefinition & main)634 void WGSLCodeGenerator::writeEntryPoint(const FunctionDefinition& main) {
635     SkASSERT(main.declaration().isMain());
636 
637     // The input and output parameters for a vertex/fragment stage entry point function have the
638     // FSIn/FSOut/VSIn/VSOut struct types that have been synthesized in generateCode(). An entry
639     // point always has the same signature and acts as a trampoline to the user-defined main
640     // function.
641     std::string outputType;
642     if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
643         this->write("@vertex fn vertexMain(");
644         if (fPipelineInputCount > 0) {
645             this->write("_stageIn: VSIn");
646         }
647         this->writeLine(") -> VSOut {");
648         outputType = "VSOut";
649     } else if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
650         this->write("@fragment fn fragmentMain(");
651         if (fPipelineInputCount > 0) {
652             this->write("_stageIn: FSIn");
653         }
654         this->writeLine(") -> FSOut {");
655         outputType = "FSOut";
656     } else {
657         fContext.fErrors->error(Position(), "program kind not supported");
658         return;
659     }
660 
661     // Declare the stage output struct.
662     fIndentation++;
663     this->write("var _stageOut: ");
664     this->write(outputType);
665     this->writeLine(";");
666 
667     // Generate assignment to sk_FragColor built-in if the user-defined main returns a color.
668     if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
669         const SymbolTable* symbolTable = top_level_symbol_table(main);
670         const Symbol* symbol = symbolTable->find("sk_FragColor");
671         SkASSERT(symbol);
672         if (main.declaration().returnType().matches(symbol->type())) {
673             this->write("_stageOut.sk_FragColor = ");
674         }
675     }
676 
677     // Generate the function call to the user-defined main:
678     this->write(main.declaration().mangledName());
679     this->write("(");
680     auto separator = SkSL::String::Separator();
681     FunctionDependencies* deps = fRequirements.dependencies.find(&main.declaration());
682     if (deps) {
683         if ((*deps & FunctionDependencies::kPipelineInputs) != FunctionDependencies::kNone) {
684             this->write(separator());
685             this->write("_stageIn");
686         }
687         if ((*deps & FunctionDependencies::kPipelineOutputs) != FunctionDependencies::kNone) {
688             this->write(separator());
689             this->write("&_stageOut");
690         }
691     }
692     // TODO(armansito): Handle arbitrary parameters.
693     if (main.declaration().parameters().size() != 0) {
694         const Variable* v = main.declaration().parameters()[0];
695         const Type& type = v->type();
696         if (v->modifiers().fLayout.fBuiltin == SK_MAIN_COORDS_BUILTIN) {
697             if (!type.matches(*fContext.fTypes.fFloat2)) {
698                 fContext.fErrors->error(
699                         main.fPosition,
700                         "main function has unsupported parameter: " + type.description());
701                 return;
702             }
703 
704             this->write(separator());
705             this->write("_stageIn.sk_FragCoord.xy");
706         }
707     }
708     this->writeLine(");");
709     this->writeLine("return _stageOut;");
710 
711     fIndentation--;
712     this->writeLine("}");
713 }
714 
writeStatement(const Statement & s)715 void WGSLCodeGenerator::writeStatement(const Statement& s) {
716     switch (s.kind()) {
717         case Statement::Kind::kBlock:
718             this->writeBlock(s.as<Block>());
719             break;
720         case Statement::Kind::kExpression:
721             this->writeExpressionStatement(s.as<ExpressionStatement>());
722             break;
723         case Statement::Kind::kIf:
724             this->writeIfStatement(s.as<IfStatement>());
725             break;
726         case Statement::Kind::kReturn:
727             this->writeReturnStatement(s.as<ReturnStatement>());
728             break;
729         case Statement::Kind::kVarDeclaration:
730             this->writeVarDeclaration(s.as<VarDeclaration>());
731             break;
732         default:
733             SkDEBUGFAILF("unsupported statement (kind: %d) %s",
734                          static_cast<int>(s.kind()), s.description().c_str());
735             break;
736     }
737 }
738 
writeStatements(const StatementArray & statements)739 void WGSLCodeGenerator::writeStatements(const StatementArray& statements) {
740     for (const auto& s : statements) {
741         if (!s->isEmpty()) {
742             this->writeStatement(*s);
743             this->finishLine();
744         }
745     }
746 }
747 
writeBlock(const Block & b)748 void WGSLCodeGenerator::writeBlock(const Block& b) {
749     // Write scope markers if this block is a scope, or if the block is empty (since we need to emit
750     // something here to make the code valid).
751     bool isScope = b.isScope() || b.isEmpty();
752     if (isScope) {
753         this->writeLine("{");
754         fIndentation++;
755     }
756     this->writeStatements(b.children());
757     if (isScope) {
758         fIndentation--;
759         this->writeLine("}");
760     }
761 }
762 
writeExpressionStatement(const ExpressionStatement & s)763 void WGSLCodeGenerator::writeExpressionStatement(const ExpressionStatement& s) {
764     if (Analysis::HasSideEffects(*s.expression())) {
765         this->writeExpression(*s.expression(), Precedence::kTopLevel);
766         this->write(";");
767     }
768 }
769 
writeIfStatement(const IfStatement & s)770 void WGSLCodeGenerator::writeIfStatement(const IfStatement& s) {
771     this->write("if (");
772     this->writeExpression(*s.test(), Precedence::kTopLevel);
773     this->write(") ");
774     this->writeStatement(*s.ifTrue());
775     if (s.ifFalse()) {
776         this->write("else ");
777         this->writeStatement(*s.ifFalse());
778     }
779 }
780 
writeReturnStatement(const ReturnStatement & s)781 void WGSLCodeGenerator::writeReturnStatement(const ReturnStatement& s) {
782     this->write("return");
783     if (s.expression()) {
784         this->write(" ");
785         this->writeExpression(*s.expression(), Precedence::kTopLevel);
786     }
787     this->write(";");
788 }
789 
writeVarDeclaration(const VarDeclaration & varDecl)790 void WGSLCodeGenerator::writeVarDeclaration(const VarDeclaration& varDecl) {
791     bool isConst = varDecl.var()->modifiers().fFlags & Modifiers::kConst_Flag;
792     if (isConst) {
793         this->write("let ");
794     } else {
795         this->write("var ");
796     }
797     this->writeName(varDecl.var()->mangledName());
798     this->write(": ");
799     this->write(to_wgsl_type(varDecl.var()->type()));
800 
801     if (varDecl.value()) {
802         this->write(" = ");
803         this->writeExpression(*varDecl.value(), Precedence::kTopLevel);
804     } else if (isConst) {
805         SkDEBUGFAILF("A let-declared constant must specify a value");
806     }
807 
808     this->write(";");
809 }
810 
writeExpression(const Expression & e,Precedence parentPrecedence)811 void WGSLCodeGenerator::writeExpression(const Expression& e, Precedence parentPrecedence) {
812     switch (e.kind()) {
813         case Expression::Kind::kBinary:
814             this->writeBinaryExpression(e.as<BinaryExpression>(), parentPrecedence);
815             break;
816         case Expression::Kind::kConstructorCompound:
817             this->writeConstructorCompound(e.as<ConstructorCompound>(), parentPrecedence);
818             break;
819         case Expression::Kind::kConstructorCompoundCast:
820         case Expression::Kind::kConstructorScalarCast:
821         case Expression::Kind::kConstructorSplat:
822             this->writeAnyConstructor(e.asAnyConstructor(), parentPrecedence);
823             break;
824         case Expression::Kind::kConstructorDiagonalMatrix:
825             this->writeConstructorDiagonalMatrix(e.as<ConstructorDiagonalMatrix>(),
826                                                  parentPrecedence);
827             break;
828         case Expression::Kind::kFieldAccess:
829             this->writeFieldAccess(e.as<FieldAccess>());
830             break;
831         case Expression::Kind::kFunctionCall:
832             this->writeFunctionCall(e.as<FunctionCall>());
833             break;
834         case Expression::Kind::kIndex:
835             this->writeIndexExpression(e.as<IndexExpression>());
836             break;
837         case Expression::Kind::kLiteral:
838             this->writeLiteral(e.as<Literal>());
839             break;
840         case Expression::Kind::kSwizzle:
841             this->writeSwizzle(e.as<Swizzle>());
842             break;
843         case Expression::Kind::kTernary:
844             this->writeTernaryExpression(e.as<TernaryExpression>(), parentPrecedence);
845             break;
846         case Expression::Kind::kVariableReference:
847             this->writeVariableReference(e.as<VariableReference>());
848             break;
849         default:
850             SkDEBUGFAILF("unsupported expression (kind: %d) %s",
851                          static_cast<int>(e.kind()),
852                          e.description().c_str());
853             break;
854     }
855 }
856 
writeBinaryExpression(const BinaryExpression & b,Precedence parentPrecedence)857 void WGSLCodeGenerator::writeBinaryExpression(const BinaryExpression& b,
858                                               Precedence parentPrecedence) {
859     const Expression& left = *b.left();
860     const Expression& right = *b.right();
861     Operator op = b.getOperator();
862 
863     // The equality and comparison operators are only supported for scalar and vector types.
864     if (op.isEquality() && !left.type().isScalar() && !left.type().isVector()) {
865         if (left.type().isMatrix()) {
866             if (op.kind() == OperatorKind::NEQ) {
867                 this->write("!");
868             }
869             this->writeMatrixEquality(left, right);
870             return;
871         }
872 
873         // TODO(skia:13092): Synthesize helper functions for structs and arrays.
874         return;
875     }
876 
877     Precedence precedence = op.getBinaryPrecedence();
878     bool needParens = precedence >= parentPrecedence;
879 
880     // The equality operators ('=='/'!=') in WGSL apply component-wise to vectors and result in a
881     // vector. We need to reduce the value to a boolean.
882     if (left.type().isVector()) {
883         if (op.kind() == Operator::Kind::EQEQ) {
884             this->write("all");
885             needParens = true;
886         } else if (op.kind() == Operator::Kind::NEQ) {
887             this->write("any");
888             needParens = true;
889         }
890     }
891 
892     if (needParens) {
893         this->write("(");
894     }
895 
896     // TODO(skia:13092): Correctly handle the case when lhs is a pointer.
897 
898     this->writeExpression(left, precedence);
899     this->write(op.operatorName());
900     this->writeExpression(right, precedence);
901 
902     if (needParens) {
903         this->write(")");
904     }
905 }
906 
writeFieldAccess(const FieldAccess & f)907 void WGSLCodeGenerator::writeFieldAccess(const FieldAccess& f) {
908     const Type::Field* field = &f.base()->type().fields()[f.fieldIndex()];
909     if (FieldAccess::OwnerKind::kDefault == f.ownerKind()) {
910         this->writeExpression(*f.base(), Precedence::kPostfix);
911         this->write(".");
912     } else {
913         // We are accessing a field in an anonymous interface block. If the field refers to a
914         // pipeline IO parameter, then we access it via the synthesized IO structs. We make an
915         // explicit exception for `sk_PointSize` which we declare as a placeholder variable in
916         // global scope as it is not supported by WebGPU as a pipeline IO parameter (see comments
917         // in `writeStageOutputStruct`).
918         const Variable& v = *f.base()->as<VariableReference>().variable();
919         if (v.modifiers().fFlags & Modifiers::kIn_Flag) {
920             this->write("_stageIn.");
921         } else if (v.modifiers().fFlags & Modifiers::kOut_Flag &&
922                    field->fModifiers.fLayout.fBuiltin != SK_POINTSIZE_BUILTIN) {
923             this->write("(*_stageOut).");
924         } else {
925             // TODO(skia:13092): Reference the variable using the base name used for its
926             // uniform/storage block global declaration.
927         }
928     }
929     this->writeName(field->fName);
930 }
931 
writeFunctionCall(const FunctionCall & c)932 void WGSLCodeGenerator::writeFunctionCall(const FunctionCall& c) {
933     const FunctionDeclaration& func = c.function();
934 
935     // TODO(skia:13092): Handle intrinsic call as many of them need to be rewritten.
936 
937     // We implement function out-parameters by declaring them as pointers. SkSL follows GLSL's
938     // out-parameter semantics, in which out-parameters are only written back to the original
939     // variable after the function's execution is complete (see
940     // https://www.khronos.org/opengl/wiki/Core_Language_(GLSL)#Parameters).
941     //
942     // In addition, SkSL supports swizzles and array index expressions to be passed into
943     // out-parameters however WGSL does not allow taking their address into a pointer.
944     //
945     // We support these by wrapping each function call in a special helper, which internally stores
946     // all out parameters in temporaries.
947 
948     // First detect which arguments are passed to out-parameters.
949     const ExpressionArray& args = c.arguments();
950     const std::vector<Variable*>& params = func.parameters();
951     SkASSERT(SkToSizeT(args.size()) == params.size());
952 
953     bool foundOutParam = false;
954     SkSTArray<16, VariableReference*> outVars;
955     outVars.push_back_n(args.size(), static_cast<VariableReference*>(nullptr));
956 
957     for (int i = 0; i < args.size(); ++i) {
958         if (params[i]->modifiers().fFlags & Modifiers::kOut_Flag) {
959             // Find the expression's inner variable being written to. Assignability was verified at
960             // IR generation time, so this should always succeed.
961             Analysis::AssignmentInfo info;
962             SkAssertResult(Analysis::IsAssignable(*args[i], &info));
963             outVars[i] = info.fAssignedVar;
964             foundOutParam = true;
965         }
966     }
967 
968     if (foundOutParam) {
969         this->writeName(this->writeOutParamHelper(c, args, outVars));
970     } else {
971         this->writeName(func.mangledName());
972     }
973 
974     this->write("(");
975     auto separator = SkSL::String::Separator();
976     if (this->writeFunctionDependencyArgs(func)) {
977         separator();
978     }
979     for (int i = 0; i < args.size(); ++i) {
980         this->write(separator());
981         if (outVars[i]) {
982             // We need to take the address of the variable and pass it down as a pointer.
983             this->write("&");
984             this->writeExpression(*outVars[i], Precedence::kSequence);
985         } else {
986             this->writeExpression(*args[i], Precedence::kSequence);
987         }
988     }
989     this->write(")");
990 }
991 
writeIndexExpression(const IndexExpression & i)992 void WGSLCodeGenerator::writeIndexExpression(const IndexExpression& i) {
993     this->writeExpression(*i.base(), Precedence::kPostfix);
994     this->write("[");
995     this->writeExpression(*i.index(), Precedence::kTopLevel);
996     this->write("]");
997 }
998 
writeLiteral(const Literal & l)999 void WGSLCodeGenerator::writeLiteral(const Literal& l) {
1000     const Type& type = l.type();
1001     if (type.isFloat() || type.isBoolean()) {
1002         this->write(l.description(OperatorPrecedence::kTopLevel));
1003         return;
1004     }
1005     SkASSERT(type.isInteger());
1006     if (type.matches(*fContext.fTypes.fUInt)) {
1007         this->write(std::to_string(l.intValue() & 0xffffffff));
1008         this->write("u");
1009     } else if (type.matches(*fContext.fTypes.fUShort)) {
1010         this->write(std::to_string(l.intValue() & 0xffff));
1011         this->write("u");
1012     } else {
1013         this->write(std::to_string(l.intValue()));
1014     }
1015 }
1016 
writeSwizzle(const Swizzle & swizzle)1017 void WGSLCodeGenerator::writeSwizzle(const Swizzle& swizzle) {
1018     this->writeExpression(*swizzle.base(), Precedence::kPostfix);
1019     this->write(".");
1020     for (int c : swizzle.components()) {
1021         SkASSERT(c >= 0 && c <= 3);
1022         this->write(&("x\0y\0z\0w\0"[c * 2]));
1023     }
1024 }
1025 
writeTernaryExpression(const TernaryExpression & t,Precedence parentPrecedence)1026 void WGSLCodeGenerator::writeTernaryExpression(const TernaryExpression& t,
1027                                                Precedence parentPrecedence) {
1028     bool needParens = Precedence::kTernary >= parentPrecedence;
1029     if (needParens) {
1030         this->write("(");
1031     }
1032 
1033     // The trivial case is when neither branch has side effects and evaluate to a scalar or vector
1034     // type. This can be represented with a call to the WGSL `select` intrinsic although it doesn't
1035     // support short-circuiting.
1036     if ((t.type().isScalar() || t.type().isVector()) && !Analysis::HasSideEffects(*t.ifTrue()) &&
1037         !Analysis::HasSideEffects(*t.ifFalse())) {
1038         this->write("select(");
1039         this->writeExpression(*t.ifFalse(), Precedence::kTernary);
1040         this->write(", ");
1041         this->writeExpression(*t.ifTrue(), Precedence::kTernary);
1042         this->write(", ");
1043 
1044         bool isVector = t.type().isVector();
1045         if (isVector) {
1046             // Splat the condition expression into a vector.
1047             this->write(String::printf("vec%d<bool>(", t.type().columns()));
1048         }
1049         this->writeExpression(*t.test(), Precedence::kTernary);
1050         if (isVector) {
1051             this->write(")");
1052         }
1053         this->write(")");
1054         if (needParens) {
1055             this->write(")");
1056         }
1057         return;
1058     }
1059 
1060     // TODO(skia:13092): WGSL does not support ternary expressions. To replicate the required
1061     // short-circuting behavior we need to hoist the expression out into the surrounding block,
1062     // convert it into an if statement that writes the result to a synthesized variable, and replace
1063     // the original expression with a reference to that variable.
1064     //
1065     // Once hoisting is supported, we may want to use that for vector type expressions as well,
1066     // since select above does a component-wise select
1067 }
1068 
writeVariableReference(const VariableReference & r)1069 void WGSLCodeGenerator::writeVariableReference(const VariableReference& r) {
1070     // TODO(skia:13092): Correctly handle RTflip for built-ins.
1071     const Variable& v = *r.variable();
1072 
1073     // Insert a conversion expression if this is a built-in variable whose type differs from the
1074     // SkSL.
1075     std::optional<std::string_view> conversion = needs_builtin_type_conversion(v);
1076     if (conversion.has_value()) {
1077         this->write(*conversion);
1078         this->write("(");
1079     }
1080 
1081     bool needsDeref = false;
1082     bool isSynthesizedOutParamArg = fOutParamArgVars.contains(&v);
1083 
1084     // When a variable is referenced in the context of a synthesized out-parameter helper argument,
1085     // two special rules apply:
1086     //     1. If it's accessed via a pipeline I/O or global uniforms struct, it should instead
1087     //        be referenced by name (since it's actually referring to a function parameter).
1088     //     2. Its type should be treated as a pointer and should be dereferenced as such.
1089     if (v.storage() == Variable::Storage::kGlobal && !isSynthesizedOutParamArg) {
1090         if (v.modifiers().fFlags & Modifiers::kIn_Flag) {
1091             this->write("_stageIn.");
1092         } else if (v.modifiers().fFlags & Modifiers::kOut_Flag) {
1093             this->write("(*_stageOut).");
1094         } else if (is_in_global_uniforms(v)) {
1095             this->write("_globalUniforms.");
1096         }
1097     } else if ((v.storage() == Variable::Storage::kParameter &&
1098                 v.modifiers().fFlags & Modifiers::kOut_Flag) ||
1099                isSynthesizedOutParamArg) {
1100         // This is an out-parameter and its type is a pointer, which we need to dereference.
1101         // We wrap the dereference in parentheses in case the value is used in an access expression
1102         // later.
1103         needsDeref = true;
1104         this->write("(*");
1105     }
1106 
1107     this->writeName(v.mangledName());
1108     if (needsDeref) {
1109         this->write(")");
1110     }
1111     if (conversion.has_value()) {
1112         this->write(")");
1113     }
1114 }
1115 
writeAnyConstructor(const AnyConstructor & c,Precedence parentPrecedence)1116 void WGSLCodeGenerator::writeAnyConstructor(const AnyConstructor& c, Precedence parentPrecedence) {
1117     this->write(to_wgsl_type(c.type()));
1118     this->write("(");
1119     auto separator = SkSL::String::Separator();
1120     for (const auto& e : c.argumentSpan()) {
1121         this->write(separator());
1122         this->writeExpression(*e, Precedence::kSequence);
1123     }
1124     this->write(")");
1125 }
1126 
writeConstructorCompound(const ConstructorCompound & c,Precedence parentPrecedence)1127 void WGSLCodeGenerator::writeConstructorCompound(const ConstructorCompound& c,
1128                                                  Precedence parentPrecedence) {
1129     // TODO(skia:13092): Support matrix constructors
1130     if (c.type().isVector()) {
1131         this->writeConstructorCompoundVector(c, parentPrecedence);
1132     } else {
1133         fContext.fErrors->error(c.fPosition, "unsupported compound constructor");
1134     }
1135 }
1136 
writeConstructorCompoundVector(const ConstructorCompound & c,Precedence parentPrecedence)1137 void WGSLCodeGenerator::writeConstructorCompoundVector(const ConstructorCompound& c,
1138                                                        Precedence parentPrecedence) {
1139     // WGSL supports constructing vectors from a mix of scalars and vectors but
1140     // not matrices (see https://www.w3.org/TR/WGSL/#type-constructor-expr).
1141     //
1142     // SkSL supports vec4(mat2x2) which we handle specially.
1143     if (c.type().columns() == 4 && c.argumentSpan().size() == 1) {
1144         const Expression& arg = *c.argumentSpan().front();
1145         if (arg.type().isMatrix()) {
1146             // This is the vec4(mat2x2) case.
1147             SkASSERT(arg.type().columns() == 2);
1148             SkASSERT(arg.type().rows() == 2);
1149 
1150             // Generate a helper so that the argument expression gets evaluated once.
1151             std::string name = String::printf("%s_from_%s",
1152                                               to_mangled_wgsl_type_name(c.type()).c_str(),
1153                                               to_mangled_wgsl_type_name(arg.type()).c_str());
1154             if (!fHelpers.contains(name)) {
1155                 fHelpers.add(name);
1156                 std::string returnType = to_wgsl_type(c.type());
1157                 std::string argType = to_wgsl_type(arg.type());
1158                 fExtraFunctions.printf(
1159                         "fn %s(x: %s) -> %s {\n    return %s(x[0].xy, x[1].xy);\n}\n",
1160                         name.c_str(),
1161                         argType.c_str(),
1162                         returnType.c_str(),
1163                         returnType.c_str());
1164             }
1165             this->write(name);
1166             this->write("(");
1167             this->writeExpression(arg, Precedence::kSequence);
1168             this->write(")");
1169             return;
1170         }
1171     }
1172     this->writeAnyConstructor(c, parentPrecedence);
1173 }
1174 
writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix & c,Precedence parentPrecedence)1175 void WGSLCodeGenerator::writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c,
1176                                                        Precedence parentPrecedence) {
1177     const Type& type = c.type();
1178     SkASSERT(type.isMatrix());
1179     SkASSERT(c.argument()->type().isScalar());
1180 
1181     // Generate a helper so that the argument expression gets evaluated once.
1182     std::string name = String::printf("%s_diagonal", to_mangled_wgsl_type_name(type).c_str());
1183     if (!fHelpers.contains(name)) {
1184         fHelpers.add(name);
1185 
1186         std::string typeName = to_wgsl_type(type);
1187         fExtraFunctions.printf("fn %s(x: %s) -> %s {\n",
1188                                name.c_str(),
1189                                to_wgsl_type(c.argument()->type()).c_str(),
1190                                typeName.c_str());
1191         fExtraFunctions.printf("    return %s(", typeName.c_str());
1192         auto separator = String::Separator();
1193         for (int col = 0; col < type.columns(); ++col) {
1194             for (int row = 0; row < type.rows(); ++row) {
1195                 fExtraFunctions.printf("%s%s", separator().c_str(), (col == row) ? "x" : "0.0");
1196             }
1197         }
1198         fExtraFunctions.printf(");\n}\n");
1199     }
1200     this->write(name);
1201     this->write("(");
1202     this->writeExpression(*c.argument(), Precedence::kSequence);
1203     this->write(")");
1204 }
1205 
writeMatrixEquality(const Expression & left,const Expression & right)1206 void WGSLCodeGenerator::writeMatrixEquality(const Expression& left, const Expression& right) {
1207     const Type& leftType = left.type();
1208     const Type& rightType = right.type();
1209     SkASSERT(leftType.isMatrix());
1210     SkASSERT(rightType.isMatrix());
1211     SkASSERT(leftType.rows() == rightType.rows());
1212     SkASSERT(leftType.columns() == rightType.columns());
1213 
1214     std::string name = String::printf("%s_eq_%s",
1215                                       to_mangled_wgsl_type_name(leftType).c_str(),
1216                                       to_mangled_wgsl_type_name(rightType).c_str());
1217     if (!fHelpers.contains(name)) {
1218         fHelpers.add(name);
1219         fExtraFunctions.printf("fn %s(left: %s, right: %s) -> bool {\n    return ",
1220                                name.c_str(),
1221                                to_wgsl_type(leftType).c_str(),
1222                                to_wgsl_type(rightType).c_str());
1223         const char* separator = "";
1224         for (int i = 0; i < leftType.columns(); ++i) {
1225             fExtraFunctions.printf("%sall(left[%d] == right[%d])", separator, i, i);
1226             separator = " &&\n           ";
1227         }
1228         fExtraFunctions.printf(";\n}\n");
1229     }
1230     this->write(name);
1231     this->write("(");
1232     this->writeExpression(left, Precedence::kSequence);
1233     this->write(", ");
1234     this->writeExpression(right, Precedence::kSequence);
1235     this->write(")");
1236 }
1237 
writeProgramElement(const ProgramElement & e)1238 void WGSLCodeGenerator::writeProgramElement(const ProgramElement& e) {
1239     switch (e.kind()) {
1240         case ProgramElement::Kind::kExtension:
1241             // TODO(skia:13092): WGSL supports extensions via the "enable" directive
1242             // (https://www.w3.org/TR/WGSL/#language-extensions). While we could easily emit this
1243             // directive, we should first ensure that all possible SkSL extension names are
1244             // converted to their appropriate WGSL extension. Currently there are no known supported
1245             // WGSL extensions aside from the hypotheticals listed in the spec.
1246             break;
1247         case ProgramElement::Kind::kGlobalVar:
1248             this->writeGlobalVarDeclaration(e.as<GlobalVarDeclaration>());
1249             break;
1250         case ProgramElement::Kind::kInterfaceBlock:
1251             // All interface block declarations are handled explicitly as the "program header" in
1252             // generateCode().
1253             break;
1254         case ProgramElement::Kind::kStructDefinition:
1255             this->writeStructDefinition(e.as<StructDefinition>());
1256             break;
1257         case ProgramElement::Kind::kFunctionPrototype:
1258             // A WGSL function declaration must contain its body and the function name is in scope
1259             // for the entire program (see https://www.w3.org/TR/WGSL/#function-declaration and
1260             // https://www.w3.org/TR/WGSL/#declaration-and-scope).
1261             //
1262             // As such, we don't emit function prototypes.
1263             break;
1264         case ProgramElement::Kind::kFunction:
1265             this->writeFunction(e.as<FunctionDefinition>());
1266             break;
1267         default:
1268             SkDEBUGFAILF("unsupported program element: %s\n", e.description().c_str());
1269             break;
1270     }
1271 }
1272 
writeGlobalVarDeclaration(const GlobalVarDeclaration & d)1273 void WGSLCodeGenerator::writeGlobalVarDeclaration(const GlobalVarDeclaration& d) {
1274     const Variable& var = *d.declaration()->as<VarDeclaration>().var();
1275     if ((var.modifiers().fFlags & (Modifiers::kIn_Flag | Modifiers::kOut_Flag)) ||
1276         is_in_global_uniforms(var)) {
1277         // Pipeline stage I/O parameters and top-level (non-block) uniforms are handled specially
1278         // in generateCode().
1279         return;
1280     }
1281 
1282     // TODO(skia:13092): Implement workgroup variable decoration
1283     this->write("var<private> ");
1284     this->writeVariableDecl(var.type(), var.name(), Delimiter::kSemicolon);
1285 }
1286 
writeStructDefinition(const StructDefinition & s)1287 void WGSLCodeGenerator::writeStructDefinition(const StructDefinition& s) {
1288     const Type& type = s.type();
1289     this->writeLine("struct " + type.displayName() + " {");
1290     fIndentation++;
1291     this->writeFields(SkSpan(type.fields()), type.fPosition);
1292     fIndentation--;
1293     this->writeLine("};");
1294 }
1295 
writeFields(SkSpan<const Type::Field> fields,Position parentPos,const MemoryLayout *)1296 void WGSLCodeGenerator::writeFields(SkSpan<const Type::Field> fields,
1297                                     Position parentPos,
1298                                     const MemoryLayout*) {
1299     // TODO(skia:13092): Check alignment against `layout` constraints, if present. A layout
1300     // constraint will be specified for interface blocks and for structs that appear in a block.
1301     for (const Type::Field& field : fields) {
1302         const Type* fieldType = field.fType;
1303         this->writeVariableDecl(*fieldType, field.fName, Delimiter::kComma);
1304     }
1305 }
1306 
writeStageInputStruct()1307 void WGSLCodeGenerator::writeStageInputStruct() {
1308     std::string_view structNamePrefix = pipeline_struct_prefix(fProgram.fConfig->fKind);
1309     if (structNamePrefix.empty()) {
1310         // There's no need to declare pipeline stage outputs.
1311         return;
1312     }
1313 
1314     // It is illegal to declare a struct with no members.
1315     if (fPipelineInputCount < 1) {
1316         return;
1317     }
1318 
1319     this->write("struct ");
1320     this->write(structNamePrefix);
1321     this->writeLine("In {");
1322     fIndentation++;
1323 
1324     bool declaredFragCoordsBuiltin = false;
1325     for (const ProgramElement* e : fProgram.elements()) {
1326         if (e->is<GlobalVarDeclaration>()) {
1327             const Variable* v = e->as<GlobalVarDeclaration>().declaration()
1328                                  ->as<VarDeclaration>().var();
1329             if (v->modifiers().fFlags & Modifiers::kIn_Flag) {
1330                 this->writePipelineIODeclaration(v->modifiers(), v->type(), v->mangledName(),
1331                                                  Delimiter::kComma);
1332                 if (v->modifiers().fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN) {
1333                     declaredFragCoordsBuiltin = true;
1334                 }
1335             }
1336         } else if (e->is<InterfaceBlock>()) {
1337             const Variable* v = e->as<InterfaceBlock>().var();
1338             // Merge all the members of `in` interface blocks to the input struct, which are
1339             // specified as either "builtin" or with a "layout(location=".
1340             //
1341             // TODO(armansito): Is it legal to have an interface block without a storage qualifier
1342             // but with members that have individual storage qualifiers?
1343             if (v->modifiers().fFlags & Modifiers::kIn_Flag) {
1344                 for (const auto& f : v->type().fields()) {
1345                     this->writePipelineIODeclaration(f.fModifiers, *f.fType, f.fName,
1346                                                      Delimiter::kComma);
1347                     if (f.fModifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN) {
1348                         declaredFragCoordsBuiltin = true;
1349                     }
1350                 }
1351             }
1352         }
1353     }
1354 
1355     if (ProgramConfig::IsFragment(fProgram.fConfig->fKind) &&
1356         fRequirements.mainNeedsCoordsArgument && !declaredFragCoordsBuiltin) {
1357         this->writeLine("@builtin(position) sk_FragCoord: vec4<f32>,");
1358     }
1359 
1360     fIndentation--;
1361     this->writeLine("};");
1362 }
1363 
writeStageOutputStruct()1364 void WGSLCodeGenerator::writeStageOutputStruct() {
1365     std::string_view structNamePrefix = pipeline_struct_prefix(fProgram.fConfig->fKind);
1366     if (structNamePrefix.empty()) {
1367         // There's no need to declare pipeline stage outputs.
1368         return;
1369     }
1370 
1371     this->write("struct ");
1372     this->write(structNamePrefix);
1373     this->writeLine("Out {");
1374     fIndentation++;
1375 
1376     // TODO(skia:13092): Remember all variables that are added to the output struct here so they
1377     // can be referenced correctly when handling variable references.
1378     bool declaredPositionBuiltin = false;
1379     bool requiresPointSizeBuiltin = false;
1380     for (const ProgramElement* e : fProgram.elements()) {
1381         if (e->is<GlobalVarDeclaration>()) {
1382             const Variable* v = e->as<GlobalVarDeclaration>().declaration()
1383                                  ->as<VarDeclaration>().var();
1384             if (v->modifiers().fFlags & Modifiers::kOut_Flag) {
1385                 this->writePipelineIODeclaration(v->modifiers(), v->type(), v->mangledName(),
1386                                                  Delimiter::kComma);
1387             }
1388         } else if (e->is<InterfaceBlock>()) {
1389             const Variable* v = e->as<InterfaceBlock>().var();
1390             // Merge all the members of `out` interface blocks to the output struct, which are
1391             // specified as either "builtin" or with a "layout(location=".
1392             //
1393             // TODO(armansito): Is it legal to have an interface block without a storage qualifier
1394             // but with members that have individual storage qualifiers?
1395             if (v->modifiers().fFlags & Modifiers::kOut_Flag) {
1396                 for (const auto& f : v->type().fields()) {
1397                     this->writePipelineIODeclaration(f.fModifiers, *f.fType, f.fName,
1398                                                      Delimiter::kComma);
1399                     if (f.fModifiers.fLayout.fBuiltin == SK_POSITION_BUILTIN) {
1400                         declaredPositionBuiltin = true;
1401                     } else if (f.fModifiers.fLayout.fBuiltin == SK_POINTSIZE_BUILTIN) {
1402                         // sk_PointSize is explicitly not supported by `builtin_from_sksl_name` so
1403                         // writePipelineIODeclaration will never write it. We mark it here if the
1404                         // declaration is needed so we can synthesize it below.
1405                         requiresPointSizeBuiltin = true;
1406                     }
1407                 }
1408             }
1409         }
1410     }
1411 
1412     // A vertex program must include the `position` builtin in its entry point return type.
1413     if (ProgramConfig::IsVertex(fProgram.fConfig->fKind) && !declaredPositionBuiltin) {
1414         this->writeLine("@builtin(position) sk_Position: vec4<f32>,");
1415     }
1416 
1417     fIndentation--;
1418     this->writeLine("};");
1419 
1420     // In WebGPU/WGSL, the vertex stage does not support a point-size output and the size
1421     // of a point primitive is always 1 pixel (see https://github.com/gpuweb/gpuweb/issues/332).
1422     //
1423     // There isn't anything we can do to emulate this correctly at this stage so we
1424     // synthesize a placeholder variable that has no effect. Programs should not rely on
1425     // sk_PointSize when using the Dawn backend.
1426     if (ProgramConfig::IsVertex(fProgram.fConfig->fKind) && requiresPointSizeBuiltin) {
1427         this->writeLine("/* unsupported */ var<private> sk_PointSize: f32;");
1428     }
1429 }
1430 
writeNonBlockUniformsForTests()1431 void WGSLCodeGenerator::writeNonBlockUniformsForTests() {
1432     for (const ProgramElement* e : fProgram.elements()) {
1433         if (e->is<GlobalVarDeclaration>()) {
1434             const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
1435             const Variable& var = *decls.varDeclaration().var();
1436             if (is_in_global_uniforms(var)) {
1437                 if (!fDeclaredUniformsStruct) {
1438                     this->write("struct _GlobalUniforms {\n");
1439                     fDeclaredUniformsStruct = true;
1440                 }
1441                 this->write("    ");
1442                 this->writeVariableDecl(var.type(), var.mangledName(), Delimiter::kComma);
1443             }
1444         }
1445     }
1446     if (fDeclaredUniformsStruct) {
1447         int binding = fProgram.fConfig->fSettings.fDefaultUniformBinding;
1448         int set = fProgram.fConfig->fSettings.fDefaultUniformSet;
1449         this->write("};\n");
1450         this->write("@binding(" + std::to_string(binding) + ") ");
1451         this->write("@group(" + std::to_string(set) + ") ");
1452         this->writeLine("var<uniform> _globalUniforms: _GlobalUniforms;");
1453     }
1454 }
1455 
writeFunctionDependencyArgs(const FunctionDeclaration & f)1456 bool WGSLCodeGenerator::writeFunctionDependencyArgs(const FunctionDeclaration& f) {
1457     FunctionDependencies* deps = fRequirements.dependencies.find(&f);
1458     if (!deps || *deps == FunctionDependencies::kNone) {
1459         return false;
1460     }
1461 
1462     const char* separator = "";
1463     if ((*deps & FunctionDependencies::kPipelineInputs) != FunctionDependencies::kNone) {
1464         this->write("_stageIn");
1465         separator = ", ";
1466     }
1467     if ((*deps & FunctionDependencies::kPipelineOutputs) != FunctionDependencies::kNone) {
1468         this->write(separator);
1469         this->write("_stageOut");
1470     }
1471     return true;
1472 }
1473 
writeFunctionDependencyParams(const FunctionDeclaration & f)1474 bool WGSLCodeGenerator::writeFunctionDependencyParams(const FunctionDeclaration& f) {
1475     FunctionDependencies* deps = fRequirements.dependencies.find(&f);
1476     if (!deps || *deps == FunctionDependencies::kNone) {
1477         return false;
1478     }
1479 
1480     std::string_view structNamePrefix = pipeline_struct_prefix(fProgram.fConfig->fKind);
1481     if (structNamePrefix.empty()) {
1482         return false;
1483     }
1484     const char* separator = "";
1485     if ((*deps & FunctionDependencies::kPipelineInputs) != FunctionDependencies::kNone) {
1486         this->write("_stageIn: ");
1487         separator = ", ";
1488         this->write(structNamePrefix);
1489         this->write("In");
1490     }
1491     if ((*deps & FunctionDependencies::kPipelineOutputs) != FunctionDependencies::kNone) {
1492         this->write(separator);
1493         this->write("_stageOut: ptr<function, ");
1494         this->write(structNamePrefix);
1495         this->write("Out>");
1496     }
1497     return true;
1498 }
1499 
writeOutParamHelper(const FunctionCall & c,const ExpressionArray & args,const SkTArray<VariableReference * > & outVars)1500 std::string WGSLCodeGenerator::writeOutParamHelper(const FunctionCall& c,
1501                                                    const ExpressionArray& args,
1502                                                    const SkTArray<VariableReference*>& outVars) {
1503     // It's possible for out-param function arguments to contain an out-param function call
1504     // expression. Emit the function into a temporary stream to prevent the nested helper from
1505     // clobbering the current helper as we recursively evaluate argument expressions.
1506     StringStream tmpStream;
1507     AutoOutputStream outputToExtraFunctions(this, &tmpStream, &fIndentation);
1508 
1509     // Reset the line start state while the AutoOutputStream is active. We restore it later before
1510     // the function returns.
1511     bool atLineStart = fAtLineStart;
1512     fAtLineStart = false;
1513     const FunctionDeclaration& func = c.function();
1514 
1515     // Synthesize a helper function that takes the same inputs as `function`, except in places where
1516     // `outVars` is non-null; in those places, we take the type of the VariableReference.
1517     //
1518     // float _outParamHelper_0_originalFuncName(float _var0, float _var1, float& outParam) {
1519     std::string name =
1520             "_outParamHelper_" + std::to_string(fSwizzleHelperCount++) + "_" + func.mangledName();
1521     auto separator = SkSL::String::Separator();
1522     this->write("fn ");
1523     this->write(name);
1524     this->write("(");
1525     if (this->writeFunctionDependencyParams(func)) {
1526         separator();
1527     }
1528 
1529     SkASSERT(outVars.size() == args.size());
1530     SkASSERT(SkToSizeT(outVars.size()) == func.parameters().size());
1531 
1532     // We need to detect cases where the caller passes the same variable as an out-param more than
1533     // once and avoid redeclaring the variable name. This is also a situation that is not permitted
1534     // by WGSL aliasing rules (see https://www.w3.org/TR/WGSL/#aliasing). Because the parameter is
1535     // redundant and we don't actually ever reference it, we give it a placeholder name.
1536     auto parentOutParamArgVars = std::move(fOutParamArgVars);
1537     SkASSERT(fOutParamArgVars.empty());
1538 
1539     for (int i = 0; i < args.size(); ++i) {
1540         this->write(separator());
1541 
1542         if (outVars[i]) {
1543             const Variable* var = outVars[i]->variable();
1544             if (!fOutParamArgVars.contains(var)) {
1545                 fOutParamArgVars.add(var);
1546                 this->writeName(var->mangledName());
1547             } else {
1548                 this->write("_unused");
1549                 this->write(std::to_string(i));
1550             }
1551         } else {
1552             this->write("_var");
1553             this->write(std::to_string(i));
1554         }
1555 
1556         this->write(": ");
1557 
1558         // Declare the parameter using the type of argument variable. If the complete argument is an
1559         // access or swizzle expression, the target assignment will be resolved below when we copy
1560         // the value to the out-parameter.
1561         const Type& type = outVars[i] ? outVars[i]->type() : args[i]->type();
1562 
1563         // Declare an out-parameter as a pointer.
1564         if (func.parameters()[i]->modifiers().fFlags & Modifiers::kOut_Flag) {
1565             this->write(to_ptr_type(type));
1566         } else {
1567             this->write(to_wgsl_type(type));
1568         }
1569     }
1570 
1571     this->write(")");
1572     if (!func.returnType().isVoid()) {
1573         this->write(" -> ");
1574         this->write(to_wgsl_type(func.returnType()));
1575     }
1576     this->writeLine(" {");
1577     ++fIndentation;
1578 
1579     // Declare a temporary variable for each out-parameter.
1580     for (int i = 0; i < outVars.size(); ++i) {
1581         if (!outVars[i]) {
1582             continue;
1583         }
1584         this->write("var ");
1585         this->write("_var");
1586         this->write(std::to_string(i));
1587         this->write(": ");
1588         this->write(to_wgsl_type(args[i]->type()));
1589 
1590         // If this is an inout parameter then we need to copy the input argument into the parameter
1591         // per https://www.khronos.org/opengl/wiki/Core_Language_(GLSL)#Parameters.
1592         if (func.parameters()[i]->modifiers().fFlags & Modifiers::kIn_Flag) {
1593             this->write(" = ");
1594             this->writeExpression(*args[i], Precedence::kAssignment);
1595         }
1596 
1597         this->writeLine(";");
1598     }
1599 
1600     // Call the function we're wrapping. If it has a return type, then store it so it can be
1601     // returned later.
1602     bool hasReturn = !c.type().isVoid();
1603     if (hasReturn) {
1604         this->write("var _return: ");
1605         this->write(to_wgsl_type(c.type()));
1606         this->write(" = ");
1607     }
1608 
1609     // Write the function call.
1610     this->writeName(func.mangledName());
1611     this->write("(");
1612     auto newSeparator = SkSL::String::Separator();
1613     if (this->writeFunctionDependencyArgs(func)) {
1614         newSeparator();
1615     }
1616     for (int i = 0; i < args.size(); ++i) {
1617         this->write(newSeparator());
1618         // All forwarded arguments now have a name that looks like "_var[i]" (e.g. _var0, var1,
1619         // etc.). All such variables should be of value type and those that have been passed in as
1620         // inout should have been dereferenced when they were stored in a local temporary. We need
1621         // to take their address again when forwarding to a pointer.
1622         if (outVars[i]) {
1623             this->write("&");
1624         }
1625         this->write("_var");
1626         this->write(std::to_string(i));
1627     }
1628     this->writeLine(");");
1629 
1630     // Copy the temporary variables back into the original out-parameters.
1631     for (int i = 0; i < outVars.size(); ++i) {
1632         if (!outVars[i]) {
1633             continue;
1634         }
1635         // TODO(skia:13092): WGSL does not support assigning to a swizzle
1636         // (see https://github.com/gpuweb/gpuweb/issues/737). These will require special treatment
1637         // when they appear on the lhs of an assignment.
1638         this->writeExpression(*args[i], Precedence::kAssignment);
1639         this->write(" = _var");
1640         this->write(std::to_string(i));
1641         this->writeLine(";");
1642     }
1643 
1644     // Return
1645     if (hasReturn) {
1646         this->writeLine("return _return;");
1647     }
1648 
1649     --fIndentation;
1650     this->writeLine("}");
1651 
1652     // Write the function out to `fExtraFunctions`.
1653     write_stringstream(tmpStream, fExtraFunctions);
1654 
1655     // Restore any global state
1656     fOutParamArgVars = std::move(parentOutParamArgVars);
1657     fAtLineStart = atLineStart;
1658     return name;
1659 }
1660 
1661 }  // namespace SkSL
1662