• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2022 Google Inc.
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/codegen/SkSLWGSLCodeGenerator.h"
9 
10 #include "include/core/SkSpan.h"
11 #include "include/core/SkTypes.h"
12 #include "include/private/base/SkTArray.h"
13 #include "include/private/base/SkTo.h"
14 #include "src/base/SkEnumBitMask.h"
15 #include "src/base/SkStringView.h"
16 #include "src/core/SkTHash.h"
17 #include "src/core/SkTraceEvent.h"
18 #include "src/sksl/SkSLAnalysis.h"
19 #include "src/sksl/SkSLBuiltinTypes.h"
20 #include "src/sksl/SkSLCompiler.h"
21 #include "src/sksl/SkSLConstantFolder.h"
22 #include "src/sksl/SkSLContext.h"
23 #include "src/sksl/SkSLDefines.h"
24 #include "src/sksl/SkSLErrorReporter.h"
25 #include "src/sksl/SkSLIntrinsicList.h"
26 #include "src/sksl/SkSLMemoryLayout.h"
27 #include "src/sksl/SkSLOperator.h"
28 #include "src/sksl/SkSLOutputStream.h"
29 #include "src/sksl/SkSLPosition.h"
30 #include "src/sksl/SkSLProgramSettings.h"
31 #include "src/sksl/SkSLString.h"
32 #include "src/sksl/SkSLStringStream.h"
33 #include "src/sksl/SkSLUtil.h"
34 #include "src/sksl/analysis/SkSLProgramUsage.h"
35 #include "src/sksl/analysis/SkSLProgramVisitor.h"
36 #include "src/sksl/codegen/SkSLCodeGenerator.h"
37 #include "src/sksl/ir/SkSLBinaryExpression.h"
38 #include "src/sksl/ir/SkSLBlock.h"
39 #include "src/sksl/ir/SkSLConstructor.h"
40 #include "src/sksl/ir/SkSLConstructorArrayCast.h"
41 #include "src/sksl/ir/SkSLConstructorCompound.h"
42 #include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
43 #include "src/sksl/ir/SkSLConstructorMatrixResize.h"
44 #include "src/sksl/ir/SkSLDoStatement.h"
45 #include "src/sksl/ir/SkSLExpression.h"
46 #include "src/sksl/ir/SkSLExpressionStatement.h"
47 #include "src/sksl/ir/SkSLFieldAccess.h"
48 #include "src/sksl/ir/SkSLForStatement.h"
49 #include "src/sksl/ir/SkSLFunctionCall.h"
50 #include "src/sksl/ir/SkSLFunctionDeclaration.h"
51 #include "src/sksl/ir/SkSLFunctionDefinition.h"
52 #include "src/sksl/ir/SkSLIRNode.h"
53 #include "src/sksl/ir/SkSLIfStatement.h"
54 #include "src/sksl/ir/SkSLIndexExpression.h"
55 #include "src/sksl/ir/SkSLInterfaceBlock.h"
56 #include "src/sksl/ir/SkSLLayout.h"
57 #include "src/sksl/ir/SkSLLiteral.h"
58 #include "src/sksl/ir/SkSLModifierFlags.h"
59 #include "src/sksl/ir/SkSLModifiersDeclaration.h"
60 #include "src/sksl/ir/SkSLPostfixExpression.h"
61 #include "src/sksl/ir/SkSLPrefixExpression.h"
62 #include "src/sksl/ir/SkSLProgram.h"
63 #include "src/sksl/ir/SkSLProgramElement.h"
64 #include "src/sksl/ir/SkSLReturnStatement.h"
65 #include "src/sksl/ir/SkSLSetting.h"
66 #include "src/sksl/ir/SkSLStatement.h"
67 #include "src/sksl/ir/SkSLStructDefinition.h"
68 #include "src/sksl/ir/SkSLSwitchCase.h"
69 #include "src/sksl/ir/SkSLSwitchStatement.h"
70 #include "src/sksl/ir/SkSLSwizzle.h"
71 #include "src/sksl/ir/SkSLTernaryExpression.h"
72 #include "src/sksl/ir/SkSLType.h"
73 #include "src/sksl/ir/SkSLVarDeclarations.h"
74 #include "src/sksl/ir/SkSLVariable.h"
75 #include "src/sksl/ir/SkSLVariableReference.h"
76 #include "src/sksl/spirv.h"
77 #include "src/sksl/transform/SkSLTransform.h"
78 
79 #include <algorithm>
80 #include <cstddef>
81 #include <cstdint>
82 #include <initializer_list>
83 #include <iterator>
84 #include <memory>
85 #include <optional>
86 #include <string>
87 #include <string_view>
88 #include <utility>
89 
90 #ifdef SK_ENABLE_WGSL_VALIDATION
91 #include "tint/tint.h"
92 #include "src/tint/lang/wgsl/reader/options.h"
93 #include "src/tint/lang/wgsl/extension.h"
94 #endif
95 
96 using namespace skia_private;
97 
98 namespace {
99 
100 // Represents a function's dependencies that are not accessible in global scope. For instance,
101 // pipeline stage input and output parameters must be passed in as an argument.
102 //
103 // This is a bitmask enum. (It would be inside `class WGSLCodeGenerator`, but this leads to build
104 // errors in MSVC.)
105 enum class WGSLFunctionDependency : uint8_t {
106     kNone = 0,
107     kPipelineInputs  = 1 << 0,
108     kPipelineOutputs = 1 << 1,
109 };
110 using WGSLFunctionDependencies = SkEnumBitMask<WGSLFunctionDependency>;
111 
112 SK_MAKE_BITMASK_OPS(WGSLFunctionDependency)
113 
114 }  // namespace
115 
116 namespace SkSL {
117 
118 class WGSLCodeGenerator : public CodeGenerator {
119 public:
120     // See https://www.w3.org/TR/WGSL/#builtin-values
121     enum class Builtin {
122         // Vertex stage:
123         kVertexIndex,    // input
124         kInstanceIndex,  // input
125         kPosition,       // output, fragment stage input
126 
127         // Fragment stage:
128         kLastFragColor,  // input
129         kFrontFacing,    // input
130         kSampleIndex,    // input
131         kFragDepth,      // output
132         kSampleMaskIn,   // input
133         kSampleMask,     // output
134 
135         // Compute stage:
136         kLocalInvocationId,     // input
137         kLocalInvocationIndex,  // input
138         kGlobalInvocationId,    // input
139         kWorkgroupId,           // input
140         kNumWorkgroups,         // input
141     };
142 
143     // Variable declarations can be terminated by:
144     //   - comma (","), e.g. in struct member declarations or function parameters
145     //   - semicolon (";"), e.g. in function scope variables
146     // A "none" option is provided to skip the delimiter when not needed, e.g. at the end of a list
147     // of declarations.
148     enum class Delimiter {
149         kComma,
150         kSemicolon,
151         kNone,
152     };
153 
154     struct ProgramRequirements {
155         using DepsMap = skia_private::THashMap<const FunctionDeclaration*,
156                                                WGSLFunctionDependencies>;
157 
158         // Mappings used to synthesize function parameters according to dependencies on pipeline
159         // input/output variables.
160         DepsMap fDependencies;
161 
162         // These flags track extensions that will need to be enabled.
163         bool fPixelLocalExtension = false;
164     };
165 
WGSLCodeGenerator(const Context * context,const ShaderCaps * caps,const Program * program,OutputStream * out)166     WGSLCodeGenerator(const Context* context,
167                       const ShaderCaps* caps,
168                       const Program* program,
169                       OutputStream* out)
170             : INHERITED(context, caps, program, out) {}
171 
172     bool generateCode() override;
173 
174 private:
175     using INHERITED = CodeGenerator;
176     using Precedence = OperatorPrecedence;
177 
178     // Called by generateCode() as the first step.
179     void preprocessProgram();
180 
181     // Write output content while correctly handling indentation.
182     void write(std::string_view s);
183     void writeLine(std::string_view s = std::string_view());
184     void finishLine();
185 
186     // Helpers to declare a pipeline stage IO parameter declaration.
187     void writePipelineIODeclaration(const Layout& layout,
188                                     const Type& type,
189                                     std::string_view name,
190                                     Delimiter delimiter);
191     void writeUserDefinedIODecl(const Layout& layout,
192                                 const Type& type,
193                                 std::string_view name,
194                                 Delimiter delimiter);
195     void writeBuiltinIODecl(const Type& type,
196                             std::string_view name,
197                             Builtin builtin,
198                             Delimiter delimiter);
199     void writeVariableDecl(const Layout& layout,
200                            const Type& type,
201                            std::string_view name,
202                            Delimiter delimiter);
203 
204     // Write a function definition.
205     void writeFunction(const FunctionDefinition& f);
206     void writeFunctionDeclaration(const FunctionDeclaration& f,
207                                   SkSpan<const bool> paramNeedsDedicatedStorage);
208 
209     // Write the program entry point.
210     void writeEntryPoint(const FunctionDefinition& f);
211 
212     // Writers for supported statement types.
213     void writeStatement(const Statement& s);
214     void writeStatements(const StatementArray& statements);
215     void writeBlock(const Block& b);
216     void writeDoStatement(const DoStatement& expr);
217     void writeExpressionStatement(const Expression& expr);
218     void writeForStatement(const ForStatement& s);
219     void writeIfStatement(const IfStatement& s);
220     void writeReturnStatement(const ReturnStatement& s);
221     void writeSwitchStatement(const SwitchStatement& s);
222     void writeSwitchCases(SkSpan<const SwitchCase* const> cases);
223     void writeEmulatedSwitchFallthroughCases(SkSpan<const SwitchCase* const> cases,
224                                              std::string_view switchValue);
225     void writeSwitchCaseList(SkSpan<const SwitchCase* const> cases);
226     void writeVarDeclaration(const VarDeclaration& varDecl);
227 
228     // Synthesizes an LValue for an expression.
229     class LValue;
230     class PointerLValue;
231     class SwizzleLValue;
232     class VectorComponentLValue;
233     std::unique_ptr<LValue> makeLValue(const Expression& e);
234 
235     std::string variableReferenceNameForLValue(const VariableReference& r);
236     std::string variablePrefix(const Variable& v);
237 
238     bool binaryOpNeedsComponentwiseMatrixPolyfill(const Type& left, const Type& right, Operator op);
239 
240     // Writers for expressions. These return the final expression text as a string, and emit any
241     // necessary setup code directly into the program as necessary. The returned expression may be
242     // a `let`-alias that cannot be assigned-into; use `makeLValue` for an assignable expression.
243     std::string assembleExpression(const Expression& e, Precedence parentPrecedence);
244     std::string assembleBinaryExpression(const BinaryExpression& b, Precedence parentPrecedence);
245     std::string assembleBinaryExpression(const Expression& left,
246                                          Operator op,
247                                          const Expression& right,
248                                          const Type& resultType,
249                                          Precedence parentPrecedence);
250     std::string assembleFieldAccess(const FieldAccess& f);
251     std::string assembleFunctionCall(const FunctionCall& call, Precedence parentPrecedence);
252     std::string assembleIndexExpression(const IndexExpression& i);
253     std::string assembleLiteral(const Literal& l);
254     std::string assemblePostfixExpression(const PostfixExpression& p, Precedence parentPrecedence);
255     std::string assemblePrefixExpression(const PrefixExpression& p, Precedence parentPrecedence);
256     std::string assembleSwizzle(const Swizzle& swizzle);
257     std::string assembleTernaryExpression(const TernaryExpression& t, Precedence parentPrecedence);
258     std::string assembleVariableReference(const VariableReference& r);
259     std::string assembleName(std::string_view name);
260 
261     std::string assembleIncrementExpr(const Type& type);
262 
263     // Intrinsic helper functions.
264     std::string assembleIntrinsicCall(const FunctionCall& call,
265                                       IntrinsicKind kind,
266                                       Precedence parentPrecedence);
267     std::string assembleSimpleIntrinsic(std::string_view intrinsicName, const FunctionCall& call);
268     std::string assembleUnaryOpIntrinsic(Operator op,
269                                          const FunctionCall& call,
270                                          Precedence parentPrecedence);
271     std::string assembleBinaryOpIntrinsic(Operator op,
272                                           const FunctionCall& call,
273                                           Precedence parentPrecedence);
274     std::string assembleVectorizedIntrinsic(std::string_view intrinsicName,
275                                             const FunctionCall& call);
276     std::string assembleOutAssignedIntrinsic(std::string_view intrinsicName,
277                                              std::string_view returnFieldName,
278                                              std::string_view outFieldName,
279                                              const FunctionCall& call);
280     std::string assemblePartialSampleCall(std::string_view intrinsicName,
281                                           const Expression& sampler,
282                                           const Expression& coords);
283     std::string assembleInversePolyfill(const FunctionCall& call);
284     std::string assembleComponentwiseMatrixBinary(const Type& leftType,
285                                                   const Type& rightType,
286                                                   const std::string& left,
287                                                   const std::string& right,
288                                                   Operator op);
289 
290     // Constructor expressions
291     std::string assembleAnyConstructor(const AnyConstructor& c);
292     std::string assembleConstructorCompound(const ConstructorCompound& c);
293     std::string assembleConstructorCompoundVector(const ConstructorCompound& c);
294     std::string assembleConstructorCompoundMatrix(const ConstructorCompound& c);
295     std::string assembleConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c);
296     std::string assembleConstructorMatrixResize(const ConstructorMatrixResize& ctor);
297 
298     // Synthesized helper functions for comparison operators that are not supported by WGSL.
299     std::string assembleEqualityExpression(const Type& left,
300                                            const std::string& leftName,
301                                            const Type& right,
302                                            const std::string& rightName,
303                                            Operator op,
304                                            Precedence parentPrecedence);
305     std::string assembleEqualityExpression(const Expression& left,
306                                            const Expression& right,
307                                            Operator op,
308                                            Precedence parentPrecedence);
309 
310     // Writes a scratch variable into the program and returns its name (e.g. `_skTemp123`).
311     std::string writeScratchVar(const Type& type, const std::string& value = "");
312 
313     // Writes a scratch let-variable into the program, gives it the value of `expr`, and returns its
314     // name (e.g. `_skTemp123`).
315     std::string writeScratchLet(const std::string& expr);
316     std::string writeScratchLet(const Expression& expr, Precedence parentPrecedence);
317 
318     // Converts `expr` into a string and returns a scratch let-variable associated with the
319     // expression. Compile-time constants and plain variable references will return the expression
320     // directly and omit the let-variable.
321     std::string writeNontrivialScratchLet(const Expression& expr, Precedence parentPrecedence);
322 
323     // Generic recursive ProgramElement visitor.
324     void writeProgramElement(const ProgramElement& e);
325     void writeGlobalVarDeclaration(const GlobalVarDeclaration& d);
326     void writeStructDefinition(const StructDefinition& s);
327     void writeModifiersDeclaration(const ModifiersDeclaration&);
328 
329     // Writes the WGSL struct fields for SkSL structs and interface blocks. Enforces WGSL address
330     // space layout constraints
331     // (https://www.w3.org/TR/WGSL/#address-space-layout-constraints) if a `layout` is
332     // provided. A struct that does not need to be host-shareable does not require a `layout`.
333     void writeFields(SkSpan<const Field> fields, const MemoryLayout* memoryLayout = nullptr);
334 
335     // We bundle uniforms, and all varying pipeline stage inputs and outputs, into separate structs.
336     bool needsStageInputStruct() const;
337     void writeStageInputStruct();
338     bool needsStageOutputStruct() const;
339     void writeStageOutputStruct();
340     void writeUniformsAndBuffers();
341     void prepareUniformPolyfillsForInterfaceBlock(const InterfaceBlock* interfaceBlock,
342                                                   std::string_view instanceName,
343                                                   MemoryLayout::Standard nativeLayout);
344     void writeEnables();
345     void writeUniformPolyfills();
346 
347     void writeTextureOrSampler(const Variable& var,
348                                int bindingLocation,
349                                std::string_view suffix,
350                                std::string_view wgslType);
351 
352     // Writes all top-level non-opaque global uniform declarations (i.e. not part of an interface
353     // block) into a single uniform block binding.
354     //
355     // In complete fragment/vertex/compute programs, uniforms will be declared only as interface
356     // blocks and global opaque types (like textures and samplers) which we expect to be declared
357     // with a unique binding and descriptor set index. However, test files that are declared as RTE
358     // programs may contain OpenGL-style global uniform declarations with no clear binding index to
359     // use for the containing synthesized block.
360     //
361     // Since we are handling these variables only to generate gold files from RTEs and never run
362     // them, we always declare them at the default bind group and binding index.
363     void writeNonBlockUniformsForTests();
364 
365     // For a given function declaration, writes out any implicitly required pipeline stage arguments
366     // based on the function's pre-determined dependencies. These are expected to be written out as
367     // the first parameters for a function that requires them. Returns true if any arguments were
368     // written.
369     std::string functionDependencyArgs(const FunctionDeclaration&);
370     bool writeFunctionDependencyParams(const FunctionDeclaration&);
371 
372     // Code in the header appears before the main body of code.
373     StringStream fHeader;
374 
375     // We assign unique names to anonymous interface blocks based on the type.
376     skia_private::THashMap<const Type*, std::string> fInterfaceBlockNameMap;
377 
378     // Stores the functions which use stage inputs/outputs as well as required WGSL extensions.
379     ProgramRequirements fRequirements;
380     skia_private::TArray<const Variable*> fPipelineInputs;
381     skia_private::TArray<const Variable*> fPipelineOutputs;
382 
383     // These fields track whether we have written the polyfill for `inverse()` for a given matrix
384     // type.
385     bool fWrittenInverse2 = false;
386     bool fWrittenInverse3 = false;
387     bool fWrittenInverse4 = false;
388 
389     // These fields control uniform polyfill support in cases where WGSL and std140 disagree.
390     // In std140 layout, matrices need to be represented as arrays of @size(16)-aligned vectors, and
391     // array elements are wrapped in a struct containing a single @size(16)-aligned element. Arrays
392     // of matrices combine both wrappers. These wrapper structs are unpacked into natively-typed
393     // globals at the shader entrypoint.
394     struct FieldPolyfillInfo {
395         const InterfaceBlock* fInterfaceBlock;
396         std::string fReplacementName;
397         bool fIsArray = false;
398         bool fIsMatrix = false;
399         bool fWasAccessed = false;
400     };
401     using FieldPolyfillMap = skia_private::THashMap<const Field*, FieldPolyfillInfo>;
402     FieldPolyfillMap fFieldPolyfillMap;
403 
404     // Output processing state.
405     int fIndentation = 0;
406     bool fAtLineStart = false;
407     bool fHasUnconditionalReturn = false;
408     bool fAtFunctionScope = false;
409     int fConditionalScopeDepth = 0;
410     int fLocalSizeX = 1;
411     int fLocalSizeY = 1;
412     int fLocalSizeZ = 1;
413 
414     int fScratchCount = 0;
415 };
416 
417 enum class ProgramKind : int8_t;
418 
419 namespace {
420 
421 static constexpr char kSamplerSuffix[] = "_Sampler";
422 static constexpr char kTextureSuffix[] = "_Texture";
423 
424 // See https://www.w3.org/TR/WGSL/#memory-view-types
425 enum class PtrAddressSpace {
426     kFunction,
427     kPrivate,
428     kStorage,
429 };
430 
operator_name(Operator op)431 const char* operator_name(Operator op) {
432     switch (op.kind()) {
433         case Operator::Kind::LOGICALXOR:  return " != ";
434         default:                          return op.operatorName();
435     }
436 }
437 
is_reserved_word(std::string_view word)438 bool is_reserved_word(std::string_view word) {
439     static const THashSet<std::string_view> kReservedWords{
440             // Used by SkSL:
441             "FSIn",
442             "FSOut",
443             "VSIn",
444             "VSOut",
445             "CSIn",
446             "_globalUniforms",
447             "_GlobalUniforms",
448             "_return",
449             "_stageIn",
450             "_stageOut",
451             // Keywords: https://www.w3.org/TR/WGSL/#keyword-summary
452             "alias",
453             "break",
454             "case",
455             "const",
456             "const_assert",
457             "continue",
458             "continuing",
459             "default",
460             "diagnostic",
461             "discard",
462             "else",
463             "enable",
464             "false",
465             "fn",
466             "for",
467             "if",
468             "let",
469             "loop",
470             "override",
471             "requires",
472             "return",
473             "struct",
474             "switch",
475             "true",
476             "var",
477             "while",
478             // Pre-declared types: https://www.w3.org/TR/WGSL/#predeclared-types
479             "bool",
480             "f16",
481             "f32",
482             "i32",
483             "u32",
484             // ... and pre-declared type generators:
485             "array",
486             "atomic",
487             "mat2x2",
488             "mat2x3",
489             "mat2x4",
490             "mat3x2",
491             "mat3x3",
492             "mat3x4",
493             "mat4x2",
494             "mat4x3",
495             "mat4x4",
496             "ptr",
497             "texture_1d",
498             "texture_2d",
499             "texture_2d_array",
500             "texture_3d",
501             "texture_cube",
502             "texture_cube_array",
503             "texture_multisampled_2d",
504             "texture_storage_1d",
505             "texture_storage_2d",
506             "texture_storage_2d_array",
507             "texture_storage_3d",
508             "vec2",
509             "vec3",
510             "vec4",
511             // Pre-declared enumerants: https://www.w3.org/TR/WGSL/#predeclared-enumerants
512             "read",
513             "write",
514             "read_write",
515             "function",
516             "private",
517             "workgroup",
518             "uniform",
519             "storage",
520             "perspective",
521             "linear",
522             "flat",
523             "center",
524             "centroid",
525             "sample",
526             "vertex_index",
527             "instance_index",
528             "position",
529             "front_facing",
530             "frag_depth",
531             "local_invocation_id",
532             "local_invocation_index",
533             "global_invocation_id",
534             "workgroup_id",
535             "num_workgroups",
536             "sample_index",
537             "sample_mask",
538             "rgba8unorm",
539             "rgba8snorm",
540             "rgba8uint",
541             "rgba8sint",
542             "rgba16uint",
543             "rgba16sint",
544             "rgba16float",
545             "r32uint",
546             "r32sint",
547             "r32float",
548             "rg32uint",
549             "rg32sint",
550             "rg32float",
551             "rgba32uint",
552             "rgba32sint",
553             "rgba32float",
554             "bgra8unorm",
555             // Reserved words: https://www.w3.org/TR/WGSL/#reserved-words
556             "_",
557             "NULL",
558             "Self",
559             "abstract",
560             "active",
561             "alignas",
562             "alignof",
563             "as",
564             "asm",
565             "asm_fragment",
566             "async",
567             "attribute",
568             "auto",
569             "await",
570             "become",
571             "binding_array",
572             "cast",
573             "catch",
574             "class",
575             "co_await",
576             "co_return",
577             "co_yield",
578             "coherent",
579             "column_major",
580             "common",
581             "compile",
582             "compile_fragment",
583             "concept",
584             "const_cast",
585             "consteval",
586             "constexpr",
587             "constinit",
588             "crate",
589             "debugger",
590             "decltype",
591             "delete",
592             "demote",
593             "demote_to_helper",
594             "do",
595             "dynamic_cast",
596             "enum",
597             "explicit",
598             "export",
599             "extends",
600             "extern",
601             "external",
602             "fallthrough",
603             "filter",
604             "final",
605             "finally",
606             "friend",
607             "from",
608             "fxgroup",
609             "get",
610             "goto",
611             "groupshared",
612             "highp",
613             "impl",
614             "implements",
615             "import",
616             "inline",
617             "instanceof",
618             "interface",
619             "layout",
620             "lowp",
621             "macro",
622             "macro_rules",
623             "match",
624             "mediump",
625             "meta",
626             "mod",
627             "module",
628             "move",
629             "mut",
630             "mutable",
631             "namespace",
632             "new",
633             "nil",
634             "noexcept",
635             "noinline",
636             "nointerpolation",
637             "noperspective",
638             "null",
639             "nullptr",
640             "of",
641             "operator",
642             "package",
643             "packoffset",
644             "partition",
645             "pass",
646             "patch",
647             "pixelfragment",
648             "precise",
649             "precision",
650             "premerge",
651             "priv",
652             "protected",
653             "pub",
654             "public",
655             "readonly",
656             "ref",
657             "regardless",
658             "register",
659             "reinterpret_cast",
660             "require",
661             "resource",
662             "restrict",
663             "self",
664             "set",
665             "shared",
666             "sizeof",
667             "smooth",
668             "snorm",
669             "static",
670             "static_assert",
671             "static_cast",
672             "std",
673             "subroutine",
674             "super",
675             "target",
676             "template",
677             "this",
678             "thread_local",
679             "throw",
680             "trait",
681             "try",
682             "type",
683             "typedef",
684             "typeid",
685             "typename",
686             "typeof",
687             "union",
688             "unless",
689             "unorm",
690             "unsafe",
691             "unsized",
692             "use",
693             "using",
694             "varying",
695             "virtual",
696             "volatile",
697             "wgsl",
698             "where",
699             "with",
700             "writeonly",
701             "yield",
702     };
703 
704     return kReservedWords.contains(word);
705 }
706 
pipeline_struct_prefix(ProgramKind kind)707 std::string_view pipeline_struct_prefix(ProgramKind kind) {
708     if (ProgramConfig::IsVertex(kind)) {
709         return "VS";
710     }
711     if (ProgramConfig::IsFragment(kind)) {
712         return "FS";
713     }
714     if (ProgramConfig::IsCompute(kind)) {
715         return "CS";
716     }
717     // Compute programs don't have stage-in/stage-out pipeline structs.
718     return "";
719 }
720 
address_space_to_str(PtrAddressSpace addressSpace)721 std::string_view address_space_to_str(PtrAddressSpace addressSpace) {
722     switch (addressSpace) {
723         case PtrAddressSpace::kFunction:
724             return "function";
725         case PtrAddressSpace::kPrivate:
726             return "private";
727         case PtrAddressSpace::kStorage:
728             return "storage";
729     }
730     SkDEBUGFAIL("unsupported ptr address space");
731     return "unsupported";
732 }
733 
to_scalar_type(const Type & type)734 std::string_view to_scalar_type(const Type& type) {
735     SkASSERT(type.typeKind() == Type::TypeKind::kScalar);
736     switch (type.numberKind()) {
737         // Floating-point numbers in WebGPU currently always have 32-bit footprint and
738         // relaxed-precision is not supported without extensions. f32 is the only floating-point
739         // number type in WGSL (see the discussion on https://github.com/gpuweb/gpuweb/issues/658).
740         case Type::NumberKind::kFloat:
741             return "f32";
742         case Type::NumberKind::kSigned:
743             return "i32";
744         case Type::NumberKind::kUnsigned:
745             return "u32";
746         case Type::NumberKind::kBoolean:
747             return "bool";
748         case Type::NumberKind::kNonnumeric:
749             [[fallthrough]];
750         default:
751             break;
752     }
753     return type.name();
754 }
755 
756 // Convert a SkSL type to a WGSL type. Handles all plain types except structure types
757 // (see https://www.w3.org/TR/WGSL/#plain-types-section).
to_wgsl_type(const Context & context,const Type & raw,const Layout * layout=nullptr)758 std::string to_wgsl_type(const Context& context, const Type& raw, const Layout* layout = nullptr) {
759     const Type& type = raw.resolve().scalarTypeForLiteral();
760     switch (type.typeKind()) {
761         case Type::TypeKind::kScalar:
762             return std::string(to_scalar_type(type));
763 
764         case Type::TypeKind::kAtomic:
765             SkASSERT(type.matches(*context.fTypes.fAtomicUInt));
766             return "atomic<u32>";
767 
768         case Type::TypeKind::kVector: {
769             std::string_view ct = to_scalar_type(type.componentType());
770             return String::printf("vec%d<%.*s>", type.columns(), (int)ct.length(), ct.data());
771         }
772         case Type::TypeKind::kMatrix: {
773             std::string_view ct = to_scalar_type(type.componentType());
774             return String::printf("mat%dx%d<%.*s>",
775                                   type.columns(), type.rows(), (int)ct.length(), ct.data());
776         }
777         case Type::TypeKind::kArray: {
778             std::string result = "array<" + to_wgsl_type(context, type.componentType(), layout);
779             if (!type.isUnsizedArray()) {
780                 result += ", ";
781                 result += std::to_string(type.columns());
782             }
783             return result + '>';
784         }
785         case Type::TypeKind::kTexture: {
786             if (type.matches(*context.fTypes.fWriteOnlyTexture2D)) {
787                 std::string result = "texture_storage_2d<";
788                 // Write-only storage texture types require a pixel format, which is in the layout.
789                 SkASSERT(layout);
790                 LayoutFlags pixelFormat = layout->fFlags & LayoutFlag::kAllPixelFormats;
791                 switch (pixelFormat.value()) {
792                     case (int)LayoutFlag::kRGBA8:
793                         return result + "rgba8unorm, write>";
794 
795                     case (int)LayoutFlag::kRGBA32F:
796                         return result + "rgba32float, write>";
797 
798                     case (int)LayoutFlag::kR32F:
799                         return result + "r32float, write>";
800 
801                     default:
802                         // The front-end should have rejected this.
803                         return result + "write>";
804                 }
805             }
806             if (type.matches(*context.fTypes.fReadOnlyTexture2D)) {
807                 return "texture_2d<f32>";
808             }
809             break;
810         }
811         default:
812             break;
813     }
814     return std::string(type.name());
815 }
816 
to_ptr_type(const Context & context,const Type & type,const Layout * layout,PtrAddressSpace addressSpace=PtrAddressSpace::kFunction)817 std::string to_ptr_type(const Context& context,
818                         const Type& type,
819                         const Layout* layout,
820                         PtrAddressSpace addressSpace = PtrAddressSpace::kFunction) {
821     return "ptr<" + std::string(address_space_to_str(addressSpace)) + ", " +
822            to_wgsl_type(context, type, layout) + '>';
823 }
824 
wgsl_builtin_name(WGSLCodeGenerator::Builtin builtin)825 std::string_view wgsl_builtin_name(WGSLCodeGenerator::Builtin builtin) {
826     using Builtin = WGSLCodeGenerator::Builtin;
827     switch (builtin) {
828         case Builtin::kVertexIndex:
829             return "@builtin(vertex_index)";
830         case Builtin::kInstanceIndex:
831             return "@builtin(instance_index)";
832         case Builtin::kPosition:
833             return "@builtin(position)";
834         case Builtin::kLastFragColor:
835             return "@color(0)";
836         case Builtin::kFrontFacing:
837             return "@builtin(front_facing)";
838         case Builtin::kSampleIndex:
839             return "@builtin(sample_index)";
840         case Builtin::kFragDepth:
841             return "@builtin(frag_depth)";
842         case Builtin::kSampleMask:
843         case Builtin::kSampleMaskIn:
844             return "@builtin(sample_mask)";
845         case Builtin::kLocalInvocationId:
846             return "@builtin(local_invocation_id)";
847         case Builtin::kLocalInvocationIndex:
848             return "@builtin(local_invocation_index)";
849         case Builtin::kGlobalInvocationId:
850             return "@builtin(global_invocation_id)";
851         case Builtin::kWorkgroupId:
852             return "@builtin(workgroup_id)";
853         case Builtin::kNumWorkgroups:
854             return "@builtin(num_workgroups)";
855         default:
856             break;
857     }
858 
859     SkDEBUGFAIL("unsupported builtin");
860     return "unsupported";
861 }
862 
wgsl_builtin_type(WGSLCodeGenerator::Builtin builtin)863 std::string_view wgsl_builtin_type(WGSLCodeGenerator::Builtin builtin) {
864     using Builtin = WGSLCodeGenerator::Builtin;
865     switch (builtin) {
866         case Builtin::kVertexIndex:
867             return "u32";
868         case Builtin::kInstanceIndex:
869             return "u32";
870         case Builtin::kPosition:
871             return "vec4<f32>";
872         case Builtin::kLastFragColor:
873             return "vec4<f32>";
874         case Builtin::kFrontFacing:
875             return "bool";
876         case Builtin::kSampleIndex:
877             return "u32";
878         case Builtin::kFragDepth:
879             return "f32";
880         case Builtin::kSampleMask:
881             return "u32";
882         case Builtin::kSampleMaskIn:
883             return "u32";
884         case Builtin::kLocalInvocationId:
885             return "vec3<u32>";
886         case Builtin::kLocalInvocationIndex:
887             return "u32";
888         case Builtin::kGlobalInvocationId:
889             return "vec3<u32>";
890         case Builtin::kWorkgroupId:
891             return "vec3<u32>";
892         case Builtin::kNumWorkgroups:
893             return "vec3<u32>";
894         default:
895             break;
896     }
897 
898     SkDEBUGFAIL("unsupported builtin");
899     return "unsupported";
900 }
901 
902 // Some built-in variables have a type that differs from their SkSL counterpart (e.g. signed vs
903 // unsigned integer). We handle these cases with an explicit type conversion during a variable
904 // reference. Returns the WGSL type of the conversion target if conversion is needed, otherwise
905 // returns std::nullopt.
needs_builtin_type_conversion(const Variable & v)906 std::optional<std::string_view> needs_builtin_type_conversion(const Variable& v) {
907     switch (v.layout().fBuiltin) {
908         case SK_VERTEXID_BUILTIN:
909         case SK_INSTANCEID_BUILTIN:
910             return {"i32"};
911         default:
912             break;
913     }
914     return std::nullopt;
915 }
916 
917 // Map a SkSL builtin flag to a WGSL builtin kind. Returns std::nullopt if `builtin` is not
918 // not supported for WGSL.
919 //
920 // Also see //src/sksl/sksl_vert.sksl and //src/sksl/sksl_frag.sksl for supported built-ins.
builtin_from_sksl_name(int builtin)921 std::optional<WGSLCodeGenerator::Builtin> builtin_from_sksl_name(int builtin) {
922     using Builtin = WGSLCodeGenerator::Builtin;
923     switch (builtin) {
924         case SK_POSITION_BUILTIN:
925             [[fallthrough]];
926         case SK_FRAGCOORD_BUILTIN:
927             return Builtin::kPosition;
928         case SK_VERTEXID_BUILTIN:
929             return Builtin::kVertexIndex;
930         case SK_INSTANCEID_BUILTIN:
931             return Builtin::kInstanceIndex;
932         case SK_LASTFRAGCOLOR_BUILTIN:
933             return Builtin::kLastFragColor;
934         case SK_CLOCKWISE_BUILTIN:
935             // TODO(skia:13092): While `front_facing` is the corresponding built-in, it does not
936             // imply a particular winding order. We correctly compute the face orientation based
937             // on how Skia configured the render pipeline for all references to this built-in
938             // variable (see `SkSL::Program::Interface::fRTFlipUniform`).
939             return Builtin::kFrontFacing;
940         case SK_SAMPLEMASKIN_BUILTIN:
941             return Builtin::kSampleMaskIn;
942         case SK_SAMPLEMASK_BUILTIN:
943             return Builtin::kSampleMask;
944         case SK_NUMWORKGROUPS_BUILTIN:
945             return Builtin::kNumWorkgroups;
946         case SK_WORKGROUPID_BUILTIN:
947             return Builtin::kWorkgroupId;
948         case SK_LOCALINVOCATIONID_BUILTIN:
949             return Builtin::kLocalInvocationId;
950         case SK_GLOBALINVOCATIONID_BUILTIN:
951             return Builtin::kGlobalInvocationId;
952         case SK_LOCALINVOCATIONINDEX_BUILTIN:
953             return Builtin::kLocalInvocationIndex;
954         default:
955             break;
956     }
957     return std::nullopt;
958 }
959 
delimiter_to_str(WGSLCodeGenerator::Delimiter delimiter)960 const char* delimiter_to_str(WGSLCodeGenerator::Delimiter delimiter) {
961     using Delim = WGSLCodeGenerator::Delimiter;
962     switch (delimiter) {
963         case Delim::kComma:
964             return ",";
965         case Delim::kSemicolon:
966             return ";";
967         case Delim::kNone:
968         default:
969             break;
970     }
971     return "";
972 }
973 
974 // FunctionDependencyResolver visits the IR tree rooted at a particular function definition and
975 // computes that function's dependencies on pipeline stage IO parameters. These are later used to
976 // synthesize arguments when writing out function definitions.
977 class FunctionDependencyResolver : public ProgramVisitor {
978 public:
979     using Deps = WGSLFunctionDependencies;
980     using DepsMap = WGSLCodeGenerator::ProgramRequirements::DepsMap;
981 
FunctionDependencyResolver(const Program * p,const FunctionDeclaration * f,DepsMap * programDependencyMap)982     FunctionDependencyResolver(const Program* p,
983                                const FunctionDeclaration* f,
984                                DepsMap* programDependencyMap)
985             : fProgram(p), fFunction(f), fDependencyMap(programDependencyMap) {}
986 
resolve()987     Deps resolve() {
988         fDeps = WGSLFunctionDependency::kNone;
989         this->visit(*fProgram);
990         return fDeps;
991     }
992 
993 private:
visitProgramElement(const ProgramElement & p)994     bool visitProgramElement(const ProgramElement& p) override {
995         // Only visit the program that matches the requested function.
996         if (p.is<FunctionDefinition>() && &p.as<FunctionDefinition>().declaration() == fFunction) {
997             return INHERITED::visitProgramElement(p);
998         }
999         // Continue visiting other program elements.
1000         return false;
1001     }
1002 
visitExpression(const Expression & e)1003     bool visitExpression(const Expression& e) override {
1004         if (e.is<VariableReference>()) {
1005             const VariableReference& v = e.as<VariableReference>();
1006             if (v.variable()->storage() == Variable::Storage::kGlobal) {
1007                 ModifierFlags flags = v.variable()->modifierFlags();
1008                 if (flags & ModifierFlag::kIn) {
1009                     fDeps |= WGSLFunctionDependency::kPipelineInputs;
1010                 }
1011                 if (flags & ModifierFlag::kOut) {
1012                     fDeps |= WGSLFunctionDependency::kPipelineOutputs;
1013                 }
1014             }
1015         } else if (e.is<FunctionCall>()) {
1016             // The current function that we're processing (`fFunction`) inherits the dependencies of
1017             // functions that it makes calls to, because the pipeline stage IO parameters need to be
1018             // passed down as an argument.
1019             const FunctionCall& callee = e.as<FunctionCall>();
1020 
1021             // Don't process a function again if we have already resolved it.
1022             Deps* found = fDependencyMap->find(&callee.function());
1023             if (found) {
1024                 fDeps |= *found;
1025             } else {
1026                 // Store the dependencies that have been discovered for the current function so far.
1027                 // If `callee` directly or indirectly calls the current function, then this value
1028                 // will prevent an infinite recursion.
1029                 fDependencyMap->set(fFunction, fDeps);
1030 
1031                 // Separately traverse the called function's definition and determine its
1032                 // dependencies.
1033                 FunctionDependencyResolver resolver(fProgram, &callee.function(), fDependencyMap);
1034                 Deps calleeDeps = resolver.resolve();
1035 
1036                 // Store the callee's dependencies in the global map to avoid processing
1037                 // the function again for future calls.
1038                 fDependencyMap->set(&callee.function(), calleeDeps);
1039 
1040                 // Add to the current function's dependencies.
1041                 fDeps |= calleeDeps;
1042             }
1043         }
1044         return INHERITED::visitExpression(e);
1045     }
1046 
1047     const Program* const fProgram;
1048     const FunctionDeclaration* const fFunction;
1049     DepsMap* const fDependencyMap;
1050     Deps fDeps = WGSLFunctionDependency::kNone;
1051 
1052     using INHERITED = ProgramVisitor;
1053 };
1054 
resolve_program_requirements(const Program * program)1055 WGSLCodeGenerator::ProgramRequirements resolve_program_requirements(const Program* program) {
1056     WGSLCodeGenerator::ProgramRequirements requirements;
1057 
1058     for (const ProgramElement* e : program->elements()) {
1059         switch (e->kind()) {
1060             case ProgramElement::Kind::kFunction: {
1061                 const FunctionDeclaration& decl = e->as<FunctionDefinition>().declaration();
1062 
1063                 FunctionDependencyResolver resolver(program, &decl, &requirements.fDependencies);
1064                 requirements.fDependencies.set(&decl, resolver.resolve());
1065                 break;
1066             }
1067             case ProgramElement::Kind::kGlobalVar: {
1068                 const GlobalVarDeclaration& decl = e->as<GlobalVarDeclaration>();
1069                 if (decl.varDeclaration().var()->modifierFlags().isPixelLocal()) {
1070                     requirements.fPixelLocalExtension = true;
1071                 }
1072                 break;
1073             }
1074             default:
1075                 break;
1076         }
1077     }
1078 
1079     return requirements;
1080 }
1081 
collect_pipeline_io_vars(const Program * program,TArray<const Variable * > * ioVars,ModifierFlag ioType)1082 void collect_pipeline_io_vars(const Program* program,
1083                               TArray<const Variable*>* ioVars,
1084                               ModifierFlag ioType) {
1085     for (const ProgramElement* e : program->elements()) {
1086         if (e->is<GlobalVarDeclaration>()) {
1087             const Variable* v = e->as<GlobalVarDeclaration>().varDeclaration().var();
1088             if (v->modifierFlags() & ioType) {
1089                 ioVars->push_back(v);
1090             }
1091         } else if (e->is<InterfaceBlock>()) {
1092             const Variable* v = e->as<InterfaceBlock>().var();
1093             if (v->modifierFlags() & ioType) {
1094                 ioVars->push_back(v);
1095             }
1096         }
1097     }
1098 }
1099 
is_in_global_uniforms(const Variable & var)1100 bool is_in_global_uniforms(const Variable& var) {
1101     SkASSERT(var.storage() == VariableStorage::kGlobal);
1102     return  var.modifierFlags().isUniform() &&
1103            !var.type().isOpaque() &&
1104            !var.interfaceBlock();
1105 }
1106 
1107 }  // namespace
1108 
1109 class WGSLCodeGenerator::LValue {
1110 public:
1111     virtual ~LValue() = default;
1112 
1113     // Returns a WGSL expression that loads from the lvalue with no side effects.
1114     // (e.g. `array[index].field`)
1115     virtual std::string load() = 0;
1116 
1117     // Returns a WGSL statement that stores into the lvalue with no side effects.
1118     // (e.g. `array[index].field = the_passed_in_value_string;`)
1119     virtual std::string store(const std::string& value) = 0;
1120 };
1121 
1122 class WGSLCodeGenerator::PointerLValue : public WGSLCodeGenerator::LValue {
1123 public:
1124     // `name` must be a WGSL expression with no side-effects, which we can safely take the address
1125     // of. (e.g. `array[index].field` would be valid, but `array[Func()]` or `vector.x` are not.)
PointerLValue(std::string name)1126     PointerLValue(std::string name) : fName(std::move(name)) {}
1127 
load()1128     std::string load() override {
1129         return fName;
1130     }
1131 
store(const std::string & value)1132     std::string store(const std::string& value) override {
1133         return fName + " = " + value + ";";
1134     }
1135 
1136 private:
1137     std::string fName;
1138 };
1139 
1140 class WGSLCodeGenerator::VectorComponentLValue : public WGSLCodeGenerator::LValue {
1141 public:
1142     // `name` must be a WGSL expression with no side-effects that points to a single component of a
1143     // WGSL vector.
VectorComponentLValue(std::string name)1144     VectorComponentLValue(std::string name) : fName(std::move(name)) {}
1145 
load()1146     std::string load() override {
1147         return fName;
1148     }
1149 
store(const std::string & value)1150     std::string store(const std::string& value) override {
1151         return fName + " = " + value + ";";
1152     }
1153 
1154 private:
1155     std::string fName;
1156 };
1157 
1158 class WGSLCodeGenerator::SwizzleLValue : public WGSLCodeGenerator::LValue {
1159 public:
1160     // `name` must be a WGSL expression with no side-effects that points to a WGSL vector.
SwizzleLValue(const Context & ctx,std::string name,const Type & t,const ComponentArray & c)1161     SwizzleLValue(const Context& ctx, std::string name, const Type& t, const ComponentArray& c)
1162             : fContext(ctx)
1163             , fName(std::move(name))
1164             , fType(t)
1165             , fComponents(c) {
1166         // If the component array doesn't cover the entire value, we need to create masks for
1167         // writing back into the lvalue. For example, if the type is vec4 and the component array
1168         // holds `zx`, a GLSL assignment would look like:
1169         //     name.zx = new_value;
1170         //
1171         // The equivalent WGSL assignment statement would look like:
1172         //     name = vec4<f32>(new_value, name.xw).yzxw;
1173         //
1174         // This replaces name.zy with new_value.xy, and leaves name.xw at their original values.
1175         // By convention, we always put the new value first and the original values second; it might
1176         // be possible to find better arrangements which simplify the assignment overall, but we
1177         // don't attempt this.
1178         int fullSlotCount = fType.slotCount();
1179         SkASSERT(fullSlotCount <= 4);
1180 
1181         // First, see which components are used.
1182         // The assignment swizzle must not reuse components.
1183         bool used[4] = {};
1184         for (int8_t component : fComponents) {
1185             SkASSERT(!used[component]);
1186             used[component] = true;
1187         }
1188 
1189         // Any untouched components will need to be fetched from the original value.
1190         for (int index = 0; index < fullSlotCount; ++index) {
1191             if (!used[index]) {
1192                 fUntouchedComponents.push_back(index);
1193             }
1194         }
1195 
1196         // The reintegration swizzle needs to move the components back into their proper slots.
1197         fReintegrationSwizzle.resize(fullSlotCount);
1198         int reintegrateIndex = 0;
1199 
1200         // This refills the untouched slots with the original values.
1201         auto refillUntouchedSlots = [&] {
1202             for (int index = 0; index < fullSlotCount; ++index) {
1203                 if (!used[index]) {
1204                     fReintegrationSwizzle[index] = reintegrateIndex++;
1205                 }
1206             }
1207         };
1208 
1209         // This places the new-value components into the proper slots.
1210         auto insertNewValuesIntoSlots = [&] {
1211             for (int index = 0; index < fComponents.size(); ++index) {
1212                 fReintegrationSwizzle[fComponents[index]] = reintegrateIndex++;
1213             }
1214         };
1215 
1216         // When reintegrating the untouched and new values, if the `x` slot is overwritten, we
1217         // reintegrate the new value first. Otherwise, we reintegrate the original value first.
1218         // This increases our odds of getting an identity swizzle for the reintegration.
1219         if (used[0]) {
1220             fReintegrateNewValueFirst = true;
1221             insertNewValuesIntoSlots();
1222             refillUntouchedSlots();
1223         } else {
1224             fReintegrateNewValueFirst = false;
1225             refillUntouchedSlots();
1226             insertNewValuesIntoSlots();
1227         }
1228     }
1229 
load()1230     std::string load() override {
1231         return fName + "." + Swizzle::MaskString(fComponents);
1232     }
1233 
store(const std::string & value)1234     std::string store(const std::string& value) override {
1235         // `variable = `
1236         std::string result = fName;
1237         result += " = ";
1238 
1239         if (fUntouchedComponents.empty()) {
1240             // `(new_value);`
1241             result += '(';
1242             result += value;
1243             result += ")";
1244         } else if (fReintegrateNewValueFirst) {
1245             // `vec4<f32>((new_value), `
1246             result += to_wgsl_type(fContext, fType);
1247             result += "((";
1248             result += value;
1249             result += "), ";
1250 
1251             // `variable.yz)`
1252             result += fName;
1253             result += '.';
1254             result += Swizzle::MaskString(fUntouchedComponents);
1255             result += ')';
1256         } else {
1257             // `vec4<f32>(variable.yz`
1258             result += to_wgsl_type(fContext, fType);
1259             result += '(';
1260             result += fName;
1261             result += '.';
1262             result += Swizzle::MaskString(fUntouchedComponents);
1263 
1264             // `, (new_value))`
1265             result += ", (";
1266             result += value;
1267             result += "))";
1268         }
1269 
1270         if (!Swizzle::IsIdentity(fReintegrationSwizzle)) {
1271             // `.wzyx`
1272             result += '.';
1273             result += Swizzle::MaskString(fReintegrationSwizzle);
1274         }
1275 
1276         return result + ';';
1277     }
1278 
1279 private:
1280     const Context& fContext;
1281     std::string fName;
1282     const Type& fType;
1283     ComponentArray fComponents;
1284     ComponentArray fUntouchedComponents;
1285     ComponentArray fReintegrationSwizzle;
1286     bool fReintegrateNewValueFirst = false;
1287 };
1288 
generateCode()1289 bool WGSLCodeGenerator::generateCode() {
1290     // The resources of a WGSL program are structured in the following way:
1291     // - Stage attribute inputs and outputs are bundled inside synthetic structs called
1292     //   VSIn/VSOut/FSIn/FSOut/CSIn.
1293     // - All uniform and storage type resources are declared in global scope.
1294     this->preprocessProgram();
1295 
1296     {
1297         AutoOutputStream outputToHeader(this, &fHeader, &fIndentation);
1298         this->writeEnables();
1299         this->writeStageInputStruct();
1300         this->writeStageOutputStruct();
1301         this->writeUniformsAndBuffers();
1302         this->writeNonBlockUniformsForTests();
1303     }
1304     StringStream body;
1305     {
1306         // Emit the program body.
1307         AutoOutputStream outputToBody(this, &body, &fIndentation);
1308         const FunctionDefinition* mainFunc = nullptr;
1309         for (const ProgramElement* e : fProgram.elements()) {
1310             this->writeProgramElement(*e);
1311 
1312             if (e->is<FunctionDefinition>()) {
1313                 const FunctionDefinition& func = e->as<FunctionDefinition>();
1314                 if (func.declaration().isMain()) {
1315                     mainFunc = &func;
1316                 }
1317             }
1318         }
1319 
1320         // At the bottom of the program body, emit the entrypoint function.
1321         // The entrypoint relies on state that has been collected while we emitted the rest of the
1322         // program, so it's important to do it last to make sure we don't miss anything.
1323         if (mainFunc) {
1324             this->writeEntryPoint(*mainFunc);
1325         }
1326     }
1327 
1328     write_stringstream(fHeader, *fOut);
1329     write_stringstream(body, *fOut);
1330 
1331     this->writeUniformPolyfills();
1332 
1333     return fContext.fErrors->errorCount() == 0;
1334 }
1335 
writeUniformPolyfills()1336 void WGSLCodeGenerator::writeUniformPolyfills() {
1337     // If we didn't encounter any uniforms that need polyfilling, there is nothing to do.
1338     if (fFieldPolyfillMap.empty()) {
1339         return;
1340     }
1341 
1342     // We store the list of polyfilled fields as pointers in a hash-map, so the order can be
1343     // inconsistent across runs. For determinism, we sort the polyfilled objects by name here.
1344     TArray<const FieldPolyfillMap::Pair*> orderedFields;
1345     orderedFields.reserve_exact(fFieldPolyfillMap.count());
1346 
1347     fFieldPolyfillMap.foreach([&](const FieldPolyfillMap::Pair& pair) {
1348         orderedFields.push_back(&pair);
1349     });
1350 
1351     std::sort(orderedFields.begin(),
1352               orderedFields.end(),
1353               [](const FieldPolyfillMap::Pair* a, const FieldPolyfillMap::Pair* b) {
1354                   return a->second.fReplacementName < b->second.fReplacementName;
1355               });
1356 
1357     THashSet<const Type*> writtenArrayElementPolyfill;
1358     bool writtenUniformMatrixPolyfill[5][5] = {};  // m[column][row] for each matrix type
1359     bool writtenUniformRowPolyfill[5] = {};        // for each matrix row-size
1360     bool anyFieldAccessed = false;
1361     for (const FieldPolyfillMap::Pair* pair : orderedFields) {
1362         const auto& [field, info] = *pair;
1363         const Type* fieldType = field->fType;
1364         const Layout* fieldLayout = &field->fLayout;
1365 
1366         if (info.fIsArray) {
1367             fieldType = &fieldType->componentType();
1368             if (!writtenArrayElementPolyfill.contains(fieldType)) {
1369                 writtenArrayElementPolyfill.add(fieldType);
1370                 this->write("struct _skArrayElement_");
1371                 this->write(fieldType->abbreviatedName());
1372                 this->writeLine(" {");
1373 
1374                 if (info.fIsMatrix) {
1375                     // Create a struct representing the array containing std140-padded matrices.
1376                     this->write("  e : _skMatrix");
1377                     this->write(std::to_string(fieldType->columns()));
1378                     this->writeLine(std::to_string(fieldType->rows()));
1379                 } else {
1380                     // Create a struct representing the array with extra padding between elements.
1381                     this->write("  @size(16) e : ");
1382                     this->writeLine(to_wgsl_type(fContext, *fieldType, fieldLayout));
1383                 }
1384                 this->writeLine("};");
1385             }
1386         }
1387 
1388         if (info.fIsMatrix) {
1389             // Create structs representing the matrix as an array of vectors, whether or not the
1390             // matrix is ever accessed by the SkSL. (The struct itself is mentioned in the list of
1391             // uniforms.)
1392             int c = fieldType->columns();
1393             int r = fieldType->rows();
1394             if (!writtenUniformRowPolyfill[r]) {
1395                 writtenUniformRowPolyfill[r] = true;
1396 
1397                 this->write("struct _skRow");
1398                 this->write(std::to_string(r));
1399                 this->writeLine(" {");
1400                 this->write("  @size(16) r : vec");
1401                 this->write(std::to_string(r));
1402                 this->write("<");
1403                 this->write(to_wgsl_type(fContext, fieldType->componentType(), fieldLayout));
1404                 this->writeLine(">");
1405                 this->writeLine("};");
1406             }
1407 
1408             if (!writtenUniformMatrixPolyfill[c][r]) {
1409                 writtenUniformMatrixPolyfill[c][r] = true;
1410 
1411                 this->write("struct _skMatrix");
1412                 this->write(std::to_string(c));
1413                 this->write(std::to_string(r));
1414                 this->writeLine(" {");
1415                 this->write("  c : array<_skRow");
1416                 this->write(std::to_string(r));
1417                 this->write(", ");
1418                 this->write(std::to_string(c));
1419                 this->writeLine(">");
1420                 this->writeLine("};");
1421             }
1422         }
1423 
1424         // We create a polyfill variable only if the uniform was actually accessed.
1425         if (!info.fWasAccessed) {
1426             continue;
1427         }
1428         anyFieldAccessed = true;
1429         this->write("var<private> ");
1430         this->write(info.fReplacementName);
1431         this->write(": ");
1432 
1433         const Type& interfaceBlockType = info.fInterfaceBlock->var()->type();
1434         if (interfaceBlockType.isArray()) {
1435             this->write("array<");
1436             this->write(to_wgsl_type(fContext, *field->fType, fieldLayout));
1437             this->write(", ");
1438             this->write(std::to_string(interfaceBlockType.columns()));
1439             this->write(">");
1440         } else {
1441             this->write(to_wgsl_type(fContext, *field->fType, fieldLayout));
1442         }
1443         this->writeLine(";");
1444     }
1445 
1446     // If no fields were actually accessed, _skInitializePolyfilledUniforms will not be called and
1447     // we can avoid emitting an empty, dead function.
1448     if (!anyFieldAccessed) {
1449         return;
1450     }
1451 
1452     this->writeLine("fn _skInitializePolyfilledUniforms() {");
1453     ++fIndentation;
1454 
1455     for (const FieldPolyfillMap::Pair* pair : orderedFields) {
1456         // Only initialize a polyfill global if the uniform was actually accessed.
1457         const auto& [field, info] = *pair;
1458         if (!info.fWasAccessed) {
1459             continue;
1460         }
1461 
1462         // Synthesize the name of this uniform variable
1463         std::string_view instanceName = info.fInterfaceBlock->instanceName();
1464         const Type& interfaceBlockType = info.fInterfaceBlock->var()->type();
1465         if (instanceName.empty()) {
1466             instanceName = fInterfaceBlockNameMap[&interfaceBlockType.componentType()];
1467         }
1468 
1469         // Initialize the global variable associated with this uniform.
1470         // If the interface block is arrayed, the associated global will be arrayed as well.
1471         int numIBElements = interfaceBlockType.isArray() ? interfaceBlockType.columns() : 1;
1472         for (int ibIdx = 0; ibIdx < numIBElements; ++ibIdx) {
1473             this->write(info.fReplacementName);
1474             if (interfaceBlockType.isArray()) {
1475                 this->write("[");
1476                 this->write(std::to_string(ibIdx));
1477                 this->write("]");
1478             }
1479             this->write(" = ");
1480 
1481             const Type* fieldType = field->fType;
1482             const Layout* fieldLayout = &field->fLayout;
1483 
1484             int numArrayElements;
1485             if (info.fIsArray) {
1486                 this->write(to_wgsl_type(fContext, *fieldType, fieldLayout));
1487                 this->write("(");
1488                 numArrayElements = fieldType->columns();
1489                 fieldType = &fieldType->componentType();
1490             } else {
1491                 numArrayElements = 1;
1492             }
1493 
1494             auto arraySeparator = String::Separator();
1495             for (int arrayIdx = 0; arrayIdx < numArrayElements; arrayIdx++) {
1496                 this->write(arraySeparator());
1497 
1498                 std::string fieldName{instanceName};
1499                 if (interfaceBlockType.isArray()) {
1500                     fieldName += '[';
1501                     fieldName += std::to_string(ibIdx);
1502                     fieldName += ']';
1503                 }
1504                 fieldName += '.';
1505                 fieldName += this->assembleName(field->fName);
1506 
1507                 if (info.fIsArray) {
1508                     fieldName += '[';
1509                     fieldName += std::to_string(arrayIdx);
1510                     fieldName += "].e";
1511                 }
1512 
1513                 if (info.fIsMatrix) {
1514                     this->write(to_wgsl_type(fContext, *fieldType, fieldLayout));
1515                     this->write("(");
1516                     int numColumns = fieldType->columns();
1517                     auto matrixSeparator = String::Separator();
1518                     for (int column = 0; column < numColumns; column++) {
1519                         this->write(matrixSeparator());
1520                         this->write(fieldName);
1521                         this->write(".c[");
1522                         this->write(std::to_string(column));
1523                         this->write("].r");
1524                     }
1525                     this->write(")");
1526                 } else {
1527                     this->write(fieldName);
1528                 }
1529             }
1530 
1531             if (info.fIsArray) {
1532                 this->write(")");
1533             }
1534 
1535             this->writeLine(";");
1536         }
1537     }
1538 
1539     --fIndentation;
1540     this->writeLine("}");
1541 }
1542 
1543 
preprocessProgram()1544 void WGSLCodeGenerator::preprocessProgram() {
1545     fRequirements = resolve_program_requirements(&fProgram);
1546     collect_pipeline_io_vars(&fProgram, &fPipelineInputs, ModifierFlag::kIn);
1547     collect_pipeline_io_vars(&fProgram, &fPipelineOutputs, ModifierFlag::kOut);
1548 }
1549 
write(std::string_view s)1550 void WGSLCodeGenerator::write(std::string_view s) {
1551     if (s.empty()) {
1552         return;
1553     }
1554 #if defined(SK_DEBUG) || defined(SKSL_STANDALONE)
1555     if (fAtLineStart) {
1556         for (int i = 0; i < fIndentation; i++) {
1557             fOut->writeText("  ");
1558         }
1559     }
1560 #endif
1561     fOut->writeText(std::string(s).c_str());
1562     fAtLineStart = false;
1563 }
1564 
writeLine(std::string_view s)1565 void WGSLCodeGenerator::writeLine(std::string_view s) {
1566     this->write(s);
1567     fOut->writeText("\n");
1568     fAtLineStart = true;
1569 }
1570 
finishLine()1571 void WGSLCodeGenerator::finishLine() {
1572     if (!fAtLineStart) {
1573         this->writeLine();
1574     }
1575 }
1576 
assembleName(std::string_view name)1577 std::string WGSLCodeGenerator::assembleName(std::string_view name) {
1578     if (name.empty()) {
1579         // WGSL doesn't allow anonymous function parameters.
1580         return "_skAnonymous" + std::to_string(fScratchCount++);
1581     }
1582     // Add `R_` before reserved names to avoid any potential reserved-word conflict.
1583     return (skstd::starts_with(name, "_sk") ||
1584             skstd::starts_with(name, "R_") ||
1585             is_reserved_word(name))
1586                    ? std::string("R_") + std::string(name)
1587                    : std::string(name);
1588 }
1589 
writeVariableDecl(const Layout & layout,const Type & type,std::string_view name,Delimiter delimiter)1590 void WGSLCodeGenerator::writeVariableDecl(const Layout& layout,
1591                                           const Type& type,
1592                                           std::string_view name,
1593                                           Delimiter delimiter) {
1594     this->write(this->assembleName(name));
1595     this->write(": " + to_wgsl_type(fContext, type, &layout));
1596     this->writeLine(delimiter_to_str(delimiter));
1597 }
1598 
writePipelineIODeclaration(const Layout & layout,const Type & type,std::string_view name,Delimiter delimiter)1599 void WGSLCodeGenerator::writePipelineIODeclaration(const Layout& layout,
1600                                                    const Type& type,
1601                                                    std::string_view name,
1602                                                    Delimiter delimiter) {
1603     // In WGSL, an entry-point IO parameter is "one of either a built-in value or assigned a
1604     // location". However, some SkSL declarations, specifically sk_FragColor, can contain both a
1605     // location and a builtin modifier. In addition, WGSL doesn't have a built-in equivalent for
1606     // sk_FragColor as it relies on the user-defined location for a render target.
1607     //
1608     // Instead of special-casing sk_FragColor, we just give higher precedence to a location modifier
1609     // if a declaration happens to both have a location and it's a built-in.
1610     //
1611     // Also see:
1612     // https://www.w3.org/TR/WGSL/#input-output-locations
1613     // https://www.w3.org/TR/WGSL/#attribute-location
1614     // https://www.w3.org/TR/WGSL/#builtin-inputs-outputs
1615     if (layout.fLocation >= 0) {
1616         this->writeUserDefinedIODecl(layout, type, name, delimiter);
1617         return;
1618     }
1619     if (layout.fBuiltin >= 0) {
1620         if (layout.fBuiltin == SK_POINTSIZE_BUILTIN) {
1621             // WebGPU does not support the point-size builtin, but we silently replace it with a
1622             // global variable when it is used, instead of reporting an error.
1623             return;
1624         }
1625         auto builtin = builtin_from_sksl_name(layout.fBuiltin);
1626         if (builtin.has_value()) {
1627             this->writeBuiltinIODecl(type, name, *builtin, delimiter);
1628             return;
1629         }
1630     }
1631     fContext.fErrors->error(Position(), "declaration '" + std::string(name) + "' is not supported");
1632 }
1633 
writeUserDefinedIODecl(const Layout & layout,const Type & type,std::string_view name,Delimiter delimiter)1634 void WGSLCodeGenerator::writeUserDefinedIODecl(const Layout& layout,
1635                                                const Type& type,
1636                                                std::string_view name,
1637                                                Delimiter delimiter) {
1638     this->write("@location(" + std::to_string(layout.fLocation) + ") ");
1639 
1640     // @blend_src is only allowed when doing dual-source blending, and only on color attachment 0.
1641     if (layout.fLocation == 0 && layout.fIndex >= 0 && fProgram.fInterface.fOutputSecondaryColor) {
1642         this->write("@blend_src(" + std::to_string(layout.fIndex) + ") ");
1643     }
1644 
1645     // "User-defined IO of scalar or vector integer type must always be specified as
1646     // @interpolate(flat)" (see https://www.w3.org/TR/WGSL/#interpolation)
1647     if (type.isInteger() || (type.isVector() && type.componentType().isInteger())) {
1648         this->write("@interpolate(flat) ");
1649     }
1650 
1651     this->writeVariableDecl(layout, type, name, delimiter);
1652 }
1653 
writeBuiltinIODecl(const Type & type,std::string_view name,Builtin builtin,Delimiter delimiter)1654 void WGSLCodeGenerator::writeBuiltinIODecl(const Type& type,
1655                                            std::string_view name,
1656                                            Builtin builtin,
1657                                            Delimiter delimiter) {
1658     this->write(wgsl_builtin_name(builtin));
1659     this->write(" ");
1660     this->write(this->assembleName(name));
1661     this->write(": ");
1662     this->write(wgsl_builtin_type(builtin));
1663     this->writeLine(delimiter_to_str(delimiter));
1664 }
1665 
writeFunction(const FunctionDefinition & f)1666 void WGSLCodeGenerator::writeFunction(const FunctionDefinition& f) {
1667     const FunctionDeclaration& decl = f.declaration();
1668     fHasUnconditionalReturn = false;
1669     fConditionalScopeDepth = 0;
1670 
1671     SkASSERT(!fAtFunctionScope);
1672     fAtFunctionScope = true;
1673 
1674     // WGSL parameters are immutable and are considered as taking no storage, but SkSL parameters
1675     // are real variables. To work around this, we make var-based copies of parameters. It's
1676     // wasteful to make a copy of every single parameter--even if the compiler can eventually
1677     // optimize them all away, that takes time and generates bloated code. So, we only make
1678     // parameter copies if the variable is actually written-to.
1679     STArray<32, bool> paramNeedsDedicatedStorage;
1680     paramNeedsDedicatedStorage.push_back_n(decl.parameters().size(), true);
1681 
1682     for (size_t index = 0; index < decl.parameters().size(); ++index) {
1683         const Variable& param = *decl.parameters()[index];
1684         if (param.type().isOpaque() || param.name().empty()) {
1685             // Opaque-typed or anonymous parameters don't need dedicated storage.
1686             paramNeedsDedicatedStorage[index] = false;
1687             continue;
1688         }
1689 
1690         const ProgramUsage::VariableCounts counts = fProgram.fUsage->get(param);
1691         if ((param.modifierFlags() & ModifierFlag::kOut) || counts.fWrite == 0) {
1692             // Variables which are never written-to don't need dedicated storage.
1693             // Out-parameters are passed as pointers; the pointer itself is never modified, so
1694             // it doesn't need dedicated storage.
1695             paramNeedsDedicatedStorage[index] = false;
1696         }
1697     }
1698 
1699     this->writeFunctionDeclaration(decl, paramNeedsDedicatedStorage);
1700     this->writeLine(" {");
1701     ++fIndentation;
1702 
1703     // The parameters were given generic names like `_skParam1`, because WGSL parameters don't have
1704     // storage and are immutable. If mutability is required, we create variables here; otherwise, we
1705     // create properly-named `let` aliases.
1706     for (size_t index = 0; index < decl.parameters().size(); ++index) {
1707         if (paramNeedsDedicatedStorage[index]) {
1708             const Variable& param = *decl.parameters()[index];
1709             this->write("var ");
1710             this->write(this->assembleName(param.mangledName()));
1711             this->write(" = _skParam");
1712             this->write(std::to_string(index));
1713             this->writeLine(";");
1714         }
1715     }
1716 
1717     this->writeBlock(f.body()->as<Block>());
1718 
1719     // If fConditionalScopeDepth isn't zero, we have an unbalanced +1 or -1 when updating the depth.
1720     SkASSERT(fConditionalScopeDepth == 0);
1721     if (!fHasUnconditionalReturn && !f.declaration().returnType().isVoid()) {
1722         this->write("return ");
1723         this->write(to_wgsl_type(fContext, f.declaration().returnType()));
1724         this->writeLine("();");
1725     }
1726 
1727     --fIndentation;
1728     this->writeLine("}");
1729 
1730     SkASSERT(fAtFunctionScope);
1731     fAtFunctionScope = false;
1732 }
1733 
writeFunctionDeclaration(const FunctionDeclaration & decl,SkSpan<const bool> paramNeedsDedicatedStorage)1734 void WGSLCodeGenerator::writeFunctionDeclaration(const FunctionDeclaration& decl,
1735                                                  SkSpan<const bool> paramNeedsDedicatedStorage) {
1736     this->write("fn ");
1737     if (decl.isMain()) {
1738         this->write("_skslMain(");
1739     } else {
1740         this->write(this->assembleName(decl.mangledName()));
1741         this->write("(");
1742     }
1743     auto separator = SkSL::String::Separator();
1744     if (this->writeFunctionDependencyParams(decl)) {
1745         separator();  // update the separator as parameters have been written
1746     }
1747     for (size_t index = 0; index < decl.parameters().size(); ++index) {
1748         this->write(separator());
1749 
1750         const Variable& param = *decl.parameters()[index];
1751         if (param.type().isOpaque()) {
1752             SkASSERT(!paramNeedsDedicatedStorage[index]);
1753             if (param.type().isSampler()) {
1754                 // Create parameters for both the texture and associated sampler.
1755                 this->write(param.name());
1756                 this->write(kTextureSuffix);
1757                 this->write(": texture_2d<f32>, ");
1758                 this->write(param.name());
1759                 this->write(kSamplerSuffix);
1760                 this->write(": sampler");
1761             } else {
1762                 // Create a parameter for the opaque object.
1763                 this->write(param.name());
1764                 this->write(": ");
1765                 this->write(to_wgsl_type(fContext, param.type(), &param.layout()));
1766             }
1767         } else {
1768             if (paramNeedsDedicatedStorage[index] || param.name().empty()) {
1769                 // Create an unnamed parameter. If the parameter needs dedicated storage, it will
1770                 // later be assigned a `var` in the function body. (If it's anonymous, a var isn't
1771                 // needed.)
1772                 this->write("_skParam");
1773                 this->write(std::to_string(index));
1774             } else {
1775                 // Use the name directly from the SkSL program.
1776                 this->write(this->assembleName(param.name()));
1777             }
1778             this->write(": ");
1779             if (param.modifierFlags() & ModifierFlag::kOut) {
1780                 // Declare an "out" function parameter as a pointer.
1781                 this->write(to_ptr_type(fContext, param.type(), &param.layout()));
1782             } else {
1783                 this->write(to_wgsl_type(fContext, param.type(), &param.layout()));
1784             }
1785         }
1786     }
1787     this->write(")");
1788     if (!decl.returnType().isVoid()) {
1789         this->write(" -> ");
1790         this->write(to_wgsl_type(fContext, decl.returnType()));
1791     }
1792 }
1793 
writeEntryPoint(const FunctionDefinition & main)1794 void WGSLCodeGenerator::writeEntryPoint(const FunctionDefinition& main) {
1795     SkASSERT(main.declaration().isMain());
1796     const ProgramKind programKind = fProgram.fConfig->fKind;
1797 
1798 #if defined(SKSL_STANDALONE)
1799     if (ProgramConfig::IsRuntimeShader(programKind)) {
1800         // Synthesize a basic entrypoint which just calls straight through to main.
1801         // This is only used by skslc and just needs to pass the WGSL validator; Skia won't ever
1802         // emit functions like this.
1803         this->writeLine("@fragment fn main(@location(0) _coords: vec2<f32>) -> "
1804                                      "@location(0) vec4<f32> {");
1805         ++fIndentation;
1806         this->writeLine("return _skslMain(_coords);");
1807         --fIndentation;
1808         this->writeLine("}");
1809         return;
1810     }
1811 #endif
1812 
1813     // The input and output parameters for a vertex/fragment stage entry point function have the
1814     // FSIn/FSOut/VSIn/VSOut/CSIn struct types that have been synthesized in generateCode(). An
1815     // entrypoint always has a predictable signature and acts as a trampoline to the user-defined
1816     // main function.
1817     if (ProgramConfig::IsVertex(programKind)) {
1818         this->write("@vertex");
1819     } else if (ProgramConfig::IsFragment(programKind)) {
1820         this->write("@fragment");
1821     } else if (ProgramConfig::IsCompute(programKind)) {
1822         this->write("@compute @workgroup_size(");
1823         this->write(std::to_string(fLocalSizeX));
1824         this->write(", ");
1825         this->write(std::to_string(fLocalSizeY));
1826         this->write(", ");
1827         this->write(std::to_string(fLocalSizeZ));
1828         this->write(")");
1829     } else {
1830         fContext.fErrors->error(Position(), "program kind not supported");
1831         return;
1832     }
1833 
1834     this->write(" fn main(");
1835     // The stage input struct is a parameter passed to main().
1836     if (this->needsStageInputStruct()) {
1837         this->write("_stageIn: ");
1838         this->write(pipeline_struct_prefix(programKind));
1839         this->write("In");
1840     }
1841     // The stage output struct is returned from main().
1842     if (this->needsStageOutputStruct()) {
1843         this->write(") -> ");
1844         this->write(pipeline_struct_prefix(programKind));
1845         this->writeLine("Out {");
1846     } else {
1847         this->writeLine(") {");
1848     }
1849     // Initialize polyfilled matrix uniforms if any were used.
1850     fIndentation++;
1851     for (const auto& [field, info] : fFieldPolyfillMap) {
1852         if (info.fWasAccessed) {
1853             this->writeLine("_skInitializePolyfilledUniforms();");
1854             break;
1855         }
1856     }
1857     // Declare the stage output struct.
1858     if (this->needsStageOutputStruct()) {
1859         this->write("var _stageOut: ");
1860         this->write(pipeline_struct_prefix(programKind));
1861         this->writeLine("Out;");
1862     }
1863 
1864 #if defined(SKSL_STANDALONE)
1865     // We are compiling a Runtime Effect as a fragment shader, for testing purposes. We assign the
1866     // result from _skslMain into sk_FragColor if the user-defined main returns a color. This
1867     // doesn't actually matter, but it is more indicative of what a real program would do.
1868     // `addImplicitFragColorWrite` from Transform::FindAndDeclareBuiltinVariables has already
1869     // injected sk_FragColor into our stage outputs even if it wasn't explicitly referenced.
1870     if (ProgramConfig::IsFragment(programKind)) {
1871         if (main.declaration().returnType().matches(*fContext.fTypes.fHalf4)) {
1872             this->write("_stageOut.sk_FragColor = ");
1873         }
1874     }
1875 #endif
1876 
1877     // Generate a function call to the user-defined main.
1878     this->write("_skslMain(");
1879     auto separator = SkSL::String::Separator();
1880     WGSLFunctionDependencies* deps = fRequirements.fDependencies.find(&main.declaration());
1881     if (deps) {
1882         if (*deps & WGSLFunctionDependency::kPipelineInputs) {
1883             this->write(separator());
1884             this->write("_stageIn");
1885         }
1886         if (*deps & WGSLFunctionDependency::kPipelineOutputs) {
1887             this->write(separator());
1888             this->write("&_stageOut");
1889         }
1890     }
1891 
1892 #if defined(SKSL_STANDALONE)
1893     if (const Variable* v = main.declaration().getMainCoordsParameter()) {
1894         // We are compiling a Runtime Effect as a fragment shader, for testing purposes.
1895         // We need to synthesize a coordinates parameter, but the coordinates don't matter.
1896         SkASSERT(ProgramConfig::IsFragment(programKind));
1897         const Type& type = v->type();
1898         if (!type.matches(*fContext.fTypes.fFloat2)) {
1899             fContext.fErrors->error(main.fPosition, "main function has unsupported parameter: " +
1900                                                     type.description());
1901             return;
1902         }
1903         this->write(separator());
1904         this->write("/*fragcoord*/ vec2<f32>()");
1905     }
1906 #endif
1907 
1908     this->writeLine(");");
1909 
1910     if (this->needsStageOutputStruct()) {
1911         // Return the stage output struct.
1912         this->writeLine("return _stageOut;");
1913     }
1914 
1915     fIndentation--;
1916     this->writeLine("}");
1917 }
1918 
writeStatement(const Statement & s)1919 void WGSLCodeGenerator::writeStatement(const Statement& s) {
1920     switch (s.kind()) {
1921         case Statement::Kind::kBlock:
1922             this->writeBlock(s.as<Block>());
1923             break;
1924         case Statement::Kind::kBreak:
1925             this->writeLine("break;");
1926             break;
1927         case Statement::Kind::kContinue:
1928             this->writeLine("continue;");
1929             break;
1930         case Statement::Kind::kDiscard:
1931             this->writeLine("discard;");
1932             break;
1933         case Statement::Kind::kDo:
1934             this->writeDoStatement(s.as<DoStatement>());
1935             break;
1936         case Statement::Kind::kExpression:
1937             this->writeExpressionStatement(*s.as<ExpressionStatement>().expression());
1938             break;
1939         case Statement::Kind::kFor:
1940             this->writeForStatement(s.as<ForStatement>());
1941             break;
1942         case Statement::Kind::kIf:
1943             this->writeIfStatement(s.as<IfStatement>());
1944             break;
1945         case Statement::Kind::kNop:
1946             this->writeLine(";");
1947             break;
1948         case Statement::Kind::kReturn:
1949             this->writeReturnStatement(s.as<ReturnStatement>());
1950             break;
1951         case Statement::Kind::kSwitch:
1952             this->writeSwitchStatement(s.as<SwitchStatement>());
1953             break;
1954         case Statement::Kind::kSwitchCase:
1955             SkDEBUGFAIL("switch-case statements should only be present inside a switch");
1956             break;
1957         case Statement::Kind::kVarDeclaration:
1958             this->writeVarDeclaration(s.as<VarDeclaration>());
1959             break;
1960     }
1961 }
1962 
writeStatements(const StatementArray & statements)1963 void WGSLCodeGenerator::writeStatements(const StatementArray& statements) {
1964     for (const auto& s : statements) {
1965         if (!s->isEmpty()) {
1966             this->writeStatement(*s);
1967             this->finishLine();
1968         }
1969     }
1970 }
1971 
writeBlock(const Block & b)1972 void WGSLCodeGenerator::writeBlock(const Block& b) {
1973     // Write scope markers if this block is a scope, or if the block is empty (since we need to emit
1974     // something here to make the code valid).
1975     bool isScope = b.isScope() || b.isEmpty();
1976     if (isScope) {
1977         this->writeLine("{");
1978         fIndentation++;
1979     }
1980     this->writeStatements(b.children());
1981     if (isScope) {
1982         fIndentation--;
1983         this->writeLine("}");
1984     }
1985 }
1986 
writeExpressionStatement(const Expression & expr)1987 void WGSLCodeGenerator::writeExpressionStatement(const Expression& expr) {
1988     // Any expression-related side effects must be emitted as separate statements when
1989     // `assembleExpression` is called.
1990     // The final result of the expression will be a variable, let-reference, or an expression with
1991     // no side effects (`foo + bar`). Discarding this result is safe, as the program never uses it.
1992     (void)this->assembleExpression(expr, Precedence::kStatement);
1993 }
1994 
writeDoStatement(const DoStatement & s)1995 void WGSLCodeGenerator::writeDoStatement(const DoStatement& s) {
1996     // Generate a loop structure like this:
1997     //   loop {
1998     //       body-statement;
1999     //       continuing {
2000     //           break if inverted-test-expression;
2001     //       }
2002     //   }
2003 
2004     ++fConditionalScopeDepth;
2005 
2006     std::unique_ptr<Expression> invertedTestExpr = PrefixExpression::Make(
2007             fContext, s.test()->fPosition, OperatorKind::LOGICALNOT, s.test()->clone());
2008 
2009     this->writeLine("loop {");
2010     fIndentation++;
2011     this->writeStatement(*s.statement());
2012     this->finishLine();
2013 
2014     this->writeLine("continuing {");
2015     fIndentation++;
2016     std::string breakIfExpr = this->assembleExpression(*invertedTestExpr, Precedence::kExpression);
2017     this->write("break if ");
2018     this->write(breakIfExpr);
2019     this->writeLine(";");
2020     fIndentation--;
2021     this->writeLine("}");
2022     fIndentation--;
2023     this->writeLine("}");
2024 
2025     --fConditionalScopeDepth;
2026 }
2027 
writeForStatement(const ForStatement & s)2028 void WGSLCodeGenerator::writeForStatement(const ForStatement& s) {
2029     // Generate a loop structure wrapped in an extra scope:
2030     //   {
2031     //     initializer-statement;
2032     //     loop;
2033     //   }
2034     // The outer scope is necessary to prevent the initializer-variable from leaking out into the
2035     // rest of the code. In practice, the generated code actually tends to be scoped even more
2036     // deeply, as the body-statement almost always contributes an extra block.
2037 
2038     ++fConditionalScopeDepth;
2039 
2040     if (s.initializer()) {
2041         this->writeLine("{");
2042         fIndentation++;
2043         this->writeStatement(*s.initializer());
2044         this->writeLine();
2045     }
2046 
2047     this->writeLine("loop {");
2048     fIndentation++;
2049 
2050     if (s.unrollInfo()) {
2051         if (s.unrollInfo()->fCount <= 0) {
2052             // Loops which are known to never execute don't need to be emitted at all.
2053             // (However, the front end should have already replaced this loop with a Nop.)
2054         } else {
2055             // Loops which are known to execute at least once can use this form:
2056             //
2057             //     loop {
2058             //         body-statement;
2059             //         continuing {
2060             //             next-expression;
2061             //             break if inverted-test-expression;
2062             //         }
2063             //     }
2064 
2065             this->writeStatement(*s.statement());
2066             this->finishLine();
2067             this->writeLine("continuing {");
2068             ++fIndentation;
2069 
2070             if (s.next()) {
2071                 this->writeExpressionStatement(*s.next());
2072                 this->finishLine();
2073             }
2074 
2075             if (s.test()) {
2076                 std::unique_ptr<Expression> invertedTestExpr = PrefixExpression::Make(
2077                         fContext, s.test()->fPosition, OperatorKind::LOGICALNOT, s.test()->clone());
2078 
2079                 std::string breakIfExpr =
2080                         this->assembleExpression(*invertedTestExpr, Precedence::kExpression);
2081                 this->write("break if ");
2082                 this->write(breakIfExpr);
2083                 this->writeLine(";");
2084             }
2085 
2086             --fIndentation;
2087             this->writeLine("}");
2088         }
2089     } else {
2090         // Loops without a known execution count are emitted in this form:
2091         //
2092         //     loop {
2093         //         if test-expression {
2094         //             body-statement;
2095         //         } else {
2096         //             break;
2097         //         }
2098         //         continuing {
2099         //             next-expression;
2100         //         }
2101         //     }
2102 
2103         if (s.test()) {
2104             std::string testExpr = this->assembleExpression(*s.test(), Precedence::kExpression);
2105             this->write("if ");
2106             this->write(testExpr);
2107             this->writeLine(" {");
2108 
2109             fIndentation++;
2110             this->writeStatement(*s.statement());
2111             this->finishLine();
2112             fIndentation--;
2113 
2114             this->writeLine("} else {");
2115 
2116             fIndentation++;
2117             this->writeLine("break;");
2118             fIndentation--;
2119 
2120             this->writeLine("}");
2121         }
2122         else {
2123             this->writeStatement(*s.statement());
2124             this->finishLine();
2125         }
2126 
2127         if (s.next()) {
2128             this->writeLine("continuing {");
2129             fIndentation++;
2130             this->writeExpressionStatement(*s.next());
2131             this->finishLine();
2132             fIndentation--;
2133             this->writeLine("}");
2134         }
2135     }
2136 
2137     // This matches an open-brace at the top of the loop.
2138     fIndentation--;
2139     this->writeLine("}");
2140 
2141     if (s.initializer()) {
2142         // This matches an open-brace before the initializer-statement.
2143         fIndentation--;
2144         this->writeLine("}");
2145     }
2146 
2147     --fConditionalScopeDepth;
2148 }
2149 
writeIfStatement(const IfStatement & s)2150 void WGSLCodeGenerator::writeIfStatement(const IfStatement& s) {
2151     ++fConditionalScopeDepth;
2152 
2153     std::string testExpr = this->assembleExpression(*s.test(), Precedence::kExpression);
2154     this->write("if ");
2155     this->write(testExpr);
2156     this->writeLine(" {");
2157     fIndentation++;
2158     this->writeStatement(*s.ifTrue());
2159     this->finishLine();
2160     fIndentation--;
2161     if (s.ifFalse()) {
2162         this->writeLine("} else {");
2163         fIndentation++;
2164         this->writeStatement(*s.ifFalse());
2165         this->finishLine();
2166         fIndentation--;
2167     }
2168     this->writeLine("}");
2169 
2170     --fConditionalScopeDepth;
2171 }
2172 
writeReturnStatement(const ReturnStatement & s)2173 void WGSLCodeGenerator::writeReturnStatement(const ReturnStatement& s) {
2174     fHasUnconditionalReturn |= (fConditionalScopeDepth == 0);
2175 
2176     std::string expr = s.expression()
2177                                ? this->assembleExpression(*s.expression(), Precedence::kExpression)
2178                                : std::string();
2179     this->write("return ");
2180     this->write(expr);
2181     this->write(";");
2182 }
2183 
writeSwitchCaseList(SkSpan<const SwitchCase * const> cases)2184 void WGSLCodeGenerator::writeSwitchCaseList(SkSpan<const SwitchCase* const> cases) {
2185     auto separator = SkSL::String::Separator();
2186     for (const SwitchCase* const sc : cases) {
2187         this->write(separator());
2188         if (sc->isDefault()) {
2189             this->write("default");
2190         } else {
2191             this->write(std::to_string(sc->value()));
2192         }
2193     }
2194 }
2195 
writeSwitchCases(SkSpan<const SwitchCase * const> cases)2196 void WGSLCodeGenerator::writeSwitchCases(SkSpan<const SwitchCase* const> cases) {
2197     if (!cases.empty()) {
2198         // Only the last switch-case should have a non-empty statement.
2199         SkASSERT(std::all_of(cases.begin(), std::prev(cases.end()), [](const SwitchCase* sc) {
2200             return sc->statement()->isEmpty();
2201         }));
2202 
2203         // Emit the cases in a comma-separated list.
2204         this->write("case ");
2205         this->writeSwitchCaseList(cases);
2206         this->writeLine(" {");
2207         ++fIndentation;
2208 
2209         // Emit the switch-case body.
2210         this->writeStatement(*cases.back()->statement());
2211         this->finishLine();
2212 
2213         --fIndentation;
2214         this->writeLine("}");
2215     }
2216 }
2217 
writeEmulatedSwitchFallthroughCases(SkSpan<const SwitchCase * const> cases,std::string_view switchValue)2218 void WGSLCodeGenerator::writeEmulatedSwitchFallthroughCases(SkSpan<const SwitchCase* const> cases,
2219                                                             std::string_view switchValue) {
2220     // There's no need for fallthrough handling unless we actually have multiple case blocks.
2221     if (cases.size() < 2) {
2222         this->writeSwitchCases(cases);
2223         return;
2224     }
2225 
2226     // Match against the entire case group.
2227     this->write("case ");
2228     this->writeSwitchCaseList(cases);
2229     this->writeLine(" {");
2230     ++fIndentation;
2231 
2232     std::string fallthroughVar = this->writeScratchVar(*fContext.fTypes.fBool, "false");
2233     const size_t secondToLastCaseIndex = cases.size() - 2;
2234     const size_t lastCaseIndex = cases.size() - 1;
2235 
2236     for (size_t index = 0; index < cases.size(); ++index) {
2237         const SwitchCase& sc = *cases[index];
2238         if (index < lastCaseIndex) {
2239             // The default case must come last in SkSL, and this case isn't the last one, so it
2240             // can't possibly be the default.
2241             SkASSERT(!sc.isDefault());
2242 
2243             this->write("if ");
2244             if (index > 0) {
2245                 this->write(fallthroughVar);
2246                 this->write(" || ");
2247             }
2248             this->write(switchValue);
2249             this->write(" == ");
2250             this->write(std::to_string(sc.value()));
2251             this->writeLine(" {");
2252             fIndentation++;
2253 
2254             // We write the entire case-block statement here, and then set `switchFallthrough`
2255             // to 1. If the case-block had a break statement in it, we break out of the outer
2256             // for-loop entirely, meaning the `switchFallthrough` assignment never occurs, nor
2257             // does any code after it inside the switch. We've forbidden `continue` statements
2258             // inside switch case-blocks entirely, so we don't need to consider their effect on
2259             // control flow; see the Finalizer in FunctionDefinition::Convert.
2260             this->writeStatement(*sc.statement());
2261             this->finishLine();
2262 
2263             if (index < secondToLastCaseIndex) {
2264                 // Set a variable to indicate falling through to the next block. The very last
2265                 // case-block is reached by process of elimination and doesn't need this
2266                 // variable, so we don't actually need to set it if we are on the second-to-last
2267                 // case block.
2268                 this->write(fallthroughVar);
2269                 this->write(" = true;  ");
2270             }
2271             this->writeLine("// fallthrough");
2272 
2273             fIndentation--;
2274             this->writeLine("}");
2275         } else {
2276             // This is the final case. Since it's always last, we can just dump in the code.
2277             // (If we didn't match any of the other values, we must have matched this one by
2278             // process of elimination. If we did match one of the other values, we either hit a
2279             // `break` statement earlier--and won't get this far--or we're falling through.)
2280             this->writeStatement(*sc.statement());
2281             this->finishLine();
2282         }
2283     }
2284 
2285     --fIndentation;
2286     this->writeLine("}");
2287 }
2288 
writeSwitchStatement(const SwitchStatement & s)2289 void WGSLCodeGenerator::writeSwitchStatement(const SwitchStatement& s) {
2290     // WGSL supports the `switch` statement in a limited capacity. A default case must always be
2291     // specified. Each switch-case must be scoped inside braces. Fallthrough is not supported; a
2292     // trailing break is implied at the end of each switch-case block. (Explicit breaks are also
2293     // allowed.)  One minor improvement over a traditional switch is that switch-cases take a list
2294     // of values to match, instead of a single value:
2295     //   case 1, 2       { foo(); }
2296     //   case 3, default { bar(); }
2297     //
2298     // We will use the native WGSL switch statement for any switch-cases in the SkSL which can be
2299     // made to conform to these limitations. The remaining cases which cannot conform will be
2300     // emulated with if-else blocks (similar to our GLSL ES2 switch-statement emulation path). This
2301     // should give us good performance in the common case, since most switches naturally conform.
2302 
2303     // First, let's emit the switch itself.
2304     std::string valueExpr = this->writeNontrivialScratchLet(*s.value(), Precedence::kExpression);
2305     this->write("switch ");
2306     this->write(valueExpr);
2307     this->writeLine(" {");
2308     ++fIndentation;
2309 
2310     // Now let's go through the switch-cases, and emit the ones that don't fall through.
2311     TArray<const SwitchCase*> nativeCases;
2312     TArray<const SwitchCase*> fallthroughCases;
2313     bool previousCaseFellThrough = false;
2314     bool foundNativeDefault = false;
2315     [[maybe_unused]] bool foundFallthroughDefault = false;
2316 
2317     const int lastSwitchCaseIdx = s.cases().size() - 1;
2318     for (int index = 0; index <= lastSwitchCaseIdx; ++index) {
2319         const SwitchCase& sc = s.cases()[index]->as<SwitchCase>();
2320 
2321         if (sc.statement()->isEmpty()) {
2322             // This is a `case X:` that immediately falls through to the next case.
2323             // If we aren't already falling through, we can handle this via a comma-separated list.
2324             if (previousCaseFellThrough) {
2325                 fallthroughCases.push_back(&sc);
2326                 foundFallthroughDefault |= sc.isDefault();
2327             } else {
2328                 nativeCases.push_back(&sc);
2329                 foundNativeDefault |= sc.isDefault();
2330             }
2331             continue;
2332         }
2333 
2334         if (index == lastSwitchCaseIdx || Analysis::SwitchCaseContainsUnconditionalExit(sc)) {
2335             // This is a `case X:` that never falls through.
2336             if (previousCaseFellThrough) {
2337                 // Because the previous cases fell through, we can't use a native switch-case here.
2338                 fallthroughCases.push_back(&sc);
2339                 foundFallthroughDefault |= sc.isDefault();
2340 
2341                 this->writeEmulatedSwitchFallthroughCases(fallthroughCases, valueExpr);
2342                 fallthroughCases.clear();
2343 
2344                 // Fortunately, we're no longer falling through blocks, so we might be able to use a
2345                 // native switch-case list again.
2346                 previousCaseFellThrough = false;
2347             } else {
2348                 // Emit a native switch-case block with a comma-separated case list.
2349                 nativeCases.push_back(&sc);
2350                 foundNativeDefault |= sc.isDefault();
2351 
2352                 this->writeSwitchCases(nativeCases);
2353                 nativeCases.clear();
2354             }
2355             continue;
2356         }
2357 
2358         // This case falls through, so it will need to be handled via emulation.
2359         // If we have put together a collection of "native" cases (cases that fall through with no
2360         // actual case-body), we will need to slide them over into the fallthrough-case list.
2361         fallthroughCases.push_back_n(nativeCases.size(), nativeCases.data());
2362         nativeCases.clear();
2363 
2364         fallthroughCases.push_back(&sc);
2365         foundFallthroughDefault |= sc.isDefault();
2366         previousCaseFellThrough = true;
2367     }
2368 
2369     // Finish out the remaining switch-cases.
2370     this->writeSwitchCases(nativeCases);
2371     nativeCases.clear();
2372 
2373     this->writeEmulatedSwitchFallthroughCases(fallthroughCases, valueExpr);
2374     fallthroughCases.clear();
2375 
2376     // WGSL requires a default case.
2377     if (!foundNativeDefault && !foundFallthroughDefault) {
2378         this->writeLine("case default {}");
2379     }
2380 
2381     --fIndentation;
2382     this->writeLine("}");
2383 }
2384 
writeVarDeclaration(const VarDeclaration & varDecl)2385 void WGSLCodeGenerator::writeVarDeclaration(const VarDeclaration& varDecl) {
2386     std::string initialValue =
2387             varDecl.value() ? this->assembleExpression(*varDecl.value(), Precedence::kAssignment)
2388                             : std::string();
2389 
2390     if (varDecl.var()->modifierFlags().isConst()) {
2391         // Use `const` at global scope, or if the value is a compile-time constant.
2392         SkASSERTF(varDecl.value(), "a constant variable must specify a value");
2393         this->write((!fAtFunctionScope || Analysis::IsCompileTimeConstant(*varDecl.value()))
2394                             ? "const "
2395                             : "let ");
2396     } else {
2397         this->write("var ");
2398     }
2399     this->write(this->assembleName(varDecl.var()->mangledName()));
2400     this->write(": ");
2401     this->write(to_wgsl_type(fContext, varDecl.var()->type(), &varDecl.var()->layout()));
2402 
2403     if (varDecl.value()) {
2404         this->write(" = ");
2405         this->write(initialValue);
2406     }
2407 
2408     this->write(";");
2409 }
2410 
makeLValue(const Expression & e)2411 std::unique_ptr<WGSLCodeGenerator::LValue> WGSLCodeGenerator::makeLValue(const Expression& e) {
2412     if (e.is<VariableReference>()) {
2413         return std::make_unique<PointerLValue>(
2414                 this->variableReferenceNameForLValue(e.as<VariableReference>()));
2415     }
2416     if (e.is<FieldAccess>()) {
2417         return std::make_unique<PointerLValue>(this->assembleFieldAccess(e.as<FieldAccess>()));
2418     }
2419     if (e.is<IndexExpression>()) {
2420         const IndexExpression& idx = e.as<IndexExpression>();
2421         if (idx.base()->type().isVector()) {
2422             // Rewrite indexed-swizzle accesses like `myVec.zyx[i]` into an index onto `myVec`.
2423             if (std::unique_ptr<Expression> rewrite =
2424                         Transform::RewriteIndexedSwizzle(fContext, idx)) {
2425                 return std::make_unique<VectorComponentLValue>(
2426                         this->assembleExpression(*rewrite, Precedence::kAssignment));
2427             } else {
2428                 return std::make_unique<VectorComponentLValue>(this->assembleIndexExpression(idx));
2429             }
2430         } else {
2431             return std::make_unique<PointerLValue>(this->assembleIndexExpression(idx));
2432         }
2433     }
2434     if (e.is<Swizzle>()) {
2435         const Swizzle& swizzle = e.as<Swizzle>();
2436         if (swizzle.components().size() == 1) {
2437             return std::make_unique<VectorComponentLValue>(this->assembleSwizzle(swizzle));
2438         } else {
2439             return std::make_unique<SwizzleLValue>(
2440                     fContext,
2441                     this->assembleExpression(*swizzle.base(), Precedence::kAssignment),
2442                     swizzle.base()->type(),
2443                     swizzle.components());
2444         }
2445     }
2446 
2447     fContext.fErrors->error(e.fPosition, "unsupported lvalue type");
2448     return nullptr;
2449 }
2450 
assembleExpression(const Expression & e,Precedence parentPrecedence)2451 std::string WGSLCodeGenerator::assembleExpression(const Expression& e,
2452                                                   Precedence parentPrecedence) {
2453     switch (e.kind()) {
2454         case Expression::Kind::kBinary:
2455             return this->assembleBinaryExpression(e.as<BinaryExpression>(), parentPrecedence);
2456 
2457         case Expression::Kind::kConstructorCompound:
2458             return this->assembleConstructorCompound(e.as<ConstructorCompound>());
2459 
2460         case Expression::Kind::kConstructorArrayCast:
2461             // This is a no-op, since WGSL 1.0 doesn't have any concept of precision qualifiers.
2462             // When we add support for f16, this will need to copy the array contents.
2463             return this->assembleExpression(*e.as<ConstructorArrayCast>().argument(),
2464                                             parentPrecedence);
2465 
2466         case Expression::Kind::kConstructorArray:
2467         case Expression::Kind::kConstructorCompoundCast:
2468         case Expression::Kind::kConstructorScalarCast:
2469         case Expression::Kind::kConstructorSplat:
2470         case Expression::Kind::kConstructorStruct:
2471             return this->assembleAnyConstructor(e.asAnyConstructor());
2472 
2473         case Expression::Kind::kConstructorDiagonalMatrix:
2474             return this->assembleConstructorDiagonalMatrix(e.as<ConstructorDiagonalMatrix>());
2475 
2476         case Expression::Kind::kConstructorMatrixResize:
2477             return this->assembleConstructorMatrixResize(e.as<ConstructorMatrixResize>());
2478 
2479         case Expression::Kind::kEmpty:
2480             return "false";
2481 
2482         case Expression::Kind::kFieldAccess:
2483             return this->assembleFieldAccess(e.as<FieldAccess>());
2484 
2485         case Expression::Kind::kFunctionCall:
2486             return this->assembleFunctionCall(e.as<FunctionCall>(), parentPrecedence);
2487 
2488         case Expression::Kind::kIndex:
2489             return this->assembleIndexExpression(e.as<IndexExpression>());
2490 
2491         case Expression::Kind::kLiteral:
2492             return this->assembleLiteral(e.as<Literal>());
2493 
2494         case Expression::Kind::kPrefix:
2495             return this->assemblePrefixExpression(e.as<PrefixExpression>(), parentPrecedence);
2496 
2497         case Expression::Kind::kPostfix:
2498             return this->assemblePostfixExpression(e.as<PostfixExpression>(), parentPrecedence);
2499 
2500         case Expression::Kind::kSetting:
2501             return this->assembleExpression(*e.as<Setting>().toLiteral(fCaps), parentPrecedence);
2502 
2503         case Expression::Kind::kSwizzle:
2504             return this->assembleSwizzle(e.as<Swizzle>());
2505 
2506         case Expression::Kind::kTernary:
2507             return this->assembleTernaryExpression(e.as<TernaryExpression>(), parentPrecedence);
2508 
2509         case Expression::Kind::kVariableReference:
2510             return this->assembleVariableReference(e.as<VariableReference>());
2511 
2512         default:
2513             SkDEBUGFAILF("unsupported expression:\n%s", e.description().c_str());
2514             return {};
2515     }
2516 }
2517 
is_nontrivial_expression(const Expression & expr)2518 static bool is_nontrivial_expression(const Expression& expr) {
2519     // We consider a "trivial expression" one which we can repeat multiple times in the output
2520     // without being dangerous or spammy. We avoid emitting temporary variables for very trivial
2521     // expressions: literals, unadorned variable references, or constant vectors.
2522     if (expr.is<VariableReference>() || expr.is<Literal>()) {
2523         // Variables and literals are trivial; adding a let-declaration won't simplify anything.
2524         return false;
2525     }
2526     if (expr.type().isVector() && Analysis::IsConstantExpression(expr)) {
2527         // Compile-time constant vectors are also considered trivial; they're short and sweet.
2528         return false;
2529     }
2530     return true;
2531 }
2532 
binary_op_is_ambiguous_in_wgsl(Operator op)2533 static bool binary_op_is_ambiguous_in_wgsl(Operator op) {
2534     // WGSL always requires parentheses for some operators which are deemed to be ambiguous.
2535     // (8.19. Operator Precedence and Associativity)
2536     switch (op.kind()) {
2537         case OperatorKind::LOGICALOR:
2538         case OperatorKind::LOGICALAND:
2539         case OperatorKind::BITWISEOR:
2540         case OperatorKind::BITWISEAND:
2541         case OperatorKind::BITWISEXOR:
2542         case OperatorKind::SHL:
2543         case OperatorKind::SHR:
2544         case OperatorKind::LT:
2545         case OperatorKind::GT:
2546         case OperatorKind::LTEQ:
2547         case OperatorKind::GTEQ:
2548             return true;
2549 
2550         default:
2551             return false;
2552     }
2553 }
2554 
binaryOpNeedsComponentwiseMatrixPolyfill(const Type & left,const Type & right,Operator op)2555 bool WGSLCodeGenerator::binaryOpNeedsComponentwiseMatrixPolyfill(const Type& left,
2556                                                                  const Type& right,
2557                                                                  Operator op) {
2558     switch (op.kind()) {
2559         case OperatorKind::SLASH:
2560             // WGSL does not natively support componentwise matrix-op-matrix for division.
2561             if (left.isMatrix() && right.isMatrix()) {
2562                 return true;
2563             }
2564             [[fallthrough]];
2565 
2566         case OperatorKind::PLUS:
2567         case OperatorKind::MINUS:
2568             // WGSL does not natively support componentwise matrix-op-scalar or scalar-op-matrix for
2569             // addition, subtraction or division.
2570             return (left.isMatrix() && right.isScalar()) ||
2571                    (left.isScalar() && right.isMatrix());
2572 
2573         default:
2574             return false;
2575     }
2576 }
2577 
assembleBinaryExpression(const BinaryExpression & b,Precedence parentPrecedence)2578 std::string WGSLCodeGenerator::assembleBinaryExpression(const BinaryExpression& b,
2579                                                         Precedence parentPrecedence) {
2580     return this->assembleBinaryExpression(*b.left(), b.getOperator(), *b.right(), b.type(),
2581                                           parentPrecedence);
2582 }
2583 
assembleBinaryExpression(const Expression & left,Operator op,const Expression & right,const Type & resultType,Precedence parentPrecedence)2584 std::string WGSLCodeGenerator::assembleBinaryExpression(const Expression& left,
2585                                                         Operator op,
2586                                                         const Expression& right,
2587                                                         const Type& resultType,
2588                                                         Precedence parentPrecedence) {
2589     // If the operator is && or ||, we need to handle short-circuiting properly. Specifically, we
2590     // sometimes need to emit extra statements to paper over functionality that WGSL lacks, like
2591     // assignment in the middle of an expression. We need to guard those extra statements, to ensure
2592     // that they don't occur if the expression evaluation is short-circuited. Converting the
2593     // expression into an if-else block keeps the short-circuit property intact even when extra
2594     // statements are involved.
2595     // If the RHS doesn't have any side effects, then it's safe to just leave the expression as-is,
2596     // since we know that any possible extra statements are non-side-effecting.
2597     std::string expr;
2598     if (op.kind() == OperatorKind::LOGICALAND && Analysis::HasSideEffects(right)) {
2599         // Converts `left_expression && right_expression` into the following block:
2600 
2601         // var _skTemp1: bool;
2602         // [[ prepare left_expression ]]
2603         // if left_expression {
2604         //     [[ prepare right_expression ]]
2605         //     _skTemp1 = right_expression;
2606         // } else {
2607         //     _skTemp1 = false;
2608         // }
2609 
2610         expr = this->writeScratchVar(resultType);
2611 
2612         std::string leftExpr = this->assembleExpression(left, Precedence::kExpression);
2613         this->write("if ");
2614         this->write(leftExpr);
2615         this->writeLine(" {");
2616 
2617         ++fIndentation;
2618         std::string rightExpr = this->assembleExpression(right, Precedence::kAssignment);
2619         this->write(expr);
2620         this->write(" = ");
2621         this->write(rightExpr);
2622         this->writeLine(";");
2623         --fIndentation;
2624 
2625         this->writeLine("} else {");
2626 
2627         ++fIndentation;
2628         this->write(expr);
2629         this->writeLine(" = false;");
2630         --fIndentation;
2631 
2632         this->writeLine("}");
2633         return expr;
2634     }
2635 
2636     if (op.kind() == OperatorKind::LOGICALOR && Analysis::HasSideEffects(right)) {
2637         // Converts `left_expression || right_expression` into the following block:
2638 
2639         // var _skTemp1: bool;
2640         // [[ prepare left_expression ]]
2641         // if left_expression {
2642         //     _skTemp1 = true;
2643         // } else {
2644         //     [[ prepare right_expression ]]
2645         //     _skTemp1 = right_expression;
2646         // }
2647 
2648         expr = this->writeScratchVar(resultType);
2649 
2650         std::string leftExpr = this->assembleExpression(left, Precedence::kExpression);
2651         this->write("if ");
2652         this->write(leftExpr);
2653         this->writeLine(" {");
2654 
2655         ++fIndentation;
2656         this->write(expr);
2657         this->writeLine(" = true;");
2658         --fIndentation;
2659 
2660         this->writeLine("} else {");
2661 
2662         ++fIndentation;
2663         std::string rightExpr = this->assembleExpression(right, Precedence::kAssignment);
2664         this->write(expr);
2665         this->write(" = ");
2666         this->write(rightExpr);
2667         this->writeLine(";");
2668         --fIndentation;
2669 
2670         this->writeLine("}");
2671         return expr;
2672     }
2673 
2674     // Handle comma-expressions.
2675     if (op.kind() == OperatorKind::COMMA) {
2676         // The result from the left-expression is ignored, but its side effects must occur.
2677         this->assembleExpression(left, Precedence::kStatement);
2678 
2679         // Evaluate the right side normally.
2680         return this->assembleExpression(right, parentPrecedence);
2681     }
2682 
2683     // Handle assignment-expressions.
2684     if (op.isAssignment()) {
2685         std::unique_ptr<LValue> lvalue = this->makeLValue(left);
2686         if (!lvalue) {
2687             return "";
2688         }
2689 
2690         if (op.kind() == OperatorKind::EQ) {
2691             // Evaluate the right-hand side of simple assignment (`a = b` --> `b`).
2692             expr = this->assembleExpression(right, Precedence::kAssignment);
2693         } else {
2694             // Evaluate the right-hand side of compound-assignment (`a += b` --> `a + b`).
2695             op = op.removeAssignment();
2696 
2697             std::string lhs = lvalue->load();
2698             std::string rhs = this->assembleExpression(right, op.getBinaryPrecedence());
2699 
2700             if (this->binaryOpNeedsComponentwiseMatrixPolyfill(left.type(), right.type(), op)) {
2701                 if (is_nontrivial_expression(right)) {
2702                     rhs = this->writeScratchLet(rhs);
2703                 }
2704 
2705                 expr = this->assembleComponentwiseMatrixBinary(left.type(), right.type(),
2706                                                                lhs, rhs, op);
2707             } else {
2708                 expr = lhs + operator_name(op) + rhs;
2709             }
2710         }
2711 
2712         // Emit the assignment statement (`a = a + b`).
2713         this->writeLine(lvalue->store(expr));
2714 
2715         // Return the lvalue (`a`) as the result, since the value might be used by the caller.
2716         return lvalue->load();
2717     }
2718 
2719     if (op.isEquality()) {
2720         return this->assembleEqualityExpression(left, right, op, parentPrecedence);
2721     }
2722 
2723     Precedence precedence = op.getBinaryPrecedence();
2724     bool needParens = precedence >= parentPrecedence;
2725     if (binary_op_is_ambiguous_in_wgsl(op)) {
2726         precedence = Precedence::kParentheses;
2727     }
2728     if (needParens) {
2729         expr = "(";
2730     }
2731 
2732     // If we are emitting `constant + constant`, this generally indicates that the values could not
2733     // be constant-folded. This happens when the values overflow or become nan. WGSL will refuse to
2734     // compile such expressions, as WGSL 1.0 has no infinity/nan support. However, the WGSL
2735     // compile-time check can be dodged by putting one side into a let-variable. This technically
2736     // gives us an indeterminate result, but the vast majority of backends will just calculate an
2737     // infinity or nan here, as we would expect. (skia:14385)
2738     bool bothSidesConstant = ConstantFolder::GetConstantValueOrNull(left) &&
2739                              ConstantFolder::GetConstantValueOrNull(right);
2740 
2741     std::string lhs = this->assembleExpression(left, precedence);
2742     std::string rhs = this->assembleExpression(right, precedence);
2743 
2744     if (this->binaryOpNeedsComponentwiseMatrixPolyfill(left.type(), right.type(), op)) {
2745         if (bothSidesConstant || is_nontrivial_expression(left)) {
2746             lhs = this->writeScratchLet(lhs);
2747         }
2748         if (is_nontrivial_expression(right)) {
2749             rhs = this->writeScratchLet(rhs);
2750         }
2751 
2752         expr += this->assembleComponentwiseMatrixBinary(left.type(), right.type(), lhs, rhs, op);
2753     } else {
2754         if (bothSidesConstant) {
2755             lhs = this->writeScratchLet(lhs);
2756         }
2757 
2758         expr += lhs + operator_name(op) + rhs;
2759     }
2760 
2761     if (needParens) {
2762         expr += ')';
2763     }
2764 
2765     return expr;
2766 }
2767 
assembleFieldAccess(const FieldAccess & f)2768 std::string WGSLCodeGenerator::assembleFieldAccess(const FieldAccess& f) {
2769     const Field* field = &f.base()->type().fields()[f.fieldIndex()];
2770     std::string expr;
2771 
2772     if (FieldPolyfillInfo* polyfillInfo = fFieldPolyfillMap.find(field)) {
2773         // We found a matrix uniform. We are required to pass some matrix uniforms as array vectors,
2774         // since the std140 layout for a matrix assumes 4-column vectors for each row, and WGSL
2775         // tightly packs 2-column matrices. When emitting code, we replace the field-access
2776         // expression with a global variable which holds an unpacked version of the uniform.
2777         polyfillInfo->fWasAccessed = true;
2778 
2779         // The polyfill can either be based directly onto a uniform in an interface block, or it
2780         // might be based on an index-expression onto a uniform if the interface block is arrayed.
2781         const Expression* base = f.base().get();
2782         const IndexExpression* indexExpr = nullptr;
2783         if (base->is<IndexExpression>()) {
2784             indexExpr = &base->as<IndexExpression>();
2785             base = indexExpr->base().get();
2786         }
2787 
2788         SkASSERT(base->is<VariableReference>());
2789         expr = polyfillInfo->fReplacementName;
2790 
2791         // If we had an index expression, we must append the index.
2792         if (indexExpr) {
2793             expr += '[';
2794             expr += this->assembleExpression(*indexExpr->index(), Precedence::kSequence);
2795             expr += ']';
2796         }
2797         return expr;
2798     }
2799 
2800     switch (f.ownerKind()) {
2801         case FieldAccess::OwnerKind::kDefault:
2802             expr = this->assembleExpression(*f.base(), Precedence::kPostfix) + '.';
2803             break;
2804 
2805         case FieldAccess::OwnerKind::kAnonymousInterfaceBlock:
2806             if (f.base()->is<VariableReference>() &&
2807                 field->fLayout.fBuiltin != SK_POINTSIZE_BUILTIN) {
2808                 expr = this->variablePrefix(*f.base()->as<VariableReference>().variable());
2809             }
2810             break;
2811     }
2812 
2813     expr += this->assembleName(field->fName);
2814     return expr;
2815 }
2816 
all_arguments_constant(const ExpressionArray & arguments)2817 static bool all_arguments_constant(const ExpressionArray& arguments) {
2818     // Returns true if all arguments in the ExpressionArray are compile-time constants. If we are
2819     // calling an intrinsic and all of its inputs are constant, but we didn't constant-fold it, this
2820     // generally indicates that constant-folding resulted in an infinity or nan. The WGSL compiler
2821     // will reject such an expression with a compile-time error. We can dodge the error, taking on
2822     // the risk of indeterminate behavior instead, by replacing one of the constant values with a
2823     // scratch let-variable. (skia:14385)
2824     for (const std::unique_ptr<Expression>& arg : arguments) {
2825         if (!ConstantFolder::GetConstantValueOrNull(*arg)) {
2826             return false;
2827         }
2828     }
2829     return true;
2830 }
2831 
assembleSimpleIntrinsic(std::string_view intrinsicName,const FunctionCall & call)2832 std::string WGSLCodeGenerator::assembleSimpleIntrinsic(std::string_view intrinsicName,
2833                                                        const FunctionCall& call) {
2834     // Invoke the function, passing each function argument.
2835     std::string expr = std::string(intrinsicName);
2836     expr.push_back('(');
2837     const ExpressionArray& args = call.arguments();
2838     auto separator = SkSL::String::Separator();
2839     bool allConstant = all_arguments_constant(call.arguments());
2840     for (int index = 0; index < args.size(); ++index) {
2841         expr += separator();
2842 
2843         std::string argument = this->assembleExpression(*args[index], Precedence::kSequence);
2844         if (args[index]->type().isAtomic()) {
2845             // WGSL passes atomic values to intrinsics as pointers.
2846             expr += '&';
2847             expr += argument;
2848         } else if (allConstant && index == 0) {
2849             // We can use a scratch-let for argument 0 to dodge WGSL overflow errors. (skia:14385)
2850             expr += this->writeScratchLet(argument);
2851         } else {
2852             expr += argument;
2853         }
2854     }
2855     expr.push_back(')');
2856 
2857     if (call.type().isVoid()) {
2858         this->write(expr);
2859         this->writeLine(";");
2860         return std::string();
2861     } else {
2862         return this->writeScratchLet(expr);
2863     }
2864 }
2865 
assembleVectorizedIntrinsic(std::string_view intrinsicName,const FunctionCall & call)2866 std::string WGSLCodeGenerator::assembleVectorizedIntrinsic(std::string_view intrinsicName,
2867                                                            const FunctionCall& call) {
2868     SkASSERT(!call.type().isVoid());
2869 
2870     // Invoke the function, passing each function argument.
2871     std::string expr = std::string(intrinsicName);
2872     expr.push_back('(');
2873 
2874     auto separator = SkSL::String::Separator();
2875     const ExpressionArray& args = call.arguments();
2876     bool returnsVector = call.type().isVector();
2877     bool allConstant = all_arguments_constant(call.arguments());
2878     for (int index = 0; index < args.size(); ++index) {
2879         expr += separator();
2880 
2881         bool vectorize = returnsVector && args[index]->type().isScalar();
2882         if (vectorize) {
2883             expr += to_wgsl_type(fContext, call.type());
2884             expr.push_back('(');
2885         }
2886 
2887         // We can use a scratch-let for argument 0 to dodge WGSL overflow errors. (skia:14385)
2888         std::string argument = this->assembleExpression(*args[index], Precedence::kSequence);
2889         expr += (allConstant && index == 0) ? this->writeScratchLet(argument)
2890                                             : argument;
2891         if (vectorize) {
2892             expr.push_back(')');
2893         }
2894     }
2895     expr.push_back(')');
2896 
2897     return this->writeScratchLet(expr);
2898 }
2899 
assembleUnaryOpIntrinsic(Operator op,const FunctionCall & call,Precedence parentPrecedence)2900 std::string WGSLCodeGenerator::assembleUnaryOpIntrinsic(Operator op,
2901                                                         const FunctionCall& call,
2902                                                         Precedence parentPrecedence) {
2903     SkASSERT(!call.type().isVoid());
2904 
2905     bool needParens = Precedence::kPrefix >= parentPrecedence;
2906 
2907     std::string expr;
2908     if (needParens) {
2909         expr.push_back('(');
2910     }
2911 
2912     expr += operator_name(op);
2913     expr += this->assembleExpression(*call.arguments()[0], Precedence::kPrefix);
2914 
2915     if (needParens) {
2916         expr.push_back(')');
2917     }
2918 
2919     return expr;
2920 }
2921 
assembleBinaryOpIntrinsic(Operator op,const FunctionCall & call,Precedence parentPrecedence)2922 std::string WGSLCodeGenerator::assembleBinaryOpIntrinsic(Operator op,
2923                                                          const FunctionCall& call,
2924                                                          Precedence parentPrecedence) {
2925     SkASSERT(!call.type().isVoid());
2926 
2927     Precedence precedence = op.getBinaryPrecedence();
2928     bool needParens = precedence >= parentPrecedence ||
2929                       binary_op_is_ambiguous_in_wgsl(op);
2930     std::string expr;
2931     if (needParens) {
2932         expr.push_back('(');
2933     }
2934 
2935     // We can use a scratch-let for argument 0 to dodge WGSL overflow errors. (skia:14385)
2936     std::string argument = this->assembleExpression(*call.arguments()[0], precedence);
2937     expr += all_arguments_constant(call.arguments()) ? this->writeScratchLet(argument)
2938                                                      : argument;
2939     expr += operator_name(op);
2940     expr += this->assembleExpression(*call.arguments()[1], precedence);
2941 
2942     if (needParens) {
2943         expr.push_back(')');
2944     }
2945 
2946     return expr;
2947 }
2948 
2949 // Rewrite a WGSL intrinsic of the form "intrinsicName(in) -> struct" to the SkSL's
2950 // "intrinsicName(in, outField) -> returnField", where outField and returnField are the names of the
2951 // fields in the struct returned by the WGSL intrinsic.
assembleOutAssignedIntrinsic(std::string_view intrinsicName,std::string_view returnField,std::string_view outField,const FunctionCall & call)2952 std::string WGSLCodeGenerator::assembleOutAssignedIntrinsic(std::string_view intrinsicName,
2953                                                             std::string_view returnField,
2954                                                             std::string_view outField,
2955                                                             const FunctionCall& call) {
2956     SkASSERT(call.type().componentType().isNumber());
2957     SkASSERT(call.arguments().size() == 2);
2958     SkASSERT(call.function().parameters()[1]->modifierFlags() & ModifierFlag::kOut);
2959 
2960     std::string expr = std::string(intrinsicName);
2961     expr += "(";
2962 
2963     // Invoke the intrinsic with the first parameter. Use a scratch-let if argument is a constant
2964     // to dodge WGSL overflow errors. (skia:14385)
2965     std::string argument = this->assembleExpression(*call.arguments()[0], Precedence::kSequence);
2966     expr += ConstantFolder::GetConstantValueOrNull(*call.arguments()[0])
2967             ? this->writeScratchLet(argument) : argument;
2968     expr += ")";
2969     // In WGSL the intrinsic returns a struct; assign it to a local so that its fields can be
2970     // accessed multiple times.
2971     expr = this->writeScratchLet(expr);
2972     expr += ".";
2973 
2974     // Store the outField of `expr` to the intended "out" argument
2975     std::unique_ptr<LValue> lvalue = this->makeLValue(*call.arguments()[1]);
2976     if (!lvalue) {
2977         return "";
2978     }
2979     std::string outValue = expr;
2980     outValue += outField;
2981     this->writeLine(lvalue->store(outValue));
2982 
2983     // And return the expression accessing the returnField.
2984     expr += returnField;
2985     return expr;
2986 }
2987 
assemblePartialSampleCall(std::string_view functionName,const Expression & sampler,const Expression & coords)2988 std::string WGSLCodeGenerator::assemblePartialSampleCall(std::string_view functionName,
2989                                                          const Expression& sampler,
2990                                                          const Expression& coords) {
2991     // This function returns `functionName(inSampler_texture, inSampler_sampler, coords` without a
2992     // terminating comma or close-parenthesis. This allows the caller to add more arguments as
2993     // needed.
2994     SkASSERT(sampler.type().typeKind() == Type::TypeKind::kSampler);
2995     std::string expr = std::string(functionName) + '(';
2996     expr += this->assembleExpression(sampler, Precedence::kSequence);
2997     expr += kTextureSuffix;
2998     expr += ", ";
2999     expr += this->assembleExpression(sampler, Precedence::kSequence);
3000     expr += kSamplerSuffix;
3001     expr += ", ";
3002 
3003     // Compute the sample coordinates, dividing out the Z if a vec3 was provided.
3004     SkASSERT(coords.type().isVector());
3005     if (coords.type().columns() == 3) {
3006         // The coordinates were passed as a vec3, so we need to emit `coords.xy / coords.z`.
3007         std::string vec3Coords = this->writeScratchLet(coords, Precedence::kMultiplicative);
3008         expr += vec3Coords + ".xy / " + vec3Coords + ".z";
3009     } else {
3010         // The coordinates should be a plain vec2; emit the expression as-is.
3011         SkASSERT(coords.type().columns() == 2);
3012         expr += this->assembleExpression(coords, Precedence::kSequence);
3013     }
3014 
3015     return expr;
3016 }
3017 
assembleComponentwiseMatrixBinary(const Type & leftType,const Type & rightType,const std::string & left,const std::string & right,Operator op)3018 std::string WGSLCodeGenerator::assembleComponentwiseMatrixBinary(const Type& leftType,
3019                                                                  const Type& rightType,
3020                                                                  const std::string& left,
3021                                                                  const std::string& right,
3022                                                                  Operator op) {
3023     bool leftIsMatrix = leftType.isMatrix();
3024     bool rightIsMatrix = rightType.isMatrix();
3025     const Type& matrixType = leftIsMatrix ? leftType : rightType;
3026 
3027     std::string expr = to_wgsl_type(fContext, matrixType) + '(';
3028     auto separator = String::Separator();
3029     int columns = matrixType.columns();
3030     for (int c = 0; c < columns; ++c) {
3031         expr += separator();
3032         expr += left;
3033         if (leftIsMatrix) {
3034             expr += '[';
3035             expr += std::to_string(c);
3036             expr += ']';
3037         }
3038         expr += op.operatorName();
3039         expr += right;
3040         if (rightIsMatrix) {
3041             expr += '[';
3042             expr += std::to_string(c);
3043             expr += ']';
3044         }
3045     }
3046     return expr + ')';
3047 }
3048 
assembleIntrinsicCall(const FunctionCall & call,IntrinsicKind kind,Precedence parentPrecedence)3049 std::string WGSLCodeGenerator::assembleIntrinsicCall(const FunctionCall& call,
3050                                                      IntrinsicKind kind,
3051                                                      Precedence parentPrecedence) {
3052     // Be careful: WGSL 1.0 will reject any intrinsic calls which can be constant-evaluated to
3053     // infinity or nan with a compile error. If all arguments to an intrinsic are compile-time
3054     // constants (`all_arguments_constant`), it is safest to copy one argument into a scratch-let so
3055     // that the call will be seen as runtime-evaluated, which defuses the overflow checks.
3056     // Don't worry; a competent driver should still optimize it away.
3057 
3058     const ExpressionArray& arguments = call.arguments();
3059     switch (kind) {
3060         case k_atan_IntrinsicKind: {
3061             const char* name = (arguments.size() == 1) ? "atan" : "atan2";
3062             return this->assembleSimpleIntrinsic(name, call);
3063         }
3064         case k_dFdx_IntrinsicKind:
3065             return this->assembleSimpleIntrinsic("dpdx", call);
3066 
3067         case k_dFdy_IntrinsicKind:
3068             // TODO(b/294274678): apply RTFlip here
3069             return this->assembleSimpleIntrinsic("dpdy", call);
3070 
3071         case k_dot_IntrinsicKind: {
3072             if (arguments[0]->type().isScalar()) {
3073                 return this->assembleBinaryOpIntrinsic(OperatorKind::STAR, call, parentPrecedence);
3074             }
3075             return this->assembleSimpleIntrinsic("dot", call);
3076         }
3077         case k_equal_IntrinsicKind:
3078             return this->assembleBinaryOpIntrinsic(OperatorKind::EQEQ, call, parentPrecedence);
3079 
3080         case k_faceforward_IntrinsicKind: {
3081             if (arguments[0]->type().isScalar()) {
3082                 // select(-N, N, (I * Nref) < 0)
3083                 std::string N = this->writeNontrivialScratchLet(*arguments[0],
3084                                                                 Precedence::kAssignment);
3085                 return this->writeScratchLet(
3086                         "select(-" + N + ", " + N + ", " +
3087                         this->assembleBinaryExpression(*arguments[1],
3088                                                        OperatorKind::STAR,
3089                                                        *arguments[2],
3090                                                        arguments[1]->type(),
3091                                                        Precedence::kRelational) +
3092                         " < 0)");
3093             }
3094             return this->assembleSimpleIntrinsic("faceForward", call);
3095         }
3096         case k_frexp_IntrinsicKind:
3097             // SkSL frexp is "$genType fract = frexp($genType, out $genIType exp)" whereas WGSL
3098             // returns a struct with no out param: "let [fract, exp] = frexp($genType)".
3099             return this->assembleOutAssignedIntrinsic("frexp", "fract", "exp", call);
3100 
3101         case k_greaterThan_IntrinsicKind:
3102             return this->assembleBinaryOpIntrinsic(OperatorKind::GT, call, parentPrecedence);
3103 
3104         case k_greaterThanEqual_IntrinsicKind:
3105             return this->assembleBinaryOpIntrinsic(OperatorKind::GTEQ, call, parentPrecedence);
3106 
3107         case k_inverse_IntrinsicKind:
3108             return this->assembleInversePolyfill(call);
3109 
3110         case k_inversesqrt_IntrinsicKind:
3111             return this->assembleSimpleIntrinsic("inverseSqrt", call);
3112 
3113         case k_lessThan_IntrinsicKind:
3114             return this->assembleBinaryOpIntrinsic(OperatorKind::LT, call, parentPrecedence);
3115 
3116         case k_lessThanEqual_IntrinsicKind:
3117             return this->assembleBinaryOpIntrinsic(OperatorKind::LTEQ, call, parentPrecedence);
3118 
3119         case k_matrixCompMult_IntrinsicKind: {
3120             // We use a scratch-let for arg0 to avoid the potential for WGSL overflow. (skia:14385)
3121             std::string arg0 = all_arguments_constant(arguments)
3122                             ? this->writeScratchLet(*arguments[0], Precedence::kPostfix)
3123                             : this->writeNontrivialScratchLet(*arguments[0], Precedence::kPostfix);
3124             std::string arg1 = this->writeNontrivialScratchLet(*arguments[1], Precedence::kPostfix);
3125             return this->writeScratchLet(
3126                     this->assembleComponentwiseMatrixBinary(arguments[0]->type(),
3127                                                             arguments[1]->type(),
3128                                                             arg0,
3129                                                             arg1,
3130                                                             OperatorKind::STAR));
3131         }
3132         case k_mix_IntrinsicKind: {
3133             const char* name = arguments[2]->type().componentType().isBoolean() ? "select" : "mix";
3134             return this->assembleVectorizedIntrinsic(name, call);
3135         }
3136         case k_mod_IntrinsicKind: {
3137             // WGSL has no intrinsic equivalent to `mod`. Synthesize `x - y * floor(x / y)`.
3138             // We can use a scratch-let on one side to dodge WGSL overflow errors.  In practice, I
3139             // can't find any values of x or y which would overflow, but it can't hurt. (skia:14385)
3140             std::string arg0 = all_arguments_constant(arguments)
3141                             ? this->writeScratchLet(*arguments[0], Precedence::kAdditive)
3142                             : this->writeNontrivialScratchLet(*arguments[0], Precedence::kAdditive);
3143             std::string arg1 = this->writeNontrivialScratchLet(*arguments[1],
3144                                                                Precedence::kAdditive);
3145             return this->writeScratchLet(arg0 + " - " + arg1 + " * floor(" +
3146                                          arg0 + " / " + arg1 + ")");
3147         }
3148 
3149         case k_modf_IntrinsicKind:
3150             // SkSL modf is "$genType fract = modf($genType, out $genType whole)" whereas WGSL
3151             // returns a struct with no out param: "let [fract, whole] = modf($genType)".
3152             return this->assembleOutAssignedIntrinsic("modf", "fract", "whole", call);
3153 
3154         case k_normalize_IntrinsicKind: {
3155             const char* name = arguments[0]->type().isScalar() ? "sign" : "normalize";
3156             return this->assembleSimpleIntrinsic(name, call);
3157         }
3158         case k_not_IntrinsicKind:
3159             return this->assembleUnaryOpIntrinsic(OperatorKind::LOGICALNOT, call, parentPrecedence);
3160 
3161         case k_notEqual_IntrinsicKind:
3162             return this->assembleBinaryOpIntrinsic(OperatorKind::NEQ, call, parentPrecedence);
3163 
3164         case k_packHalf2x16_IntrinsicKind:
3165             return this->assembleSimpleIntrinsic("pack2x16float", call);
3166 
3167         case k_packSnorm2x16_IntrinsicKind:
3168             return this->assembleSimpleIntrinsic("pack2x16snorm", call);
3169 
3170         case k_packSnorm4x8_IntrinsicKind:
3171             return this->assembleSimpleIntrinsic("pack4x8snorm", call);
3172 
3173         case k_packUnorm2x16_IntrinsicKind:
3174             return this->assembleSimpleIntrinsic("pack2x16unorm", call);
3175 
3176         case k_packUnorm4x8_IntrinsicKind:
3177             return this->assembleSimpleIntrinsic("pack4x8unorm", call);
3178 
3179         case k_reflect_IntrinsicKind:
3180             if (arguments[0]->type().isScalar()) {
3181                 // I - 2 * N * I * N
3182                 // We can use a scratch-let for N to dodge WGSL overflow errors. (skia:14385)
3183                 std::string I = this->writeNontrivialScratchLet(*arguments[0],
3184                                                                 Precedence::kAdditive);
3185                 std::string N = all_arguments_constant(arguments)
3186                       ? this->writeScratchLet(*arguments[1], Precedence::kMultiplicative)
3187                       : this->writeNontrivialScratchLet(*arguments[1], Precedence::kMultiplicative);
3188                 return this->writeScratchLet(String::printf("%s - 2 * %s * %s * %s",
3189                                                             I.c_str(), N.c_str(),
3190                                                             I.c_str(), N.c_str()));
3191             }
3192             return this->assembleSimpleIntrinsic("reflect", call);
3193 
3194         case k_refract_IntrinsicKind:
3195             if (arguments[0]->type().isScalar()) {
3196                 // WGSL only implements refract for vectors; rather than reimplementing refract from
3197                 // scratch, we can replace the call with `refract(float2(I,0), float2(N,0), eta).x`.
3198                 std::string I = this->writeNontrivialScratchLet(*arguments[0],
3199                                                                 Precedence::kSequence);
3200                 std::string N = this->writeNontrivialScratchLet(*arguments[1],
3201                                                                 Precedence::kSequence);
3202                 // We can use a scratch-let for Eta to avoid WGSL overflow errors. (skia:14385)
3203                 std::string Eta = all_arguments_constant(arguments)
3204                       ? this->writeScratchLet(*arguments[2], Precedence::kSequence)
3205                       : this->writeNontrivialScratchLet(*arguments[2], Precedence::kSequence);
3206                 return this->writeScratchLet(
3207                         String::printf("refract(vec2<%s>(%s, 0), vec2<%s>(%s, 0), %s).x",
3208                                        to_wgsl_type(fContext, arguments[0]->type()).c_str(),
3209                                        I.c_str(),
3210                                        to_wgsl_type(fContext, arguments[1]->type()).c_str(),
3211                                        N.c_str(),
3212                                        Eta.c_str()));
3213             }
3214             return this->assembleSimpleIntrinsic("refract", call);
3215 
3216         case k_sample_IntrinsicKind: {
3217             // Determine if a bias argument was passed in.
3218             SkASSERT(arguments.size() == 2 || arguments.size() == 3);
3219             bool callIncludesBias = (arguments.size() == 3);
3220 
3221             if (fProgram.fConfig->fSettings.fSharpenTextures || callIncludesBias) {
3222                 // We need to supply a bias argument; this is a separate intrinsic in WGSL.
3223                 std::string expr = this->assemblePartialSampleCall("textureSampleBias",
3224                                                                    *arguments[0],
3225                                                                    *arguments[1]);
3226                 expr += ", ";
3227                 if (callIncludesBias) {
3228                     expr += this->assembleExpression(*arguments[2], Precedence::kAdditive) +
3229                             " + ";
3230                 }
3231                 expr += skstd::to_string(fProgram.fConfig->fSettings.fSharpenTextures
3232                                                  ? kSharpenTexturesBias
3233                                                  : 0.0f);
3234                 return expr + ')';
3235             }
3236 
3237             // No bias is necessary, so we can call `textureSample` directly.
3238             return this->assemblePartialSampleCall("textureSample",
3239                                                    *arguments[0],
3240                                                    *arguments[1]) + ')';
3241         }
3242         case k_sampleLod_IntrinsicKind: {
3243             std::string expr = this->assemblePartialSampleCall("textureSampleLevel",
3244                                                                *arguments[0],
3245                                                                *arguments[1]);
3246             expr += ", " + this->assembleExpression(*arguments[2], Precedence::kSequence);
3247             return expr + ')';
3248         }
3249         case k_sampleGrad_IntrinsicKind: {
3250             std::string expr = this->assemblePartialSampleCall("textureSampleGrad",
3251                                                                *arguments[0],
3252                                                                *arguments[1]);
3253             expr += ", " + this->assembleExpression(*arguments[2], Precedence::kSequence);
3254             expr += ", " + this->assembleExpression(*arguments[3], Precedence::kSequence);
3255             return expr + ')';
3256         }
3257         case k_textureHeight_IntrinsicKind:
3258             return this->assembleSimpleIntrinsic("textureDimensions", call) + ".y";
3259 
3260         case k_textureRead_IntrinsicKind: {
3261             // We need to inject an extra argument for the mip-level. We don't plan on using mipmaps
3262             // in our storage textures, so we can just pass zero.
3263             std::string tex = this->assembleExpression(*arguments[0], Precedence::kSequence);
3264             std::string pos = this->writeScratchLet(*arguments[1], Precedence::kSequence);
3265             return std::string("textureLoad(") + tex + ", " + pos + ", 0)";
3266         }
3267         case k_textureWidth_IntrinsicKind:
3268             return this->assembleSimpleIntrinsic("textureDimensions", call) + ".x";
3269 
3270         case k_textureWrite_IntrinsicKind:
3271             return this->assembleSimpleIntrinsic("textureStore", call);
3272 
3273         case k_unpackHalf2x16_IntrinsicKind:
3274             return this->assembleSimpleIntrinsic("unpack2x16float", call);
3275 
3276         case k_unpackSnorm2x16_IntrinsicKind:
3277             return this->assembleSimpleIntrinsic("unpack2x16snorm", call);
3278 
3279         case k_unpackSnorm4x8_IntrinsicKind:
3280             return this->assembleSimpleIntrinsic("unpack4x8snorm", call);
3281 
3282         case k_unpackUnorm2x16_IntrinsicKind:
3283             return this->assembleSimpleIntrinsic("unpack2x16unorm", call);
3284 
3285         case k_unpackUnorm4x8_IntrinsicKind:
3286             return this->assembleSimpleIntrinsic("unpack4x8unorm", call);
3287 
3288         case k_clamp_IntrinsicKind:
3289         case k_max_IntrinsicKind:
3290         case k_min_IntrinsicKind:
3291         case k_smoothstep_IntrinsicKind:
3292         case k_step_IntrinsicKind:
3293             return this->assembleVectorizedIntrinsic(call.function().name(), call);
3294 
3295         case k_abs_IntrinsicKind:
3296         case k_acos_IntrinsicKind:
3297         case k_all_IntrinsicKind:
3298         case k_any_IntrinsicKind:
3299         case k_asin_IntrinsicKind:
3300         case k_atomicAdd_IntrinsicKind:
3301         case k_atomicLoad_IntrinsicKind:
3302         case k_atomicStore_IntrinsicKind:
3303         case k_ceil_IntrinsicKind:
3304         case k_cos_IntrinsicKind:
3305         case k_cross_IntrinsicKind:
3306         case k_degrees_IntrinsicKind:
3307         case k_distance_IntrinsicKind:
3308         case k_exp_IntrinsicKind:
3309         case k_exp2_IntrinsicKind:
3310         case k_floor_IntrinsicKind:
3311         case k_fract_IntrinsicKind:
3312         case k_length_IntrinsicKind:
3313         case k_log_IntrinsicKind:
3314         case k_log2_IntrinsicKind:
3315         case k_radians_IntrinsicKind:
3316         case k_pow_IntrinsicKind:
3317         case k_saturate_IntrinsicKind:
3318         case k_sign_IntrinsicKind:
3319         case k_sin_IntrinsicKind:
3320         case k_sqrt_IntrinsicKind:
3321         case k_storageBarrier_IntrinsicKind:
3322         case k_tan_IntrinsicKind:
3323         case k_workgroupBarrier_IntrinsicKind:
3324         default:
3325             return this->assembleSimpleIntrinsic(call.function().name(), call);
3326     }
3327 }
3328 
3329 static constexpr char kInverse2x2[] =
3330      "fn mat2_inverse(m: mat2x2<f32>) -> mat2x2<f32> {"
3331 "\n"     "return mat2x2<f32>(m[1].y, -m[0].y, -m[1].x, m[0].x) * (1/determinant(m));"
3332 "\n" "}"
3333 "\n";
3334 
3335 static constexpr char kInverse3x3[] =
3336      "fn mat3_inverse(m: mat3x3<f32>) -> mat3x3<f32> {"
3337 "\n"     "let a00 = m[0].x; let a01 = m[0].y; let a02 = m[0].z;"
3338 "\n"     "let a10 = m[1].x; let a11 = m[1].y; let a12 = m[1].z;"
3339 "\n"     "let a20 = m[2].x; let a21 = m[2].y; let a22 = m[2].z;"
3340 "\n"     "let b01 =  a22*a11 - a12*a21;"
3341 "\n"     "let b11 = -a22*a10 + a12*a20;"
3342 "\n"     "let b21 =  a21*a10 - a11*a20;"
3343 "\n"     "let det = a00*b01 + a01*b11 + a02*b21;"
3344 "\n"     "return mat3x3<f32>(b01, (-a22*a01 + a02*a21), ( a12*a01 - a02*a11),"
3345 "\n"                        "b11, ( a22*a00 - a02*a20), (-a12*a00 + a02*a10),"
3346 "\n"                        "b21, (-a21*a00 + a01*a20), ( a11*a00 - a01*a10)) * (1/det);"
3347 "\n" "}"
3348 "\n";
3349 
3350 static constexpr char kInverse4x4[] =
3351      "fn mat4_inverse(m: mat4x4<f32>) -> mat4x4<f32>{"
3352 "\n"     "let a00 = m[0].x; let a01 = m[0].y; let a02 = m[0].z; let a03 = m[0].w;"
3353 "\n"     "let a10 = m[1].x; let a11 = m[1].y; let a12 = m[1].z; let a13 = m[1].w;"
3354 "\n"     "let a20 = m[2].x; let a21 = m[2].y; let a22 = m[2].z; let a23 = m[2].w;"
3355 "\n"     "let a30 = m[3].x; let a31 = m[3].y; let a32 = m[3].z; let a33 = m[3].w;"
3356 "\n"     "let b00 = a00*a11 - a01*a10;"
3357 "\n"     "let b01 = a00*a12 - a02*a10;"
3358 "\n"     "let b02 = a00*a13 - a03*a10;"
3359 "\n"     "let b03 = a01*a12 - a02*a11;"
3360 "\n"     "let b04 = a01*a13 - a03*a11;"
3361 "\n"     "let b05 = a02*a13 - a03*a12;"
3362 "\n"     "let b06 = a20*a31 - a21*a30;"
3363 "\n"     "let b07 = a20*a32 - a22*a30;"
3364 "\n"     "let b08 = a20*a33 - a23*a30;"
3365 "\n"     "let b09 = a21*a32 - a22*a31;"
3366 "\n"     "let b10 = a21*a33 - a23*a31;"
3367 "\n"     "let b11 = a22*a33 - a23*a32;"
3368 "\n"     "let det = b00*b11 - b01*b10 + b02*b09 + b03*b08 - b04*b07 + b05*b06;"
3369 "\n"     "return mat4x4<f32>(a11*b11 - a12*b10 + a13*b09,"
3370 "\n"                        "a02*b10 - a01*b11 - a03*b09,"
3371 "\n"                        "a31*b05 - a32*b04 + a33*b03,"
3372 "\n"                        "a22*b04 - a21*b05 - a23*b03,"
3373 "\n"                        "a12*b08 - a10*b11 - a13*b07,"
3374 "\n"                        "a00*b11 - a02*b08 + a03*b07,"
3375 "\n"                        "a32*b02 - a30*b05 - a33*b01,"
3376 "\n"                        "a20*b05 - a22*b02 + a23*b01,"
3377 "\n"                        "a10*b10 - a11*b08 + a13*b06,"
3378 "\n"                        "a01*b08 - a00*b10 - a03*b06,"
3379 "\n"                        "a30*b04 - a31*b02 + a33*b00,"
3380 "\n"                        "a21*b02 - a20*b04 - a23*b00,"
3381 "\n"                        "a11*b07 - a10*b09 - a12*b06,"
3382 "\n"                        "a00*b09 - a01*b07 + a02*b06,"
3383 "\n"                        "a31*b01 - a30*b03 - a32*b00,"
3384 "\n"                        "a20*b03 - a21*b01 + a22*b00) * (1/det);"
3385 "\n" "}"
3386 "\n";
3387 
assembleInversePolyfill(const FunctionCall & call)3388 std::string WGSLCodeGenerator::assembleInversePolyfill(const FunctionCall& call) {
3389     const ExpressionArray& arguments = call.arguments();
3390     const Type& type = arguments.front()->type();
3391 
3392     // The `inverse` intrinsic should only accept a single-argument square matrix.
3393     // Once we implement f16 support, these polyfills will need to be updated to support `hmat`;
3394     // for the time being, all floats in WGSL are f32, so we don't need to worry about precision.
3395     SkASSERT(arguments.size() == 1);
3396     SkASSERT(type.isMatrix());
3397     SkASSERT(type.rows() == type.columns());
3398 
3399     switch (type.slotCount()) {
3400         case 4:
3401             if (!fWrittenInverse2) {
3402                 fWrittenInverse2 = true;
3403                 fHeader.writeText(kInverse2x2);
3404             }
3405             return this->assembleSimpleIntrinsic("mat2_inverse", call);
3406 
3407         case 9:
3408             if (!fWrittenInverse3) {
3409                 fWrittenInverse3 = true;
3410                 fHeader.writeText(kInverse3x3);
3411             }
3412             return this->assembleSimpleIntrinsic("mat3_inverse", call);
3413 
3414         case 16:
3415             if (!fWrittenInverse4) {
3416                 fWrittenInverse4 = true;
3417                 fHeader.writeText(kInverse4x4);
3418             }
3419             return this->assembleSimpleIntrinsic("mat4_inverse", call);
3420 
3421         default:
3422             // We only support square matrices.
3423             SkUNREACHABLE;
3424     }
3425 }
3426 
assembleFunctionCall(const FunctionCall & call,Precedence parentPrecedence)3427 std::string WGSLCodeGenerator::assembleFunctionCall(const FunctionCall& call,
3428                                                     Precedence parentPrecedence) {
3429     const FunctionDeclaration& func = call.function();
3430     std::string result;
3431 
3432     // Many intrinsics need to be rewritten in WGSL.
3433     if (func.isIntrinsic()) {
3434         return this->assembleIntrinsicCall(call, func.intrinsicKind(), parentPrecedence);
3435     }
3436 
3437     // We implement function out-parameters by declaring them as pointers. SkSL follows GLSL's
3438     // out-parameter semantics, in which out-parameters are only written back to the original
3439     // variable after the function's execution is complete (see
3440     // https://www.khronos.org/opengl/wiki/Core_Language_(GLSL)#Parameters).
3441     //
3442     // In addition, SkSL supports swizzles and array index expressions to be passed into
3443     // out-parameters; however, WGSL does not allow taking their address into a pointer.
3444     //
3445     // We support these by using LValues to create temporary copies and then pass pointers to the
3446     // copies. Once the function returns, we copy the values back to the LValue.
3447 
3448     // First detect which arguments are passed to out-parameters.
3449     // TODO: rewrite this method in terms of LValues.
3450     const ExpressionArray& args = call.arguments();
3451     SkSpan<Variable* const> params = func.parameters();
3452     SkASSERT(SkToSizeT(args.size()) == params.size());
3453 
3454     STArray<16, std::unique_ptr<LValue>> writeback;
3455     STArray<16, std::string> substituteArgument;
3456     writeback.reserve_exact(args.size());
3457     substituteArgument.reserve_exact(args.size());
3458 
3459     for (int index = 0; index < args.size(); ++index) {
3460         if (params[index]->modifierFlags() & ModifierFlag::kOut) {
3461             std::unique_ptr<LValue> lvalue = this->makeLValue(*args[index]);
3462             if (params[index]->modifierFlags() & ModifierFlag::kIn) {
3463                 // Load the lvalue's contents into the substitute argument.
3464                 substituteArgument.push_back(this->writeScratchVar(args[index]->type(),
3465                                                                    lvalue->load()));
3466             } else {
3467                 // Create a substitute argument, but leave it uninitialized.
3468                 substituteArgument.push_back(this->writeScratchVar(args[index]->type()));
3469             }
3470             writeback.push_back(std::move(lvalue));
3471         } else {
3472             substituteArgument.push_back(std::string());
3473             writeback.push_back(nullptr);
3474         }
3475     }
3476 
3477     std::string expr = this->assembleName(func.mangledName());
3478     expr.push_back('(');
3479     auto separator = SkSL::String::Separator();
3480 
3481     if (std::string funcDepArgs = this->functionDependencyArgs(func); !funcDepArgs.empty()) {
3482         expr += funcDepArgs;
3483         separator();
3484     }
3485 
3486     // Pass the function arguments, or any substitutes as needed.
3487     for (int index = 0; index < args.size(); ++index) {
3488         expr += separator();
3489         if (!substituteArgument[index].empty()) {
3490             // We need to take the address of the variable and pass it down as a pointer.
3491             expr += '&' + substituteArgument[index];
3492         } else if (args[index]->type().isSampler()) {
3493             // If the argument is a sampler, we need to pass the texture _and_ its associated
3494             // sampler. (Function parameter lists also convert sampler parameters into a matching
3495             // texture/sampler parameter pair.)
3496             expr += this->assembleExpression(*args[index], Precedence::kSequence);
3497             expr += kTextureSuffix;
3498             expr += ", ";
3499             expr += this->assembleExpression(*args[index], Precedence::kSequence);
3500             expr += kSamplerSuffix;
3501         } else {
3502             expr += this->assembleExpression(*args[index], Precedence::kSequence);
3503         }
3504     }
3505     expr += ')';
3506 
3507     if (call.type().isVoid()) {
3508         // Making function calls that result in `void` is only valid in on the left side of a
3509         // comma-sequence, or in a top-level statement. Emit the function call as a top-level
3510         // statement and return an empty string, as the result will not be used.
3511         SkASSERT(parentPrecedence >= Precedence::kSequence);
3512         this->write(expr);
3513         this->writeLine(";");
3514     } else {
3515         result = this->writeScratchLet(expr);
3516     }
3517 
3518     // Write the substitute arguments back into their lvalues.
3519     for (int index = 0; index < args.size(); ++index) {
3520         if (!substituteArgument[index].empty()) {
3521             this->writeLine(writeback[index]->store(substituteArgument[index]));
3522         }
3523     }
3524 
3525     // Return the result of invoking the function.
3526     return result;
3527 }
3528 
assembleIndexExpression(const IndexExpression & i)3529 std::string WGSLCodeGenerator::assembleIndexExpression(const IndexExpression& i) {
3530     // Put the index value into a let-expression.
3531     std::string idx = this->writeNontrivialScratchLet(*i.index(), Precedence::kExpression);
3532     return this->assembleExpression(*i.base(), Precedence::kPostfix) + "[" + idx + "]";
3533 }
3534 
assembleLiteral(const Literal & l)3535 std::string WGSLCodeGenerator::assembleLiteral(const Literal& l) {
3536     const Type& type = l.type();
3537     if (type.isFloat() || type.isBoolean()) {
3538         return l.description(OperatorPrecedence::kExpression);
3539     }
3540     SkASSERT(type.isInteger());
3541     if (type.matches(*fContext.fTypes.fUInt)) {
3542         return std::to_string(l.intValue() & 0xffffffff) + "u";
3543     } else if (type.matches(*fContext.fTypes.fUShort)) {
3544         return std::to_string(l.intValue() & 0xffff) + "u";
3545     } else {
3546         return std::to_string(l.intValue());
3547     }
3548 }
3549 
assembleIncrementExpr(const Type & type)3550 std::string WGSLCodeGenerator::assembleIncrementExpr(const Type& type) {
3551     // `type(`
3552     std::string expr = to_wgsl_type(fContext, type);
3553     expr.push_back('(');
3554 
3555     // `1, 1, 1...)`
3556     auto separator = SkSL::String::Separator();
3557     for (int slots = type.slotCount(); slots > 0; --slots) {
3558         expr += separator();
3559         expr += "1";
3560     }
3561     expr.push_back(')');
3562     return expr;
3563 }
3564 
assemblePrefixExpression(const PrefixExpression & p,Precedence parentPrecedence)3565 std::string WGSLCodeGenerator::assemblePrefixExpression(const PrefixExpression& p,
3566                                                         Precedence parentPrecedence) {
3567     // Unary + does nothing, so we omit it from the output.
3568     Operator op = p.getOperator();
3569     if (op.kind() == Operator::Kind::PLUS) {
3570         return this->assembleExpression(*p.operand(), Precedence::kPrefix);
3571     }
3572 
3573     // Pre-increment/decrement expressions have no direct equivalent in WGSL.
3574     if (op.kind() == Operator::Kind::PLUSPLUS || op.kind() == Operator::Kind::MINUSMINUS) {
3575         std::unique_ptr<LValue> lvalue = this->makeLValue(*p.operand());
3576         if (!lvalue) {
3577             return "";
3578         }
3579 
3580         // Generate the new value: `lvalue + type(1, 1, 1...)`.
3581         std::string newValue =
3582                 lvalue->load() +
3583                 (p.getOperator().kind() == Operator::Kind::PLUSPLUS ? " + " : " - ") +
3584                 this->assembleIncrementExpr(p.operand()->type());
3585         this->writeLine(lvalue->store(newValue));
3586         return lvalue->load();
3587     }
3588 
3589     // WGSL natively supports unary negation/not expressions (!,~,-).
3590     SkASSERT(op.kind() == OperatorKind::LOGICALNOT ||
3591              op.kind() == OperatorKind::BITWISENOT ||
3592              op.kind() == OperatorKind::MINUS);
3593 
3594     // The unary negation operator only applies to scalars and vectors. For other mathematical
3595     // objects (such as matrices) we can express it as a multiplication by -1.
3596     std::string expr;
3597     const bool needsNegation = op.kind() == Operator::Kind::MINUS &&
3598                                !p.operand()->type().isScalar() && !p.operand()->type().isVector();
3599     const bool needParens = needsNegation || Precedence::kPrefix >= parentPrecedence;
3600 
3601     if (needParens) {
3602         expr.push_back('(');
3603     }
3604 
3605     if (needsNegation) {
3606         expr += "-1.0 * ";
3607         expr += this->assembleExpression(*p.operand(), Precedence::kMultiplicative);
3608     } else {
3609         expr += p.getOperator().tightOperatorName();
3610         expr += this->assembleExpression(*p.operand(), Precedence::kPrefix);
3611     }
3612 
3613     if (needParens) {
3614         expr.push_back(')');
3615     }
3616 
3617     return expr;
3618 }
3619 
assemblePostfixExpression(const PostfixExpression & p,Precedence parentPrecedence)3620 std::string WGSLCodeGenerator::assemblePostfixExpression(const PostfixExpression& p,
3621                                                          Precedence parentPrecedence) {
3622     SkASSERT(p.getOperator().kind() == Operator::Kind::PLUSPLUS ||
3623              p.getOperator().kind() == Operator::Kind::MINUSMINUS);
3624 
3625     // Post-increment/decrement expressions have no direct equivalent in WGSL; they do exist as a
3626     // standalone statement for convenience, but these aren't the same as SkSL's post-increments.
3627     std::unique_ptr<LValue> lvalue = this->makeLValue(*p.operand());
3628     if (!lvalue) {
3629         return "";
3630     }
3631 
3632     // If the expression is used, create a let-copy of the original value.
3633     // (At statement-level precedence, we know the value is unused and can skip this let-copy.)
3634     std::string originalValue;
3635     if (parentPrecedence != Precedence::kStatement) {
3636         originalValue = this->writeScratchLet(lvalue->load());
3637     }
3638     // Generate the new value: `lvalue + type(1, 1, 1...)`.
3639     std::string newValue = lvalue->load() +
3640                            (p.getOperator().kind() == Operator::Kind::PLUSPLUS ? " + " : " - ") +
3641                            this->assembleIncrementExpr(p.operand()->type());
3642     this->writeLine(lvalue->store(newValue));
3643 
3644     return originalValue;
3645 }
3646 
assembleSwizzle(const Swizzle & swizzle)3647 std::string WGSLCodeGenerator::assembleSwizzle(const Swizzle& swizzle) {
3648     return this->assembleExpression(*swizzle.base(), Precedence::kPostfix) + "." +
3649            Swizzle::MaskString(swizzle.components());
3650 }
3651 
writeScratchVar(const Type & type,const std::string & value)3652 std::string WGSLCodeGenerator::writeScratchVar(const Type& type, const std::string& value) {
3653     std::string scratchVarName = "_skTemp" + std::to_string(fScratchCount++);
3654     this->write("var ");
3655     this->write(scratchVarName);
3656     this->write(": ");
3657     this->write(to_wgsl_type(fContext, type));
3658     if (!value.empty()) {
3659         this->write(" = ");
3660         this->write(value);
3661     }
3662     this->writeLine(";");
3663     return scratchVarName;
3664 }
3665 
writeScratchLet(const std::string & expr)3666 std::string WGSLCodeGenerator::writeScratchLet(const std::string& expr) {
3667     std::string scratchVarName = "_skTemp" + std::to_string(fScratchCount++);
3668     this->write(fAtFunctionScope ? "let " : "const ");
3669     this->write(scratchVarName);
3670     this->write(" = ");
3671     this->write(expr);
3672     this->writeLine(";");
3673     return scratchVarName;
3674 }
3675 
writeScratchLet(const Expression & expr,Precedence parentPrecedence)3676 std::string WGSLCodeGenerator::writeScratchLet(const Expression& expr,
3677                                                Precedence parentPrecedence) {
3678     return this->writeScratchLet(this->assembleExpression(expr, parentPrecedence));
3679 }
3680 
writeNontrivialScratchLet(const Expression & expr,Precedence parentPrecedence)3681 std::string WGSLCodeGenerator::writeNontrivialScratchLet(const Expression& expr,
3682                                                          Precedence parentPrecedence) {
3683     std::string result = this->assembleExpression(expr, parentPrecedence);
3684     return is_nontrivial_expression(expr) ? this->writeScratchLet(result)
3685                                           : result;
3686 }
3687 
assembleTernaryExpression(const TernaryExpression & t,Precedence parentPrecedence)3688 std::string WGSLCodeGenerator::assembleTernaryExpression(const TernaryExpression& t,
3689                                                          Precedence parentPrecedence) {
3690     std::string expr;
3691 
3692     // The trivial case is when neither branch has side effects and evaluate to a scalar or vector
3693     // type. This can be represented with a call to the WGSL `select` intrinsic. Select doesn't
3694     // support short-circuiting, so we should only use it when both the true- and false-expressions
3695     // are trivial to evaluate.
3696     if ((t.type().isScalar() || t.type().isVector()) &&
3697         !Analysis::HasSideEffects(*t.test()) &&
3698         Analysis::IsTrivialExpression(*t.ifTrue()) &&
3699         Analysis::IsTrivialExpression(*t.ifFalse())) {
3700 
3701         bool needParens = Precedence::kTernary >= parentPrecedence;
3702         if (needParens) {
3703             expr.push_back('(');
3704         }
3705         expr += "select(";
3706         expr += this->assembleExpression(*t.ifFalse(), Precedence::kSequence);
3707         expr += ", ";
3708         expr += this->assembleExpression(*t.ifTrue(), Precedence::kSequence);
3709         expr += ", ";
3710 
3711         bool isVector = t.type().isVector();
3712         if (isVector) {
3713             // Splat the condition expression into a vector.
3714             expr += String::printf("vec%d<bool>(", t.type().columns());
3715         }
3716         expr += this->assembleExpression(*t.test(), Precedence::kSequence);
3717         if (isVector) {
3718             expr.push_back(')');
3719         }
3720         expr.push_back(')');
3721         if (needParens) {
3722             expr.push_back(')');
3723         }
3724     } else {
3725         // WGSL does not support ternary expressions. Instead, we hoist the expression out into the
3726         // surrounding block, convert it into an if statement, and write the result to a synthesized
3727         // variable. Instead of the original expression, we return that variable.
3728         expr = this->writeScratchVar(t.ifTrue()->type());
3729 
3730         std::string testExpr = this->assembleExpression(*t.test(), Precedence::kExpression);
3731         this->write("if ");
3732         this->write(testExpr);
3733         this->writeLine(" {");
3734 
3735         ++fIndentation;
3736         std::string trueExpr = this->assembleExpression(*t.ifTrue(), Precedence::kAssignment);
3737         this->write(expr);
3738         this->write(" = ");
3739         this->write(trueExpr);
3740         this->writeLine(";");
3741         --fIndentation;
3742 
3743         this->writeLine("} else {");
3744 
3745         ++fIndentation;
3746         std::string falseExpr = this->assembleExpression(*t.ifFalse(), Precedence::kAssignment);
3747         this->write(expr);
3748         this->write(" = ");
3749         this->write(falseExpr);
3750         this->writeLine(";");
3751         --fIndentation;
3752 
3753         this->writeLine("}");
3754     }
3755     return expr;
3756 }
3757 
variablePrefix(const Variable & v)3758 std::string WGSLCodeGenerator::variablePrefix(const Variable& v) {
3759     if (v.storage() == Variable::Storage::kGlobal) {
3760         // If the field refers to a pipeline IO parameter, then we access it via the synthesized IO
3761         // structs. We make an explicit exception for `sk_PointSize` which we declare as a
3762         // placeholder variable in global scope as it is not supported by WebGPU as a pipeline IO
3763         // parameter (see comments in `writeStageOutputStruct`).
3764         if (v.modifierFlags() & ModifierFlag::kIn) {
3765             return "_stageIn.";
3766         }
3767         if (v.modifierFlags() & ModifierFlag::kOut) {
3768             return "(*_stageOut).";
3769         }
3770 
3771         // If the field refers to an anonymous-interface-block structure, access it via the
3772         // synthesized `_uniform0` or `_storage1` global.
3773         if (const InterfaceBlock* ib = v.interfaceBlock()) {
3774             const Type& ibType = ib->var()->type().componentType();
3775             if (const std::string* ibName = fInterfaceBlockNameMap.find(&ibType)) {
3776                 return *ibName + '.';
3777             }
3778         }
3779 
3780         // If the field refers to an top-level uniform, access it via the synthesized
3781         // `_globalUniforms` global. (Note that this should only occur in test code; Skia will
3782         // always put uniforms in an interface block.)
3783         if (is_in_global_uniforms(v)) {
3784             return "_globalUniforms.";
3785         }
3786     }
3787 
3788     return "";
3789 }
3790 
variableReferenceNameForLValue(const VariableReference & r)3791 std::string WGSLCodeGenerator::variableReferenceNameForLValue(const VariableReference& r) {
3792     const Variable& v = *r.variable();
3793 
3794     if ((v.storage() == Variable::Storage::kParameter &&
3795          v.modifierFlags() & ModifierFlag::kOut)) {
3796         // This is an out-parameter; it's pointer-typed, so we need to dereference it. We wrap the
3797         // dereference in parentheses, in case the value is used in an access expression later.
3798         return "(*" + this->assembleName(v.mangledName()) + ')';
3799     }
3800 
3801     return this->variablePrefix(v) + this->assembleName(v.mangledName());
3802 }
3803 
assembleVariableReference(const VariableReference & r)3804 std::string WGSLCodeGenerator::assembleVariableReference(const VariableReference& r) {
3805     // TODO(b/294274678): Correctly handle RTFlip for built-ins.
3806     const Variable& v = *r.variable();
3807 
3808     // Insert a conversion expression if this is a built-in variable whose type differs from the
3809     // SkSL.
3810     std::string expr;
3811     std::optional<std::string_view> conversion = needs_builtin_type_conversion(v);
3812     if (conversion.has_value()) {
3813         expr += *conversion;
3814         expr.push_back('(');
3815     }
3816 
3817     expr += this->variableReferenceNameForLValue(r);
3818 
3819     if (conversion.has_value()) {
3820         expr.push_back(')');
3821     }
3822 
3823     return expr;
3824 }
3825 
assembleAnyConstructor(const AnyConstructor & c)3826 std::string WGSLCodeGenerator::assembleAnyConstructor(const AnyConstructor& c) {
3827     std::string expr = to_wgsl_type(fContext, c.type());
3828     expr.push_back('(');
3829     auto separator = SkSL::String::Separator();
3830     for (const auto& e : c.argumentSpan()) {
3831         expr += separator();
3832         expr += this->assembleExpression(*e, Precedence::kSequence);
3833     }
3834     expr.push_back(')');
3835     return expr;
3836 }
3837 
assembleConstructorCompound(const ConstructorCompound & c)3838 std::string WGSLCodeGenerator::assembleConstructorCompound(const ConstructorCompound& c) {
3839     if (c.type().isVector()) {
3840         return this->assembleConstructorCompoundVector(c);
3841     } else if (c.type().isMatrix()) {
3842         return this->assembleConstructorCompoundMatrix(c);
3843     } else {
3844         fContext.fErrors->error(c.fPosition, "unsupported compound constructor");
3845         return {};
3846     }
3847 }
3848 
assembleConstructorCompoundVector(const ConstructorCompound & c)3849 std::string WGSLCodeGenerator::assembleConstructorCompoundVector(const ConstructorCompound& c) {
3850     // WGSL supports constructing vectors from a mix of scalars and vectors but
3851     // not matrices (see https://www.w3.org/TR/WGSL/#type-constructor-expr).
3852     //
3853     // SkSL supports vec4(mat2x2) which we handle specially.
3854     if (c.type().columns() == 4 && c.argumentSpan().size() == 1) {
3855         const Expression& arg = *c.argumentSpan().front();
3856         if (arg.type().isMatrix()) {
3857             SkASSERT(arg.type().columns() == 2);
3858             SkASSERT(arg.type().rows() == 2);
3859 
3860             std::string matrix = this->writeNontrivialScratchLet(arg, Precedence::kPostfix);
3861             return String::printf("%s(%s[0], %s[1])", to_wgsl_type(fContext, c.type()).c_str(),
3862                                                       matrix.c_str(),
3863                                                       matrix.c_str());
3864         }
3865     }
3866     return this->assembleAnyConstructor(c);
3867 }
3868 
assembleConstructorCompoundMatrix(const ConstructorCompound & ctor)3869 std::string WGSLCodeGenerator::assembleConstructorCompoundMatrix(const ConstructorCompound& ctor) {
3870     SkASSERT(ctor.type().isMatrix());
3871 
3872     std::string expr = to_wgsl_type(fContext, ctor.type()) + '(';
3873     auto separator = String::Separator();
3874     for (const std::unique_ptr<Expression>& arg : ctor.arguments()) {
3875         SkASSERT(arg->type().isScalar() || arg->type().isVector());
3876 
3877         if (arg->type().isScalar()) {
3878             expr += separator();
3879             expr += this->assembleExpression(*arg, Precedence::kSequence);
3880         } else {
3881             std::string inner = this->writeNontrivialScratchLet(*arg, Precedence::kSequence);
3882             int numSlots = arg->type().slotCount();
3883             for (int slot = 0; slot < numSlots; ++slot) {
3884                 String::appendf(&expr, "%s%s[%d]", separator().c_str(), inner.c_str(), slot);
3885             }
3886         }
3887     }
3888     return expr + ')';
3889 }
3890 
assembleConstructorDiagonalMatrix(const ConstructorDiagonalMatrix & c)3891 std::string WGSLCodeGenerator::assembleConstructorDiagonalMatrix(
3892         const ConstructorDiagonalMatrix& c) {
3893     const Type& type = c.type();
3894     SkASSERT(type.isMatrix());
3895     SkASSERT(c.argument()->type().isScalar());
3896 
3897     // Evaluate the inner-expression, creating a scratch variable if necessary.
3898     std::string inner = this->writeNontrivialScratchLet(*c.argument(), Precedence::kAssignment);
3899 
3900     // Assemble a diagonal-matrix expression.
3901     std::string expr = to_wgsl_type(fContext, type) + '(';
3902     auto separator = String::Separator();
3903     for (int col = 0; col < type.columns(); ++col) {
3904         for (int row = 0; row < type.rows(); ++row) {
3905             expr += separator();
3906             if (col == row) {
3907                 expr += inner;
3908             } else {
3909                 expr += "0.0";
3910             }
3911         }
3912     }
3913     return expr + ')';
3914 }
3915 
assembleConstructorMatrixResize(const ConstructorMatrixResize & ctor)3916 std::string WGSLCodeGenerator::assembleConstructorMatrixResize(
3917         const ConstructorMatrixResize& ctor) {
3918     std::string source = this->writeScratchLet(this->assembleExpression(*ctor.argument(),
3919                                                                         Precedence::kSequence));
3920     int columns = ctor.type().columns();
3921     int rows = ctor.type().rows();
3922     int sourceColumns = ctor.argument()->type().columns();
3923     int sourceRows = ctor.argument()->type().rows();
3924     auto separator = String::Separator();
3925     std::string expr = to_wgsl_type(fContext, ctor.type()) + '(';
3926 
3927     for (int c = 0; c < columns; ++c) {
3928         for (int r = 0; r < rows; ++r) {
3929             expr += separator();
3930             if (c < sourceColumns && r < sourceRows) {
3931                 String::appendf(&expr, "%s[%d][%d]", source.c_str(), c, r);
3932             } else if (r == c) {
3933                 expr += "1.0";
3934             } else {
3935                 expr += "0.0";
3936             }
3937         }
3938     }
3939 
3940     return expr + ')';
3941 }
3942 
assembleEqualityExpression(const Type & left,const std::string & leftName,const Type & right,const std::string & rightName,Operator op,Precedence parentPrecedence)3943 std::string WGSLCodeGenerator::assembleEqualityExpression(const Type& left,
3944                                                           const std::string& leftName,
3945                                                           const Type& right,
3946                                                           const std::string& rightName,
3947                                                           Operator op,
3948                                                           Precedence parentPrecedence) {
3949     SkASSERT(op.kind() == OperatorKind::EQEQ || op.kind() == OperatorKind::NEQ);
3950 
3951     std::string expr;
3952     bool isEqual = (op.kind() == Operator::Kind::EQEQ);
3953     const char* const combiner = isEqual ? " && " : " || ";
3954 
3955     if (left.isMatrix()) {
3956         // Each matrix column must be compared as if it were an individual vector.
3957         SkASSERT(right.isMatrix());
3958         SkASSERT(left.rows() == right.rows());
3959         SkASSERT(left.columns() == right.columns());
3960         int columns = left.columns();
3961         const Type& vecType = left.columnType(fContext);
3962         const char* separator = "(";
3963         for (int index = 0; index < columns; ++index) {
3964             expr += separator;
3965             std::string suffix = '[' + std::to_string(index) + ']';
3966             expr += this->assembleEqualityExpression(vecType, leftName + suffix,
3967                                                      vecType, rightName + suffix,
3968                                                      op, Precedence::kParentheses);
3969             separator = combiner;
3970         }
3971         return expr + ')';
3972     }
3973 
3974     if (left.isArray()) {
3975         SkASSERT(right.matches(left));
3976         const Type& indexedType = left.componentType();
3977         const char* separator = "(";
3978         for (int index = 0; index < left.columns(); ++index) {
3979             expr += separator;
3980             std::string suffix = '[' + std::to_string(index) + ']';
3981             expr += this->assembleEqualityExpression(indexedType, leftName + suffix,
3982                                                      indexedType, rightName + suffix,
3983                                                      op, Precedence::kParentheses);
3984             separator = combiner;
3985         }
3986         return expr + ')';
3987     }
3988 
3989     if (left.isStruct()) {
3990         // Recursively compare every field in the struct.
3991         SkASSERT(right.matches(left));
3992         SkSpan<const Field> fields = left.fields();
3993 
3994         const char* separator = "(";
3995         for (const Field& field : fields) {
3996             expr += separator;
3997             expr += this->assembleEqualityExpression(
3998                             *field.fType, leftName + '.' + this->assembleName(field.fName),
3999                             *field.fType, rightName + '.' + this->assembleName(field.fName),
4000                             op, Precedence::kParentheses);
4001             separator = combiner;
4002         }
4003         return expr + ')';
4004     }
4005 
4006     if (left.isVector()) {
4007         // Compare vectors via `all(x == y)` or `any(x != y)`.
4008         SkASSERT(right.isVector());
4009         SkASSERT(left.slotCount() == right.slotCount());
4010 
4011         expr += isEqual ? "all(" : "any(";
4012         expr += leftName;
4013         expr += operator_name(op);
4014         expr += rightName;
4015         return expr + ')';
4016     }
4017 
4018     // Compare scalars via `x == y`.
4019     SkASSERT(right.isScalar());
4020     if (parentPrecedence < Precedence::kSequence) {
4021         expr = '(';
4022     }
4023     expr += leftName;
4024     expr += operator_name(op);
4025     expr += rightName;
4026     if (parentPrecedence < Precedence::kSequence) {
4027         expr += ')';
4028     }
4029     return expr;
4030 }
4031 
assembleEqualityExpression(const Expression & left,const Expression & right,Operator op,Precedence parentPrecedence)4032 std::string WGSLCodeGenerator::assembleEqualityExpression(const Expression& left,
4033                                                           const Expression& right,
4034                                                           Operator op,
4035                                                           Precedence parentPrecedence) {
4036     std::string leftName, rightName;
4037     if (left.type().isScalar() || left.type().isVector()) {
4038         // WGSL supports scalar and vector comparisons natively. We know the expressions will only
4039         // be emitted once, so there isn't a benefit to creating a let-declaration.
4040         leftName = this->assembleExpression(left, Precedence::kParentheses);
4041         rightName = this->assembleExpression(right, Precedence::kParentheses);
4042     } else {
4043         leftName = this->writeNontrivialScratchLet(left, Precedence::kAssignment);
4044         rightName = this->writeNontrivialScratchLet(right, Precedence::kAssignment);
4045     }
4046     return this->assembleEqualityExpression(left.type(), leftName, right.type(), rightName,
4047                                             op, parentPrecedence);
4048 }
4049 
writeProgramElement(const ProgramElement & e)4050 void WGSLCodeGenerator::writeProgramElement(const ProgramElement& e) {
4051     switch (e.kind()) {
4052         case ProgramElement::Kind::kExtension:
4053             // TODO(skia:13092): WGSL supports extensions via the "enable" directive
4054             // (https://www.w3.org/TR/WGSL/#enable-extensions-sec ). While we could easily emit this
4055             // directive, we should first ensure that all possible SkSL extension names are
4056             // converted to their appropriate WGSL extension.
4057             break;
4058         case ProgramElement::Kind::kGlobalVar:
4059             this->writeGlobalVarDeclaration(e.as<GlobalVarDeclaration>());
4060             break;
4061         case ProgramElement::Kind::kInterfaceBlock:
4062             // All interface block declarations are handled explicitly as the "program header" in
4063             // generateCode().
4064             break;
4065         case ProgramElement::Kind::kStructDefinition:
4066             this->writeStructDefinition(e.as<StructDefinition>());
4067             break;
4068         case ProgramElement::Kind::kFunctionPrototype:
4069             // A WGSL function declaration must contain its body and the function name is in scope
4070             // for the entire program (see https://www.w3.org/TR/WGSL/#function-declaration and
4071             // https://www.w3.org/TR/WGSL/#declaration-and-scope).
4072             //
4073             // As such, we don't emit function prototypes.
4074             break;
4075         case ProgramElement::Kind::kFunction:
4076             this->writeFunction(e.as<FunctionDefinition>());
4077             break;
4078         case ProgramElement::Kind::kModifiers:
4079             this->writeModifiersDeclaration(e.as<ModifiersDeclaration>());
4080             break;
4081         default:
4082             SkDEBUGFAILF("unsupported program element: %s\n", e.description().c_str());
4083             break;
4084     }
4085 }
4086 
writeTextureOrSampler(const Variable & var,int bindingLocation,std::string_view suffix,std::string_view wgslType)4087 void WGSLCodeGenerator::writeTextureOrSampler(const Variable& var,
4088                                               int bindingLocation,
4089                                               std::string_view suffix,
4090                                               std::string_view wgslType) {
4091     if (var.type().dimensions() != SpvDim2D) {
4092         // Skia currently only uses 2D textures.
4093         fContext.fErrors->error(var.varDeclaration()->position(), "unsupported texture dimensions");
4094         return;
4095     }
4096 
4097     this->write("@group(");
4098     this->write(std::to_string(std::max(0, var.layout().fSet)));
4099     this->write(") @binding(");
4100     this->write(std::to_string(bindingLocation));
4101     this->write(") var ");
4102     this->write(this->assembleName(var.mangledName()));
4103     this->write(suffix);
4104     this->write(": ");
4105     this->write(wgslType);
4106     this->writeLine(";");
4107 }
4108 
writeGlobalVarDeclaration(const GlobalVarDeclaration & d)4109 void WGSLCodeGenerator::writeGlobalVarDeclaration(const GlobalVarDeclaration& d) {
4110     const VarDeclaration& decl = d.varDeclaration();
4111     const Variable& var = *decl.var();
4112     if ((var.modifierFlags() & (ModifierFlag::kIn | ModifierFlag::kOut)) ||
4113         is_in_global_uniforms(var)) {
4114         // Pipeline stage I/O parameters and top-level (non-block) uniforms are handled specially
4115         // in generateCode().
4116         return;
4117     }
4118 
4119     const Type::TypeKind varKind = var.type().typeKind();
4120     if (varKind == Type::TypeKind::kSampler) {
4121         // If the sampler binding was unassigned, provide a scratch value; this will make
4122         // golden-output tests pass, but will not actually be usable for drawing.
4123         int samplerLocation = var.layout().fSampler >= 0 ? var.layout().fSampler
4124                                                          : 10000 + fScratchCount++;
4125         this->writeTextureOrSampler(var, samplerLocation, kSamplerSuffix, "sampler");
4126 
4127         // If the texture binding was unassigned, provide a scratch value (for golden-output tests).
4128         int textureLocation = var.layout().fTexture >= 0 ? var.layout().fTexture
4129                                                          : 10000 + fScratchCount++;
4130         this->writeTextureOrSampler(var, textureLocation, kTextureSuffix, "texture_2d<f32>");
4131         return;
4132     }
4133 
4134     if (varKind == Type::TypeKind::kTexture) {
4135         // If a binding location was unassigned, provide a scratch value (for golden-output tests).
4136         int textureLocation = var.layout().fBinding >= 0 ? var.layout().fBinding
4137                                                          : 10000 + fScratchCount++;
4138         // For a texture without an associated sampler, we don't apply a suffix.
4139         this->writeTextureOrSampler(var, textureLocation, /*suffix=*/"",
4140                                     to_wgsl_type(fContext, var.type(), &var.layout()));
4141         return;
4142     }
4143 
4144     std::string initializer;
4145     if (decl.value()) {
4146         // We assume here that the initial-value expression will not emit any helper statements.
4147         // Initial-value expressions are required to pass IsConstantExpression, which limits the
4148         // blast radius to constructors, literals, and other constant values/variables.
4149         initializer += " = ";
4150         initializer += this->assembleExpression(*decl.value(), Precedence::kAssignment);
4151     }
4152 
4153     if (var.modifierFlags().isConst()) {
4154         this->write("const ");
4155     } else if (var.modifierFlags().isWorkgroup()) {
4156         this->write("var<workgroup> ");
4157     } else if (var.modifierFlags().isPixelLocal()) {
4158         this->write("var<pixel_local> ");
4159     } else {
4160         this->write("var<private> ");
4161     }
4162     this->write(this->assembleName(var.mangledName()));
4163     this->write(": " + to_wgsl_type(fContext, var.type(), &var.layout()));
4164     this->write(initializer);
4165     this->writeLine(";");
4166 }
4167 
writeStructDefinition(const StructDefinition & s)4168 void WGSLCodeGenerator::writeStructDefinition(const StructDefinition& s) {
4169     const Type& type = s.type();
4170     this->writeLine("struct " + type.displayName() + " {");
4171     this->writeFields(type.fields(), /*memoryLayout=*/nullptr);
4172     this->writeLine("};");
4173 }
4174 
writeModifiersDeclaration(const ModifiersDeclaration & modifiers)4175 void WGSLCodeGenerator::writeModifiersDeclaration(const ModifiersDeclaration& modifiers) {
4176     LayoutFlags flags = modifiers.layout().fFlags;
4177     flags &= ~(LayoutFlag::kLocalSizeX | LayoutFlag::kLocalSizeY | LayoutFlag::kLocalSizeZ);
4178     if (flags != LayoutFlag::kNone) {
4179         fContext.fErrors->error(modifiers.position(), "unsupported declaration");
4180         return;
4181     }
4182 
4183     if (modifiers.layout().fLocalSizeX >= 0) {
4184         fLocalSizeX = modifiers.layout().fLocalSizeX;
4185     }
4186     if (modifiers.layout().fLocalSizeY >= 0) {
4187         fLocalSizeY = modifiers.layout().fLocalSizeY;
4188     }
4189     if (modifiers.layout().fLocalSizeZ >= 0) {
4190         fLocalSizeZ = modifiers.layout().fLocalSizeZ;
4191     }
4192 }
4193 
writeFields(SkSpan<const Field> fields,const MemoryLayout * memoryLayout)4194 void WGSLCodeGenerator::writeFields(SkSpan<const Field> fields, const MemoryLayout* memoryLayout) {
4195     fIndentation++;
4196 
4197     // TODO(skia:14370): array uniforms may need manual fixup for std140 padding. (Those uniforms
4198     // will also need special handling when they are accessed, or passed to functions.)
4199     for (size_t index = 0; index < fields.size(); ++index) {
4200         const Field& field = fields[index];
4201         if (memoryLayout && !memoryLayout->isSupported(*field.fType)) {
4202             // Reject types that aren't supported by the memory layout.
4203             fContext.fErrors->error(field.fPosition, "type '" + std::string(field.fType->name()) +
4204                                                      "' is not permitted here");
4205             return;
4206         }
4207 
4208         // Prepend @size(n) to enforce the offsets from the SkSL layout. (This is effectively
4209         // a gadget that we can use to insert padding between elements.)
4210         if (index < fields.size() - 1) {
4211             int thisFieldOffset = field.fLayout.fOffset;
4212             int nextFieldOffset = fields[index + 1].fLayout.fOffset;
4213             if (index == 0 && thisFieldOffset > 0) {
4214                 fContext.fErrors->error(field.fPosition, "field must have an offset of zero");
4215                 return;
4216             }
4217             if (thisFieldOffset >= 0 && nextFieldOffset > thisFieldOffset) {
4218                 this->write("@size(");
4219                 this->write(std::to_string(nextFieldOffset - thisFieldOffset));
4220                 this->write(") ");
4221             }
4222         }
4223 
4224         this->write(this->assembleName(field.fName));
4225         this->write(": ");
4226         if (const FieldPolyfillInfo* info = fFieldPolyfillMap.find(&field)) {
4227             if (info->fIsArray) {
4228                 // This properly handles arrays of matrices, as well as arrays of other primitives.
4229                 SkASSERT(field.fType->isArray());
4230                 this->write("array<_skArrayElement_");
4231                 this->write(field.fType->abbreviatedName());
4232                 this->write(", ");
4233                 this->write(std::to_string(field.fType->columns()));
4234                 this->write(">");
4235             } else if (info->fIsMatrix) {
4236                 this->write("_skMatrix");
4237                 this->write(std::to_string(field.fType->columns()));
4238                 this->write(std::to_string(field.fType->rows()));
4239             } else {
4240                 SkDEBUGFAILF("need polyfill for %s", info->fReplacementName.c_str());
4241             }
4242         } else {
4243             this->write(to_wgsl_type(fContext, *field.fType, &field.fLayout));
4244         }
4245         this->writeLine(",");
4246     }
4247 
4248     fIndentation--;
4249 }
4250 
writeEnables()4251 void WGSLCodeGenerator::writeEnables() {
4252     this->writeLine("diagnostic(off, derivative_uniformity);");
4253     this->writeLine("diagnostic(off, chromium.unreachable_code);");
4254 
4255     if (fRequirements.fPixelLocalExtension) {
4256         this->writeLine("enable chromium_experimental_pixel_local;");
4257     }
4258     if (fProgram.fInterface.fUseLastFragColor) {
4259         this->writeLine("enable chromium_experimental_framebuffer_fetch;");
4260     }
4261     if (fProgram.fInterface.fOutputSecondaryColor) {
4262         this->writeLine("enable dual_source_blending;");
4263     }
4264 }
4265 
needsStageInputStruct() const4266 bool WGSLCodeGenerator::needsStageInputStruct() const {
4267     // It is illegal to declare a struct with no members; we can't emit a placeholder empty stage
4268     // input struct.
4269     return !fPipelineInputs.empty();
4270 }
4271 
writeStageInputStruct()4272 void WGSLCodeGenerator::writeStageInputStruct() {
4273     if (!this->needsStageInputStruct()) {
4274         return;
4275     }
4276 
4277     std::string_view structNamePrefix = pipeline_struct_prefix(fProgram.fConfig->fKind);
4278     SkASSERT(!structNamePrefix.empty());
4279 
4280     this->write("struct ");
4281     this->write(structNamePrefix);
4282     this->writeLine("In {");
4283     fIndentation++;
4284 
4285     for (const Variable* v : fPipelineInputs) {
4286         if (v->type().isInterfaceBlock()) {
4287             for (const Field& f : v->type().fields()) {
4288                 this->writePipelineIODeclaration(f.fLayout, *f.fType, f.fName, Delimiter::kComma);
4289             }
4290         } else {
4291             this->writePipelineIODeclaration(v->layout(), v->type(), v->mangledName(),
4292                                              Delimiter::kComma);
4293         }
4294     }
4295 
4296     fIndentation--;
4297     this->writeLine("};");
4298 }
4299 
needsStageOutputStruct() const4300 bool WGSLCodeGenerator::needsStageOutputStruct() const {
4301     // It is illegal to declare a struct with no members. However, vertex programs will _always_
4302     // have an output stage in WGSL, because the spec requires them to emit `@builtin(position)`.
4303     // So we always synthesize a reference to `sk_Position` even if the program doesn't need it.
4304     return !fPipelineOutputs.empty() || ProgramConfig::IsVertex(fProgram.fConfig->fKind);
4305 }
4306 
writeStageOutputStruct()4307 void WGSLCodeGenerator::writeStageOutputStruct() {
4308     if (!this->needsStageOutputStruct()) {
4309         return;
4310     }
4311 
4312     std::string_view structNamePrefix = pipeline_struct_prefix(fProgram.fConfig->fKind);
4313     SkASSERT(!structNamePrefix.empty());
4314 
4315     this->write("struct ");
4316     this->write(structNamePrefix);
4317     this->writeLine("Out {");
4318     fIndentation++;
4319 
4320     bool declaredPositionBuiltin = false;
4321     bool requiresPointSizeBuiltin = false;
4322     for (const Variable* v : fPipelineOutputs) {
4323         if (v->type().isInterfaceBlock()) {
4324             for (const auto& f : v->type().fields()) {
4325                 this->writePipelineIODeclaration(f.fLayout, *f.fType, f.fName, Delimiter::kComma);
4326                 if (f.fLayout.fBuiltin == SK_POSITION_BUILTIN) {
4327                     declaredPositionBuiltin = true;
4328                 } else if (f.fLayout.fBuiltin == SK_POINTSIZE_BUILTIN) {
4329                     // sk_PointSize is explicitly not supported by `builtin_from_sksl_name` so
4330                     // writePipelineIODeclaration will never write it. We mark it here if the
4331                     // declaration is needed so we can synthesize it below.
4332                     requiresPointSizeBuiltin = true;
4333                 }
4334             }
4335         } else {
4336             this->writePipelineIODeclaration(v->layout(), v->type(), v->mangledName(),
4337                                              Delimiter::kComma);
4338         }
4339     }
4340 
4341     // A vertex program must include the `position` builtin in its entrypoint's return type.
4342     const bool positionBuiltinRequired = ProgramConfig::IsVertex(fProgram.fConfig->fKind);
4343     if (positionBuiltinRequired && !declaredPositionBuiltin) {
4344         this->writeLine("@builtin(position) sk_Position: vec4<f32>,");
4345     }
4346 
4347     fIndentation--;
4348     this->writeLine("};");
4349 
4350     // In WebGPU/WGSL, the vertex stage does not support a point-size output and the size
4351     // of a point primitive is always 1 pixel (see https://github.com/gpuweb/gpuweb/issues/332).
4352     //
4353     // There isn't anything we can do to emulate this correctly at this stage so we synthesize a
4354     // placeholder global variable that has no effect. Programs should not rely on sk_PointSize when
4355     // using the Dawn backend.
4356     if (ProgramConfig::IsVertex(fProgram.fConfig->fKind) && requiresPointSizeBuiltin) {
4357         this->writeLine("/* unsupported */ var<private> sk_PointSize: f32;");
4358     }
4359 }
4360 
prepareUniformPolyfillsForInterfaceBlock(const InterfaceBlock * interfaceBlock,std::string_view instanceName,MemoryLayout::Standard nativeLayout)4361 void WGSLCodeGenerator::prepareUniformPolyfillsForInterfaceBlock(
4362         const InterfaceBlock* interfaceBlock,
4363         std::string_view instanceName,
4364         MemoryLayout::Standard nativeLayout) {
4365     SkSL::MemoryLayout std140(MemoryLayout::Standard::k140);
4366     SkSL::MemoryLayout native(nativeLayout);
4367 
4368     const Type& structType = interfaceBlock->var()->type().componentType();
4369     for (const Field& field : structType.fields()) {
4370         const Type* type = field.fType;
4371         bool needsArrayPolyfill = false;
4372         bool needsMatrixPolyfill = false;
4373 
4374         auto isPolyfillableMatrixType = [&](const Type* type) {
4375             return type->isMatrix() && std140.stride(*type) != native.stride(*type);
4376         };
4377 
4378         if (isPolyfillableMatrixType(type)) {
4379             // Matrices will be represented as 16-byte aligned arrays in std140, and reconstituted
4380             // into proper matrices as they are later accessed. We need to synthesize polyfill.
4381             needsMatrixPolyfill = true;
4382         } else if (type->isArray() && !type->isUnsizedArray() &&
4383                    !type->componentType().isOpaque()) {
4384             const Type* innerType = &type->componentType();
4385             if (isPolyfillableMatrixType(innerType)) {
4386                 // Use a polyfill when the array contains a matrix that requires polyfill.
4387                 needsArrayPolyfill = true;
4388                 needsMatrixPolyfill = true;
4389             } else if (native.size(*innerType) < 16) {
4390                 // Use a polyfill when the array elements are smaller than 16 bytes, since std140
4391                 // will pad elements to a 16-byte stride.
4392                 needsArrayPolyfill = true;
4393             }
4394         }
4395 
4396         if (needsArrayPolyfill || needsMatrixPolyfill) {
4397             // Add a polyfill for this matrix type.
4398             FieldPolyfillInfo info;
4399             info.fInterfaceBlock = interfaceBlock;
4400             info.fReplacementName = "_skUnpacked_" + std::string(instanceName) + '_' +
4401                                     this->assembleName(field.fName);
4402             info.fIsArray = needsArrayPolyfill;
4403             info.fIsMatrix = needsMatrixPolyfill;
4404             fFieldPolyfillMap.set(&field, info);
4405         }
4406     }
4407 }
4408 
writeUniformsAndBuffers()4409 void WGSLCodeGenerator::writeUniformsAndBuffers() {
4410     for (const ProgramElement* e : fProgram.elements()) {
4411         // Iterate through the interface blocks.
4412         if (!e->is<InterfaceBlock>()) {
4413             continue;
4414         }
4415         const InterfaceBlock& ib = e->as<InterfaceBlock>();
4416 
4417         // Determine if this interface block holds uniforms, buffers, or something else (skip it).
4418         std::string_view addressSpace;
4419         std::string_view accessMode;
4420         MemoryLayout::Standard nativeLayout;
4421         if (ib.var()->modifierFlags().isUniform()) {
4422             addressSpace = "uniform";
4423             nativeLayout = MemoryLayout::Standard::kWGSLUniform_Base;
4424         } else if (ib.var()->modifierFlags().isBuffer()) {
4425             addressSpace = "storage";
4426             nativeLayout = MemoryLayout::Standard::kWGSLStorage_Base;
4427             accessMode = ib.var()->modifierFlags().isReadOnly() ? ", read" : ", read_write";
4428         } else {
4429             continue;
4430         }
4431 
4432         // If we have an anonymous interface block, assign a name like `_uniform0` or `_storage1`.
4433         std::string instanceName;
4434         if (ib.instanceName().empty()) {
4435             instanceName = "_" + std::string(addressSpace) + std::to_string(fScratchCount++);
4436             fInterfaceBlockNameMap[&ib.var()->type().componentType()] = instanceName;
4437         } else {
4438             instanceName = std::string(ib.instanceName());
4439         }
4440 
4441         this->prepareUniformPolyfillsForInterfaceBlock(&ib, instanceName, nativeLayout);
4442 
4443         // Create a struct to hold all of the fields from this InterfaceBlock.
4444         SkASSERT(!ib.typeName().empty());
4445         this->write("struct ");
4446         this->write(ib.typeName());
4447         this->writeLine(" {");
4448 
4449         // Find the struct type and fields used by this interface block.
4450         const Type& ibType = ib.var()->type().componentType();
4451         SkASSERT(ibType.isStruct());
4452 
4453         SkSpan<const Field> ibFields = ibType.fields();
4454         SkASSERT(!ibFields.empty());
4455 
4456         MemoryLayout layout(MemoryLayout::Standard::k140);
4457         this->writeFields(ibFields, &layout);
4458         this->writeLine("};");
4459         this->write("@group(");
4460         this->write(std::to_string(std::max(0, ib.var()->layout().fSet)));
4461         this->write(") @binding(");
4462         this->write(std::to_string(std::max(0, ib.var()->layout().fBinding)));
4463         this->write(") var<");
4464         this->write(addressSpace);
4465         this->write(accessMode);
4466         this->write("> ");
4467         this->write(instanceName);
4468         this->write(" : ");
4469         this->write(to_wgsl_type(fContext, ib.var()->type(), &ib.var()->layout()));
4470         this->writeLine(";");
4471     }
4472 }
4473 
writeNonBlockUniformsForTests()4474 void WGSLCodeGenerator::writeNonBlockUniformsForTests() {
4475     bool declaredUniformsStruct = false;
4476 
4477     for (const ProgramElement* e : fProgram.elements()) {
4478         if (e->is<GlobalVarDeclaration>()) {
4479             const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
4480             const Variable& var = *decls.varDeclaration().var();
4481             if (is_in_global_uniforms(var)) {
4482                 if (!declaredUniformsStruct) {
4483                     this->write("struct _GlobalUniforms {\n");
4484                     declaredUniformsStruct = true;
4485                 }
4486                 this->write("  ");
4487                 this->writeVariableDecl(var.layout(), var.type(), var.mangledName(),
4488                                         Delimiter::kComma);
4489             }
4490         }
4491     }
4492     if (declaredUniformsStruct) {
4493         int binding = fProgram.fConfig->fSettings.fDefaultUniformBinding;
4494         int set = fProgram.fConfig->fSettings.fDefaultUniformSet;
4495         this->write("};\n");
4496         this->write("@binding(" + std::to_string(binding) + ") ");
4497         this->write("@group(" + std::to_string(set) + ") ");
4498         this->writeLine("var<uniform> _globalUniforms: _GlobalUniforms;");
4499     }
4500 }
4501 
functionDependencyArgs(const FunctionDeclaration & f)4502 std::string WGSLCodeGenerator::functionDependencyArgs(const FunctionDeclaration& f) {
4503     WGSLFunctionDependencies* deps = fRequirements.fDependencies.find(&f);
4504     std::string args;
4505     if (deps && *deps) {
4506         const char* separator = "";
4507         if (*deps & WGSLFunctionDependency::kPipelineInputs) {
4508             args += "_stageIn";
4509             separator = ", ";
4510         }
4511         if (*deps & WGSLFunctionDependency::kPipelineOutputs) {
4512             args += separator;
4513             args += "_stageOut";
4514         }
4515     }
4516     return args;
4517 }
4518 
writeFunctionDependencyParams(const FunctionDeclaration & f)4519 bool WGSLCodeGenerator::writeFunctionDependencyParams(const FunctionDeclaration& f) {
4520     WGSLFunctionDependencies* deps = fRequirements.fDependencies.find(&f);
4521     if (!deps || !*deps) {
4522         return false;
4523     }
4524 
4525     std::string_view structNamePrefix = pipeline_struct_prefix(fProgram.fConfig->fKind);
4526     if (structNamePrefix.empty()) {
4527         return false;
4528     }
4529     const char* separator = "";
4530     if (*deps & WGSLFunctionDependency::kPipelineInputs) {
4531         this->write("_stageIn: ");
4532         separator = ", ";
4533         this->write(structNamePrefix);
4534         this->write("In");
4535     }
4536     if (*deps & WGSLFunctionDependency::kPipelineOutputs) {
4537         this->write(separator);
4538         this->write("_stageOut: ptr<function, ");
4539         this->write(structNamePrefix);
4540         this->write("Out>");
4541     }
4542     return true;
4543 }
4544 
4545 #if defined(SK_ENABLE_WGSL_VALIDATION)
validate_wgsl(ErrorReporter & reporter,const std::string & wgsl,std::string * warnings)4546 static bool validate_wgsl(ErrorReporter& reporter, const std::string& wgsl, std::string* warnings) {
4547     // Enable the WGSL optional features that Skia might rely on.
4548     tint::wgsl::reader::Options options;
4549     for (auto extension : {tint::wgsl::Extension::kChromiumExperimentalPixelLocal,
4550                            tint::wgsl::Extension::kDualSourceBlending}) {
4551         options.allowed_features.extensions.insert(extension);
4552     }
4553 
4554     // Verify that the WGSL we produced is valid.
4555     tint::Source::File srcFile("", wgsl);
4556     tint::Program program(tint::wgsl::reader::Parse(&srcFile, options));
4557 
4558     if (program.Diagnostics().ContainsErrors()) {
4559         // The program isn't valid WGSL.
4560 #if defined(SKSL_STANDALONE)
4561         reporter.error(Position(), std::string("Tint compilation failed.\n\n") + wgsl);
4562 #else
4563         // In debug, report the error via SkDEBUGFAIL. We also append the generated program for
4564         // ease of debugging.
4565         tint::diag::Formatter diagFormatter;
4566         std::string diagOutput = diagFormatter.Format(program.Diagnostics()).Plain();
4567         diagOutput += "\n";
4568         diagOutput += wgsl;
4569         SkDEBUGFAILF("%s", diagOutput.c_str());
4570 #endif
4571         return false;
4572     }
4573 
4574     if (!program.Diagnostics().empty()) {
4575         // The program contains warnings. Report them as-is.
4576         tint::diag::Formatter diagFormatter;
4577         *warnings = diagFormatter.Format(program.Diagnostics()).Plain();
4578     }
4579     return true;
4580 }
4581 #endif  // defined(SK_ENABLE_WGSL_VALIDATION)
4582 
ToWGSL(Program & program,const ShaderCaps * caps,OutputStream & out)4583 bool ToWGSL(Program& program, const ShaderCaps* caps, OutputStream& out) {
4584     TRACE_EVENT0("skia.shaders", "SkSL::ToWGSL");
4585     SkASSERT(caps != nullptr);
4586 
4587     program.fContext->fErrors->setSource(*program.fSource);
4588 #ifdef SK_ENABLE_WGSL_VALIDATION
4589     StringStream wgsl;
4590     WGSLCodeGenerator cg(program.fContext.get(), caps, &program, &wgsl);
4591     bool result = cg.generateCode();
4592     if (result) {
4593         std::string wgslString = wgsl.str();
4594         std::string warnings;
4595         result = validate_wgsl(*program.fContext->fErrors, wgslString, &warnings);
4596         if (!warnings.empty()) {
4597             out.writeText("/* Tint reported warnings. */\n\n");
4598         }
4599         out.writeString(wgslString);
4600     }
4601 #else
4602     WGSLCodeGenerator cg(program.fContext.get(), caps, &program, &out);
4603     bool result = cg.generateCode();
4604 #endif
4605     program.fContext->fErrors->setSource(std::string_view());
4606 
4607     return result;
4608 }
4609 
ToWGSL(Program & program,const ShaderCaps * caps,std::string * out)4610 bool ToWGSL(Program& program, const ShaderCaps* caps, std::string* out) {
4611     StringStream buffer;
4612     if (!ToWGSL(program, caps, buffer)) {
4613         return false;
4614     }
4615     *out = buffer.str();
4616     return true;
4617 }
4618 
4619 }  // namespace SkSL
4620