1 /*
2 * Copyright 2022 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8 #include "src/sksl/codegen/SkSLWGSLCodeGenerator.h"
9
10 #include "include/core/SkSpan.h"
11 #include "include/core/SkTypes.h"
12 #include "include/private/base/SkTArray.h"
13 #include "include/private/base/SkTo.h"
14 #include "src/base/SkEnumBitMask.h"
15 #include "src/base/SkStringView.h"
16 #include "src/core/SkTHash.h"
17 #include "src/core/SkTraceEvent.h"
18 #include "src/sksl/SkSLAnalysis.h"
19 #include "src/sksl/SkSLBuiltinTypes.h"
20 #include "src/sksl/SkSLCompiler.h"
21 #include "src/sksl/SkSLConstantFolder.h"
22 #include "src/sksl/SkSLContext.h"
23 #include "src/sksl/SkSLDefines.h"
24 #include "src/sksl/SkSLErrorReporter.h"
25 #include "src/sksl/SkSLIntrinsicList.h"
26 #include "src/sksl/SkSLMemoryLayout.h"
27 #include "src/sksl/SkSLOperator.h"
28 #include "src/sksl/SkSLOutputStream.h"
29 #include "src/sksl/SkSLPosition.h"
30 #include "src/sksl/SkSLProgramSettings.h"
31 #include "src/sksl/SkSLString.h"
32 #include "src/sksl/SkSLStringStream.h"
33 #include "src/sksl/SkSLUtil.h"
34 #include "src/sksl/analysis/SkSLProgramUsage.h"
35 #include "src/sksl/analysis/SkSLProgramVisitor.h"
36 #include "src/sksl/codegen/SkSLCodeGenerator.h"
37 #include "src/sksl/ir/SkSLBinaryExpression.h"
38 #include "src/sksl/ir/SkSLBlock.h"
39 #include "src/sksl/ir/SkSLConstructor.h"
40 #include "src/sksl/ir/SkSLConstructorArrayCast.h"
41 #include "src/sksl/ir/SkSLConstructorCompound.h"
42 #include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
43 #include "src/sksl/ir/SkSLConstructorMatrixResize.h"
44 #include "src/sksl/ir/SkSLDoStatement.h"
45 #include "src/sksl/ir/SkSLExpression.h"
46 #include "src/sksl/ir/SkSLExpressionStatement.h"
47 #include "src/sksl/ir/SkSLFieldAccess.h"
48 #include "src/sksl/ir/SkSLForStatement.h"
49 #include "src/sksl/ir/SkSLFunctionCall.h"
50 #include "src/sksl/ir/SkSLFunctionDeclaration.h"
51 #include "src/sksl/ir/SkSLFunctionDefinition.h"
52 #include "src/sksl/ir/SkSLIRNode.h"
53 #include "src/sksl/ir/SkSLIfStatement.h"
54 #include "src/sksl/ir/SkSLIndexExpression.h"
55 #include "src/sksl/ir/SkSLInterfaceBlock.h"
56 #include "src/sksl/ir/SkSLLayout.h"
57 #include "src/sksl/ir/SkSLLiteral.h"
58 #include "src/sksl/ir/SkSLModifierFlags.h"
59 #include "src/sksl/ir/SkSLModifiersDeclaration.h"
60 #include "src/sksl/ir/SkSLPostfixExpression.h"
61 #include "src/sksl/ir/SkSLPrefixExpression.h"
62 #include "src/sksl/ir/SkSLProgram.h"
63 #include "src/sksl/ir/SkSLProgramElement.h"
64 #include "src/sksl/ir/SkSLReturnStatement.h"
65 #include "src/sksl/ir/SkSLSetting.h"
66 #include "src/sksl/ir/SkSLStatement.h"
67 #include "src/sksl/ir/SkSLStructDefinition.h"
68 #include "src/sksl/ir/SkSLSwitchCase.h"
69 #include "src/sksl/ir/SkSLSwitchStatement.h"
70 #include "src/sksl/ir/SkSLSwizzle.h"
71 #include "src/sksl/ir/SkSLTernaryExpression.h"
72 #include "src/sksl/ir/SkSLType.h"
73 #include "src/sksl/ir/SkSLVarDeclarations.h"
74 #include "src/sksl/ir/SkSLVariable.h"
75 #include "src/sksl/ir/SkSLVariableReference.h"
76 #include "src/sksl/spirv.h"
77 #include "src/sksl/transform/SkSLTransform.h"
78
79 #include <algorithm>
80 #include <cstddef>
81 #include <cstdint>
82 #include <initializer_list>
83 #include <iterator>
84 #include <memory>
85 #include <optional>
86 #include <string>
87 #include <string_view>
88 #include <utility>
89
90 #ifdef SK_ENABLE_WGSL_VALIDATION
91 #include "tint/tint.h"
92 #include "src/tint/lang/wgsl/reader/options.h"
93 #include "src/tint/lang/wgsl/extension.h"
94 #endif
95
96 using namespace skia_private;
97
98 namespace {
99
100 // Represents a function's dependencies that are not accessible in global scope. For instance,
101 // pipeline stage input and output parameters must be passed in as an argument.
102 //
103 // This is a bitmask enum. (It would be inside `class WGSLCodeGenerator`, but this leads to build
104 // errors in MSVC.)
105 enum class WGSLFunctionDependency : uint8_t {
106 kNone = 0,
107 kPipelineInputs = 1 << 0,
108 kPipelineOutputs = 1 << 1,
109 };
110 using WGSLFunctionDependencies = SkEnumBitMask<WGSLFunctionDependency>;
111
112 SK_MAKE_BITMASK_OPS(WGSLFunctionDependency)
113
114 } // namespace
115
116 namespace SkSL {
117
118 class WGSLCodeGenerator : public CodeGenerator {
119 public:
120 // See https://www.w3.org/TR/WGSL/#builtin-values
121 enum class Builtin {
122 // Vertex stage:
123 kVertexIndex, // input
124 kInstanceIndex, // input
125 kPosition, // output, fragment stage input
126
127 // Fragment stage:
128 kLastFragColor, // input
129 kFrontFacing, // input
130 kSampleIndex, // input
131 kFragDepth, // output
132 kSampleMaskIn, // input
133 kSampleMask, // output
134
135 // Compute stage:
136 kLocalInvocationId, // input
137 kLocalInvocationIndex, // input
138 kGlobalInvocationId, // input
139 kWorkgroupId, // input
140 kNumWorkgroups, // input
141 };
142
143 // Variable declarations can be terminated by:
144 // - comma (","), e.g. in struct member declarations or function parameters
145 // - semicolon (";"), e.g. in function scope variables
146 // A "none" option is provided to skip the delimiter when not needed, e.g. at the end of a list
147 // of declarations.
148 enum class Delimiter {
149 kComma,
150 kSemicolon,
151 kNone,
152 };
153
154 struct ProgramRequirements {
155 using DepsMap = skia_private::THashMap<const FunctionDeclaration*,
156 WGSLFunctionDependencies>;
157
158 // Mappings used to synthesize function parameters according to dependencies on pipeline
159 // input/output variables.
160 DepsMap fDependencies;
161
162 // These flags track extensions that will need to be enabled.
163 bool fPixelLocalExtension = false;
164 };
165
WGSLCodeGenerator(const Context * context,const ShaderCaps * caps,const Program * program,OutputStream * out)166 WGSLCodeGenerator(const Context* context,
167 const ShaderCaps* caps,
168 const Program* program,
169 OutputStream* out)
170 : INHERITED(context, caps, program, out) {}
171
172 bool generateCode() override;
173
174 private:
175 using INHERITED = CodeGenerator;
176 using Precedence = OperatorPrecedence;
177
178 // Called by generateCode() as the first step.
179 void preprocessProgram();
180
181 // Write output content while correctly handling indentation.
182 void write(std::string_view s);
183 void writeLine(std::string_view s = std::string_view());
184 void finishLine();
185
186 // Helpers to declare a pipeline stage IO parameter declaration.
187 void writePipelineIODeclaration(const Layout& layout,
188 const Type& type,
189 std::string_view name,
190 Delimiter delimiter);
191 void writeUserDefinedIODecl(const Layout& layout,
192 const Type& type,
193 std::string_view name,
194 Delimiter delimiter);
195 void writeBuiltinIODecl(const Type& type,
196 std::string_view name,
197 Builtin builtin,
198 Delimiter delimiter);
199 void writeVariableDecl(const Layout& layout,
200 const Type& type,
201 std::string_view name,
202 Delimiter delimiter);
203
204 // Write a function definition.
205 void writeFunction(const FunctionDefinition& f);
206 void writeFunctionDeclaration(const FunctionDeclaration& f,
207 SkSpan<const bool> paramNeedsDedicatedStorage);
208
209 // Write the program entry point.
210 void writeEntryPoint(const FunctionDefinition& f);
211
212 // Writers for supported statement types.
213 void writeStatement(const Statement& s);
214 void writeStatements(const StatementArray& statements);
215 void writeBlock(const Block& b);
216 void writeDoStatement(const DoStatement& expr);
217 void writeExpressionStatement(const Expression& expr);
218 void writeForStatement(const ForStatement& s);
219 void writeIfStatement(const IfStatement& s);
220 void writeReturnStatement(const ReturnStatement& s);
221 void writeSwitchStatement(const SwitchStatement& s);
222 void writeSwitchCases(SkSpan<const SwitchCase* const> cases);
223 void writeEmulatedSwitchFallthroughCases(SkSpan<const SwitchCase* const> cases,
224 std::string_view switchValue);
225 void writeSwitchCaseList(SkSpan<const SwitchCase* const> cases);
226 void writeVarDeclaration(const VarDeclaration& varDecl);
227
228 // Synthesizes an LValue for an expression.
229 class LValue;
230 class PointerLValue;
231 class SwizzleLValue;
232 class VectorComponentLValue;
233 std::unique_ptr<LValue> makeLValue(const Expression& e);
234
235 std::string variableReferenceNameForLValue(const VariableReference& r);
236 std::string variablePrefix(const Variable& v);
237
238 bool binaryOpNeedsComponentwiseMatrixPolyfill(const Type& left, const Type& right, Operator op);
239
240 // Writers for expressions. These return the final expression text as a string, and emit any
241 // necessary setup code directly into the program as necessary. The returned expression may be
242 // a `let`-alias that cannot be assigned-into; use `makeLValue` for an assignable expression.
243 std::string assembleExpression(const Expression& e, Precedence parentPrecedence);
244 std::string assembleBinaryExpression(const BinaryExpression& b, Precedence parentPrecedence);
245 std::string assembleBinaryExpression(const Expression& left,
246 Operator op,
247 const Expression& right,
248 const Type& resultType,
249 Precedence parentPrecedence);
250 std::string assembleFieldAccess(const FieldAccess& f);
251 std::string assembleFunctionCall(const FunctionCall& call, Precedence parentPrecedence);
252 std::string assembleIndexExpression(const IndexExpression& i);
253 std::string assembleLiteral(const Literal& l);
254 std::string assemblePostfixExpression(const PostfixExpression& p, Precedence parentPrecedence);
255 std::string assemblePrefixExpression(const PrefixExpression& p, Precedence parentPrecedence);
256 std::string assembleSwizzle(const Swizzle& swizzle);
257 std::string assembleTernaryExpression(const TernaryExpression& t, Precedence parentPrecedence);
258 std::string assembleVariableReference(const VariableReference& r);
259 std::string assembleName(std::string_view name);
260
261 std::string assembleIncrementExpr(const Type& type);
262
263 // Intrinsic helper functions.
264 std::string assembleIntrinsicCall(const FunctionCall& call,
265 IntrinsicKind kind,
266 Precedence parentPrecedence);
267 std::string assembleSimpleIntrinsic(std::string_view intrinsicName, const FunctionCall& call);
268 std::string assembleUnaryOpIntrinsic(Operator op,
269 const FunctionCall& call,
270 Precedence parentPrecedence);
271 std::string assembleBinaryOpIntrinsic(Operator op,
272 const FunctionCall& call,
273 Precedence parentPrecedence);
274 std::string assembleVectorizedIntrinsic(std::string_view intrinsicName,
275 const FunctionCall& call);
276 std::string assembleOutAssignedIntrinsic(std::string_view intrinsicName,
277 std::string_view returnFieldName,
278 std::string_view outFieldName,
279 const FunctionCall& call);
280 std::string assemblePartialSampleCall(std::string_view intrinsicName,
281 const Expression& sampler,
282 const Expression& coords);
283 std::string assembleInversePolyfill(const FunctionCall& call);
284 std::string assembleComponentwiseMatrixBinary(const Type& leftType,
285 const Type& rightType,
286 const std::string& left,
287 const std::string& right,
288 Operator op);
289
290 // Constructor expressions
291 std::string assembleAnyConstructor(const AnyConstructor& c);
292 std::string assembleConstructorCompound(const ConstructorCompound& c);
293 std::string assembleConstructorCompoundVector(const ConstructorCompound& c);
294 std::string assembleConstructorCompoundMatrix(const ConstructorCompound& c);
295 std::string assembleConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c);
296 std::string assembleConstructorMatrixResize(const ConstructorMatrixResize& ctor);
297
298 // Synthesized helper functions for comparison operators that are not supported by WGSL.
299 std::string assembleEqualityExpression(const Type& left,
300 const std::string& leftName,
301 const Type& right,
302 const std::string& rightName,
303 Operator op,
304 Precedence parentPrecedence);
305 std::string assembleEqualityExpression(const Expression& left,
306 const Expression& right,
307 Operator op,
308 Precedence parentPrecedence);
309
310 // Writes a scratch variable into the program and returns its name (e.g. `_skTemp123`).
311 std::string writeScratchVar(const Type& type, const std::string& value = "");
312
313 // Writes a scratch let-variable into the program, gives it the value of `expr`, and returns its
314 // name (e.g. `_skTemp123`).
315 std::string writeScratchLet(const std::string& expr);
316 std::string writeScratchLet(const Expression& expr, Precedence parentPrecedence);
317
318 // Converts `expr` into a string and returns a scratch let-variable associated with the
319 // expression. Compile-time constants and plain variable references will return the expression
320 // directly and omit the let-variable.
321 std::string writeNontrivialScratchLet(const Expression& expr, Precedence parentPrecedence);
322
323 // Generic recursive ProgramElement visitor.
324 void writeProgramElement(const ProgramElement& e);
325 void writeGlobalVarDeclaration(const GlobalVarDeclaration& d);
326 void writeStructDefinition(const StructDefinition& s);
327 void writeModifiersDeclaration(const ModifiersDeclaration&);
328
329 // Writes the WGSL struct fields for SkSL structs and interface blocks. Enforces WGSL address
330 // space layout constraints
331 // (https://www.w3.org/TR/WGSL/#address-space-layout-constraints) if a `layout` is
332 // provided. A struct that does not need to be host-shareable does not require a `layout`.
333 void writeFields(SkSpan<const Field> fields, const MemoryLayout* memoryLayout = nullptr);
334
335 // We bundle uniforms, and all varying pipeline stage inputs and outputs, into separate structs.
336 bool needsStageInputStruct() const;
337 void writeStageInputStruct();
338 bool needsStageOutputStruct() const;
339 void writeStageOutputStruct();
340 void writeUniformsAndBuffers();
341 void prepareUniformPolyfillsForInterfaceBlock(const InterfaceBlock* interfaceBlock,
342 std::string_view instanceName,
343 MemoryLayout::Standard nativeLayout);
344 void writeEnables();
345 void writeUniformPolyfills();
346
347 void writeTextureOrSampler(const Variable& var,
348 int bindingLocation,
349 std::string_view suffix,
350 std::string_view wgslType);
351
352 // Writes all top-level non-opaque global uniform declarations (i.e. not part of an interface
353 // block) into a single uniform block binding.
354 //
355 // In complete fragment/vertex/compute programs, uniforms will be declared only as interface
356 // blocks and global opaque types (like textures and samplers) which we expect to be declared
357 // with a unique binding and descriptor set index. However, test files that are declared as RTE
358 // programs may contain OpenGL-style global uniform declarations with no clear binding index to
359 // use for the containing synthesized block.
360 //
361 // Since we are handling these variables only to generate gold files from RTEs and never run
362 // them, we always declare them at the default bind group and binding index.
363 void writeNonBlockUniformsForTests();
364
365 // For a given function declaration, writes out any implicitly required pipeline stage arguments
366 // based on the function's pre-determined dependencies. These are expected to be written out as
367 // the first parameters for a function that requires them. Returns true if any arguments were
368 // written.
369 std::string functionDependencyArgs(const FunctionDeclaration&);
370 bool writeFunctionDependencyParams(const FunctionDeclaration&);
371
372 // Code in the header appears before the main body of code.
373 StringStream fHeader;
374
375 // We assign unique names to anonymous interface blocks based on the type.
376 skia_private::THashMap<const Type*, std::string> fInterfaceBlockNameMap;
377
378 // Stores the functions which use stage inputs/outputs as well as required WGSL extensions.
379 ProgramRequirements fRequirements;
380 skia_private::TArray<const Variable*> fPipelineInputs;
381 skia_private::TArray<const Variable*> fPipelineOutputs;
382
383 // These fields track whether we have written the polyfill for `inverse()` for a given matrix
384 // type.
385 bool fWrittenInverse2 = false;
386 bool fWrittenInverse3 = false;
387 bool fWrittenInverse4 = false;
388
389 // These fields control uniform polyfill support in cases where WGSL and std140 disagree.
390 // In std140 layout, matrices need to be represented as arrays of @size(16)-aligned vectors, and
391 // array elements are wrapped in a struct containing a single @size(16)-aligned element. Arrays
392 // of matrices combine both wrappers. These wrapper structs are unpacked into natively-typed
393 // globals at the shader entrypoint.
394 struct FieldPolyfillInfo {
395 const InterfaceBlock* fInterfaceBlock;
396 std::string fReplacementName;
397 bool fIsArray = false;
398 bool fIsMatrix = false;
399 bool fWasAccessed = false;
400 };
401 using FieldPolyfillMap = skia_private::THashMap<const Field*, FieldPolyfillInfo>;
402 FieldPolyfillMap fFieldPolyfillMap;
403
404 // Output processing state.
405 int fIndentation = 0;
406 bool fAtLineStart = false;
407 bool fHasUnconditionalReturn = false;
408 bool fAtFunctionScope = false;
409 int fConditionalScopeDepth = 0;
410 int fLocalSizeX = 1;
411 int fLocalSizeY = 1;
412 int fLocalSizeZ = 1;
413
414 int fScratchCount = 0;
415 };
416
417 enum class ProgramKind : int8_t;
418
419 namespace {
420
421 static constexpr char kSamplerSuffix[] = "_Sampler";
422 static constexpr char kTextureSuffix[] = "_Texture";
423
424 // See https://www.w3.org/TR/WGSL/#memory-view-types
425 enum class PtrAddressSpace {
426 kFunction,
427 kPrivate,
428 kStorage,
429 };
430
operator_name(Operator op)431 const char* operator_name(Operator op) {
432 switch (op.kind()) {
433 case Operator::Kind::LOGICALXOR: return " != ";
434 default: return op.operatorName();
435 }
436 }
437
is_reserved_word(std::string_view word)438 bool is_reserved_word(std::string_view word) {
439 static const THashSet<std::string_view> kReservedWords{
440 // Used by SkSL:
441 "FSIn",
442 "FSOut",
443 "VSIn",
444 "VSOut",
445 "CSIn",
446 "_globalUniforms",
447 "_GlobalUniforms",
448 "_return",
449 "_stageIn",
450 "_stageOut",
451 // Keywords: https://www.w3.org/TR/WGSL/#keyword-summary
452 "alias",
453 "break",
454 "case",
455 "const",
456 "const_assert",
457 "continue",
458 "continuing",
459 "default",
460 "diagnostic",
461 "discard",
462 "else",
463 "enable",
464 "false",
465 "fn",
466 "for",
467 "if",
468 "let",
469 "loop",
470 "override",
471 "requires",
472 "return",
473 "struct",
474 "switch",
475 "true",
476 "var",
477 "while",
478 // Pre-declared types: https://www.w3.org/TR/WGSL/#predeclared-types
479 "bool",
480 "f16",
481 "f32",
482 "i32",
483 "u32",
484 // ... and pre-declared type generators:
485 "array",
486 "atomic",
487 "mat2x2",
488 "mat2x3",
489 "mat2x4",
490 "mat3x2",
491 "mat3x3",
492 "mat3x4",
493 "mat4x2",
494 "mat4x3",
495 "mat4x4",
496 "ptr",
497 "texture_1d",
498 "texture_2d",
499 "texture_2d_array",
500 "texture_3d",
501 "texture_cube",
502 "texture_cube_array",
503 "texture_multisampled_2d",
504 "texture_storage_1d",
505 "texture_storage_2d",
506 "texture_storage_2d_array",
507 "texture_storage_3d",
508 "vec2",
509 "vec3",
510 "vec4",
511 // Pre-declared enumerants: https://www.w3.org/TR/WGSL/#predeclared-enumerants
512 "read",
513 "write",
514 "read_write",
515 "function",
516 "private",
517 "workgroup",
518 "uniform",
519 "storage",
520 "perspective",
521 "linear",
522 "flat",
523 "center",
524 "centroid",
525 "sample",
526 "vertex_index",
527 "instance_index",
528 "position",
529 "front_facing",
530 "frag_depth",
531 "local_invocation_id",
532 "local_invocation_index",
533 "global_invocation_id",
534 "workgroup_id",
535 "num_workgroups",
536 "sample_index",
537 "sample_mask",
538 "rgba8unorm",
539 "rgba8snorm",
540 "rgba8uint",
541 "rgba8sint",
542 "rgba16uint",
543 "rgba16sint",
544 "rgba16float",
545 "r32uint",
546 "r32sint",
547 "r32float",
548 "rg32uint",
549 "rg32sint",
550 "rg32float",
551 "rgba32uint",
552 "rgba32sint",
553 "rgba32float",
554 "bgra8unorm",
555 // Reserved words: https://www.w3.org/TR/WGSL/#reserved-words
556 "_",
557 "NULL",
558 "Self",
559 "abstract",
560 "active",
561 "alignas",
562 "alignof",
563 "as",
564 "asm",
565 "asm_fragment",
566 "async",
567 "attribute",
568 "auto",
569 "await",
570 "become",
571 "binding_array",
572 "cast",
573 "catch",
574 "class",
575 "co_await",
576 "co_return",
577 "co_yield",
578 "coherent",
579 "column_major",
580 "common",
581 "compile",
582 "compile_fragment",
583 "concept",
584 "const_cast",
585 "consteval",
586 "constexpr",
587 "constinit",
588 "crate",
589 "debugger",
590 "decltype",
591 "delete",
592 "demote",
593 "demote_to_helper",
594 "do",
595 "dynamic_cast",
596 "enum",
597 "explicit",
598 "export",
599 "extends",
600 "extern",
601 "external",
602 "fallthrough",
603 "filter",
604 "final",
605 "finally",
606 "friend",
607 "from",
608 "fxgroup",
609 "get",
610 "goto",
611 "groupshared",
612 "highp",
613 "impl",
614 "implements",
615 "import",
616 "inline",
617 "instanceof",
618 "interface",
619 "layout",
620 "lowp",
621 "macro",
622 "macro_rules",
623 "match",
624 "mediump",
625 "meta",
626 "mod",
627 "module",
628 "move",
629 "mut",
630 "mutable",
631 "namespace",
632 "new",
633 "nil",
634 "noexcept",
635 "noinline",
636 "nointerpolation",
637 "noperspective",
638 "null",
639 "nullptr",
640 "of",
641 "operator",
642 "package",
643 "packoffset",
644 "partition",
645 "pass",
646 "patch",
647 "pixelfragment",
648 "precise",
649 "precision",
650 "premerge",
651 "priv",
652 "protected",
653 "pub",
654 "public",
655 "readonly",
656 "ref",
657 "regardless",
658 "register",
659 "reinterpret_cast",
660 "require",
661 "resource",
662 "restrict",
663 "self",
664 "set",
665 "shared",
666 "sizeof",
667 "smooth",
668 "snorm",
669 "static",
670 "static_assert",
671 "static_cast",
672 "std",
673 "subroutine",
674 "super",
675 "target",
676 "template",
677 "this",
678 "thread_local",
679 "throw",
680 "trait",
681 "try",
682 "type",
683 "typedef",
684 "typeid",
685 "typename",
686 "typeof",
687 "union",
688 "unless",
689 "unorm",
690 "unsafe",
691 "unsized",
692 "use",
693 "using",
694 "varying",
695 "virtual",
696 "volatile",
697 "wgsl",
698 "where",
699 "with",
700 "writeonly",
701 "yield",
702 };
703
704 return kReservedWords.contains(word);
705 }
706
pipeline_struct_prefix(ProgramKind kind)707 std::string_view pipeline_struct_prefix(ProgramKind kind) {
708 if (ProgramConfig::IsVertex(kind)) {
709 return "VS";
710 }
711 if (ProgramConfig::IsFragment(kind)) {
712 return "FS";
713 }
714 if (ProgramConfig::IsCompute(kind)) {
715 return "CS";
716 }
717 // Compute programs don't have stage-in/stage-out pipeline structs.
718 return "";
719 }
720
address_space_to_str(PtrAddressSpace addressSpace)721 std::string_view address_space_to_str(PtrAddressSpace addressSpace) {
722 switch (addressSpace) {
723 case PtrAddressSpace::kFunction:
724 return "function";
725 case PtrAddressSpace::kPrivate:
726 return "private";
727 case PtrAddressSpace::kStorage:
728 return "storage";
729 }
730 SkDEBUGFAIL("unsupported ptr address space");
731 return "unsupported";
732 }
733
to_scalar_type(const Type & type)734 std::string_view to_scalar_type(const Type& type) {
735 SkASSERT(type.typeKind() == Type::TypeKind::kScalar);
736 switch (type.numberKind()) {
737 // Floating-point numbers in WebGPU currently always have 32-bit footprint and
738 // relaxed-precision is not supported without extensions. f32 is the only floating-point
739 // number type in WGSL (see the discussion on https://github.com/gpuweb/gpuweb/issues/658).
740 case Type::NumberKind::kFloat:
741 return "f32";
742 case Type::NumberKind::kSigned:
743 return "i32";
744 case Type::NumberKind::kUnsigned:
745 return "u32";
746 case Type::NumberKind::kBoolean:
747 return "bool";
748 case Type::NumberKind::kNonnumeric:
749 [[fallthrough]];
750 default:
751 break;
752 }
753 return type.name();
754 }
755
756 // Convert a SkSL type to a WGSL type. Handles all plain types except structure types
757 // (see https://www.w3.org/TR/WGSL/#plain-types-section).
to_wgsl_type(const Context & context,const Type & raw,const Layout * layout=nullptr)758 std::string to_wgsl_type(const Context& context, const Type& raw, const Layout* layout = nullptr) {
759 const Type& type = raw.resolve().scalarTypeForLiteral();
760 switch (type.typeKind()) {
761 case Type::TypeKind::kScalar:
762 return std::string(to_scalar_type(type));
763
764 case Type::TypeKind::kAtomic:
765 SkASSERT(type.matches(*context.fTypes.fAtomicUInt));
766 return "atomic<u32>";
767
768 case Type::TypeKind::kVector: {
769 std::string_view ct = to_scalar_type(type.componentType());
770 return String::printf("vec%d<%.*s>", type.columns(), (int)ct.length(), ct.data());
771 }
772 case Type::TypeKind::kMatrix: {
773 std::string_view ct = to_scalar_type(type.componentType());
774 return String::printf("mat%dx%d<%.*s>",
775 type.columns(), type.rows(), (int)ct.length(), ct.data());
776 }
777 case Type::TypeKind::kArray: {
778 std::string result = "array<" + to_wgsl_type(context, type.componentType(), layout);
779 if (!type.isUnsizedArray()) {
780 result += ", ";
781 result += std::to_string(type.columns());
782 }
783 return result + '>';
784 }
785 case Type::TypeKind::kTexture: {
786 if (type.matches(*context.fTypes.fWriteOnlyTexture2D)) {
787 std::string result = "texture_storage_2d<";
788 // Write-only storage texture types require a pixel format, which is in the layout.
789 SkASSERT(layout);
790 LayoutFlags pixelFormat = layout->fFlags & LayoutFlag::kAllPixelFormats;
791 switch (pixelFormat.value()) {
792 case (int)LayoutFlag::kRGBA8:
793 return result + "rgba8unorm, write>";
794
795 case (int)LayoutFlag::kRGBA32F:
796 return result + "rgba32float, write>";
797
798 case (int)LayoutFlag::kR32F:
799 return result + "r32float, write>";
800
801 default:
802 // The front-end should have rejected this.
803 return result + "write>";
804 }
805 }
806 if (type.matches(*context.fTypes.fReadOnlyTexture2D)) {
807 return "texture_2d<f32>";
808 }
809 break;
810 }
811 default:
812 break;
813 }
814 return std::string(type.name());
815 }
816
to_ptr_type(const Context & context,const Type & type,const Layout * layout,PtrAddressSpace addressSpace=PtrAddressSpace::kFunction)817 std::string to_ptr_type(const Context& context,
818 const Type& type,
819 const Layout* layout,
820 PtrAddressSpace addressSpace = PtrAddressSpace::kFunction) {
821 return "ptr<" + std::string(address_space_to_str(addressSpace)) + ", " +
822 to_wgsl_type(context, type, layout) + '>';
823 }
824
wgsl_builtin_name(WGSLCodeGenerator::Builtin builtin)825 std::string_view wgsl_builtin_name(WGSLCodeGenerator::Builtin builtin) {
826 using Builtin = WGSLCodeGenerator::Builtin;
827 switch (builtin) {
828 case Builtin::kVertexIndex:
829 return "@builtin(vertex_index)";
830 case Builtin::kInstanceIndex:
831 return "@builtin(instance_index)";
832 case Builtin::kPosition:
833 return "@builtin(position)";
834 case Builtin::kLastFragColor:
835 return "@color(0)";
836 case Builtin::kFrontFacing:
837 return "@builtin(front_facing)";
838 case Builtin::kSampleIndex:
839 return "@builtin(sample_index)";
840 case Builtin::kFragDepth:
841 return "@builtin(frag_depth)";
842 case Builtin::kSampleMask:
843 case Builtin::kSampleMaskIn:
844 return "@builtin(sample_mask)";
845 case Builtin::kLocalInvocationId:
846 return "@builtin(local_invocation_id)";
847 case Builtin::kLocalInvocationIndex:
848 return "@builtin(local_invocation_index)";
849 case Builtin::kGlobalInvocationId:
850 return "@builtin(global_invocation_id)";
851 case Builtin::kWorkgroupId:
852 return "@builtin(workgroup_id)";
853 case Builtin::kNumWorkgroups:
854 return "@builtin(num_workgroups)";
855 default:
856 break;
857 }
858
859 SkDEBUGFAIL("unsupported builtin");
860 return "unsupported";
861 }
862
wgsl_builtin_type(WGSLCodeGenerator::Builtin builtin)863 std::string_view wgsl_builtin_type(WGSLCodeGenerator::Builtin builtin) {
864 using Builtin = WGSLCodeGenerator::Builtin;
865 switch (builtin) {
866 case Builtin::kVertexIndex:
867 return "u32";
868 case Builtin::kInstanceIndex:
869 return "u32";
870 case Builtin::kPosition:
871 return "vec4<f32>";
872 case Builtin::kLastFragColor:
873 return "vec4<f32>";
874 case Builtin::kFrontFacing:
875 return "bool";
876 case Builtin::kSampleIndex:
877 return "u32";
878 case Builtin::kFragDepth:
879 return "f32";
880 case Builtin::kSampleMask:
881 return "u32";
882 case Builtin::kSampleMaskIn:
883 return "u32";
884 case Builtin::kLocalInvocationId:
885 return "vec3<u32>";
886 case Builtin::kLocalInvocationIndex:
887 return "u32";
888 case Builtin::kGlobalInvocationId:
889 return "vec3<u32>";
890 case Builtin::kWorkgroupId:
891 return "vec3<u32>";
892 case Builtin::kNumWorkgroups:
893 return "vec3<u32>";
894 default:
895 break;
896 }
897
898 SkDEBUGFAIL("unsupported builtin");
899 return "unsupported";
900 }
901
902 // Some built-in variables have a type that differs from their SkSL counterpart (e.g. signed vs
903 // unsigned integer). We handle these cases with an explicit type conversion during a variable
904 // reference. Returns the WGSL type of the conversion target if conversion is needed, otherwise
905 // returns std::nullopt.
needs_builtin_type_conversion(const Variable & v)906 std::optional<std::string_view> needs_builtin_type_conversion(const Variable& v) {
907 switch (v.layout().fBuiltin) {
908 case SK_VERTEXID_BUILTIN:
909 case SK_INSTANCEID_BUILTIN:
910 return {"i32"};
911 default:
912 break;
913 }
914 return std::nullopt;
915 }
916
917 // Map a SkSL builtin flag to a WGSL builtin kind. Returns std::nullopt if `builtin` is not
918 // not supported for WGSL.
919 //
920 // Also see //src/sksl/sksl_vert.sksl and //src/sksl/sksl_frag.sksl for supported built-ins.
builtin_from_sksl_name(int builtin)921 std::optional<WGSLCodeGenerator::Builtin> builtin_from_sksl_name(int builtin) {
922 using Builtin = WGSLCodeGenerator::Builtin;
923 switch (builtin) {
924 case SK_POSITION_BUILTIN:
925 [[fallthrough]];
926 case SK_FRAGCOORD_BUILTIN:
927 return Builtin::kPosition;
928 case SK_VERTEXID_BUILTIN:
929 return Builtin::kVertexIndex;
930 case SK_INSTANCEID_BUILTIN:
931 return Builtin::kInstanceIndex;
932 case SK_LASTFRAGCOLOR_BUILTIN:
933 return Builtin::kLastFragColor;
934 case SK_CLOCKWISE_BUILTIN:
935 // TODO(skia:13092): While `front_facing` is the corresponding built-in, it does not
936 // imply a particular winding order. We correctly compute the face orientation based
937 // on how Skia configured the render pipeline for all references to this built-in
938 // variable (see `SkSL::Program::Interface::fRTFlipUniform`).
939 return Builtin::kFrontFacing;
940 case SK_SAMPLEMASKIN_BUILTIN:
941 return Builtin::kSampleMaskIn;
942 case SK_SAMPLEMASK_BUILTIN:
943 return Builtin::kSampleMask;
944 case SK_NUMWORKGROUPS_BUILTIN:
945 return Builtin::kNumWorkgroups;
946 case SK_WORKGROUPID_BUILTIN:
947 return Builtin::kWorkgroupId;
948 case SK_LOCALINVOCATIONID_BUILTIN:
949 return Builtin::kLocalInvocationId;
950 case SK_GLOBALINVOCATIONID_BUILTIN:
951 return Builtin::kGlobalInvocationId;
952 case SK_LOCALINVOCATIONINDEX_BUILTIN:
953 return Builtin::kLocalInvocationIndex;
954 default:
955 break;
956 }
957 return std::nullopt;
958 }
959
delimiter_to_str(WGSLCodeGenerator::Delimiter delimiter)960 const char* delimiter_to_str(WGSLCodeGenerator::Delimiter delimiter) {
961 using Delim = WGSLCodeGenerator::Delimiter;
962 switch (delimiter) {
963 case Delim::kComma:
964 return ",";
965 case Delim::kSemicolon:
966 return ";";
967 case Delim::kNone:
968 default:
969 break;
970 }
971 return "";
972 }
973
974 // FunctionDependencyResolver visits the IR tree rooted at a particular function definition and
975 // computes that function's dependencies on pipeline stage IO parameters. These are later used to
976 // synthesize arguments when writing out function definitions.
977 class FunctionDependencyResolver : public ProgramVisitor {
978 public:
979 using Deps = WGSLFunctionDependencies;
980 using DepsMap = WGSLCodeGenerator::ProgramRequirements::DepsMap;
981
FunctionDependencyResolver(const Program * p,const FunctionDeclaration * f,DepsMap * programDependencyMap)982 FunctionDependencyResolver(const Program* p,
983 const FunctionDeclaration* f,
984 DepsMap* programDependencyMap)
985 : fProgram(p), fFunction(f), fDependencyMap(programDependencyMap) {}
986
resolve()987 Deps resolve() {
988 fDeps = WGSLFunctionDependency::kNone;
989 this->visit(*fProgram);
990 return fDeps;
991 }
992
993 private:
visitProgramElement(const ProgramElement & p)994 bool visitProgramElement(const ProgramElement& p) override {
995 // Only visit the program that matches the requested function.
996 if (p.is<FunctionDefinition>() && &p.as<FunctionDefinition>().declaration() == fFunction) {
997 return INHERITED::visitProgramElement(p);
998 }
999 // Continue visiting other program elements.
1000 return false;
1001 }
1002
visitExpression(const Expression & e)1003 bool visitExpression(const Expression& e) override {
1004 if (e.is<VariableReference>()) {
1005 const VariableReference& v = e.as<VariableReference>();
1006 if (v.variable()->storage() == Variable::Storage::kGlobal) {
1007 ModifierFlags flags = v.variable()->modifierFlags();
1008 if (flags & ModifierFlag::kIn) {
1009 fDeps |= WGSLFunctionDependency::kPipelineInputs;
1010 }
1011 if (flags & ModifierFlag::kOut) {
1012 fDeps |= WGSLFunctionDependency::kPipelineOutputs;
1013 }
1014 }
1015 } else if (e.is<FunctionCall>()) {
1016 // The current function that we're processing (`fFunction`) inherits the dependencies of
1017 // functions that it makes calls to, because the pipeline stage IO parameters need to be
1018 // passed down as an argument.
1019 const FunctionCall& callee = e.as<FunctionCall>();
1020
1021 // Don't process a function again if we have already resolved it.
1022 Deps* found = fDependencyMap->find(&callee.function());
1023 if (found) {
1024 fDeps |= *found;
1025 } else {
1026 // Store the dependencies that have been discovered for the current function so far.
1027 // If `callee` directly or indirectly calls the current function, then this value
1028 // will prevent an infinite recursion.
1029 fDependencyMap->set(fFunction, fDeps);
1030
1031 // Separately traverse the called function's definition and determine its
1032 // dependencies.
1033 FunctionDependencyResolver resolver(fProgram, &callee.function(), fDependencyMap);
1034 Deps calleeDeps = resolver.resolve();
1035
1036 // Store the callee's dependencies in the global map to avoid processing
1037 // the function again for future calls.
1038 fDependencyMap->set(&callee.function(), calleeDeps);
1039
1040 // Add to the current function's dependencies.
1041 fDeps |= calleeDeps;
1042 }
1043 }
1044 return INHERITED::visitExpression(e);
1045 }
1046
1047 const Program* const fProgram;
1048 const FunctionDeclaration* const fFunction;
1049 DepsMap* const fDependencyMap;
1050 Deps fDeps = WGSLFunctionDependency::kNone;
1051
1052 using INHERITED = ProgramVisitor;
1053 };
1054
resolve_program_requirements(const Program * program)1055 WGSLCodeGenerator::ProgramRequirements resolve_program_requirements(const Program* program) {
1056 WGSLCodeGenerator::ProgramRequirements requirements;
1057
1058 for (const ProgramElement* e : program->elements()) {
1059 switch (e->kind()) {
1060 case ProgramElement::Kind::kFunction: {
1061 const FunctionDeclaration& decl = e->as<FunctionDefinition>().declaration();
1062
1063 FunctionDependencyResolver resolver(program, &decl, &requirements.fDependencies);
1064 requirements.fDependencies.set(&decl, resolver.resolve());
1065 break;
1066 }
1067 case ProgramElement::Kind::kGlobalVar: {
1068 const GlobalVarDeclaration& decl = e->as<GlobalVarDeclaration>();
1069 if (decl.varDeclaration().var()->modifierFlags().isPixelLocal()) {
1070 requirements.fPixelLocalExtension = true;
1071 }
1072 break;
1073 }
1074 default:
1075 break;
1076 }
1077 }
1078
1079 return requirements;
1080 }
1081
collect_pipeline_io_vars(const Program * program,TArray<const Variable * > * ioVars,ModifierFlag ioType)1082 void collect_pipeline_io_vars(const Program* program,
1083 TArray<const Variable*>* ioVars,
1084 ModifierFlag ioType) {
1085 for (const ProgramElement* e : program->elements()) {
1086 if (e->is<GlobalVarDeclaration>()) {
1087 const Variable* v = e->as<GlobalVarDeclaration>().varDeclaration().var();
1088 if (v->modifierFlags() & ioType) {
1089 ioVars->push_back(v);
1090 }
1091 } else if (e->is<InterfaceBlock>()) {
1092 const Variable* v = e->as<InterfaceBlock>().var();
1093 if (v->modifierFlags() & ioType) {
1094 ioVars->push_back(v);
1095 }
1096 }
1097 }
1098 }
1099
is_in_global_uniforms(const Variable & var)1100 bool is_in_global_uniforms(const Variable& var) {
1101 SkASSERT(var.storage() == VariableStorage::kGlobal);
1102 return var.modifierFlags().isUniform() &&
1103 !var.type().isOpaque() &&
1104 !var.interfaceBlock();
1105 }
1106
1107 } // namespace
1108
1109 class WGSLCodeGenerator::LValue {
1110 public:
1111 virtual ~LValue() = default;
1112
1113 // Returns a WGSL expression that loads from the lvalue with no side effects.
1114 // (e.g. `array[index].field`)
1115 virtual std::string load() = 0;
1116
1117 // Returns a WGSL statement that stores into the lvalue with no side effects.
1118 // (e.g. `array[index].field = the_passed_in_value_string;`)
1119 virtual std::string store(const std::string& value) = 0;
1120 };
1121
1122 class WGSLCodeGenerator::PointerLValue : public WGSLCodeGenerator::LValue {
1123 public:
1124 // `name` must be a WGSL expression with no side-effects, which we can safely take the address
1125 // of. (e.g. `array[index].field` would be valid, but `array[Func()]` or `vector.x` are not.)
PointerLValue(std::string name)1126 PointerLValue(std::string name) : fName(std::move(name)) {}
1127
load()1128 std::string load() override {
1129 return fName;
1130 }
1131
store(const std::string & value)1132 std::string store(const std::string& value) override {
1133 return fName + " = " + value + ";";
1134 }
1135
1136 private:
1137 std::string fName;
1138 };
1139
1140 class WGSLCodeGenerator::VectorComponentLValue : public WGSLCodeGenerator::LValue {
1141 public:
1142 // `name` must be a WGSL expression with no side-effects that points to a single component of a
1143 // WGSL vector.
VectorComponentLValue(std::string name)1144 VectorComponentLValue(std::string name) : fName(std::move(name)) {}
1145
load()1146 std::string load() override {
1147 return fName;
1148 }
1149
store(const std::string & value)1150 std::string store(const std::string& value) override {
1151 return fName + " = " + value + ";";
1152 }
1153
1154 private:
1155 std::string fName;
1156 };
1157
1158 class WGSLCodeGenerator::SwizzleLValue : public WGSLCodeGenerator::LValue {
1159 public:
1160 // `name` must be a WGSL expression with no side-effects that points to a WGSL vector.
SwizzleLValue(const Context & ctx,std::string name,const Type & t,const ComponentArray & c)1161 SwizzleLValue(const Context& ctx, std::string name, const Type& t, const ComponentArray& c)
1162 : fContext(ctx)
1163 , fName(std::move(name))
1164 , fType(t)
1165 , fComponents(c) {
1166 // If the component array doesn't cover the entire value, we need to create masks for
1167 // writing back into the lvalue. For example, if the type is vec4 and the component array
1168 // holds `zx`, a GLSL assignment would look like:
1169 // name.zx = new_value;
1170 //
1171 // The equivalent WGSL assignment statement would look like:
1172 // name = vec4<f32>(new_value, name.xw).yzxw;
1173 //
1174 // This replaces name.zy with new_value.xy, and leaves name.xw at their original values.
1175 // By convention, we always put the new value first and the original values second; it might
1176 // be possible to find better arrangements which simplify the assignment overall, but we
1177 // don't attempt this.
1178 int fullSlotCount = fType.slotCount();
1179 SkASSERT(fullSlotCount <= 4);
1180
1181 // First, see which components are used.
1182 // The assignment swizzle must not reuse components.
1183 bool used[4] = {};
1184 for (int8_t component : fComponents) {
1185 SkASSERT(!used[component]);
1186 used[component] = true;
1187 }
1188
1189 // Any untouched components will need to be fetched from the original value.
1190 for (int index = 0; index < fullSlotCount; ++index) {
1191 if (!used[index]) {
1192 fUntouchedComponents.push_back(index);
1193 }
1194 }
1195
1196 // The reintegration swizzle needs to move the components back into their proper slots.
1197 fReintegrationSwizzle.resize(fullSlotCount);
1198 int reintegrateIndex = 0;
1199
1200 // This refills the untouched slots with the original values.
1201 auto refillUntouchedSlots = [&] {
1202 for (int index = 0; index < fullSlotCount; ++index) {
1203 if (!used[index]) {
1204 fReintegrationSwizzle[index] = reintegrateIndex++;
1205 }
1206 }
1207 };
1208
1209 // This places the new-value components into the proper slots.
1210 auto insertNewValuesIntoSlots = [&] {
1211 for (int index = 0; index < fComponents.size(); ++index) {
1212 fReintegrationSwizzle[fComponents[index]] = reintegrateIndex++;
1213 }
1214 };
1215
1216 // When reintegrating the untouched and new values, if the `x` slot is overwritten, we
1217 // reintegrate the new value first. Otherwise, we reintegrate the original value first.
1218 // This increases our odds of getting an identity swizzle for the reintegration.
1219 if (used[0]) {
1220 fReintegrateNewValueFirst = true;
1221 insertNewValuesIntoSlots();
1222 refillUntouchedSlots();
1223 } else {
1224 fReintegrateNewValueFirst = false;
1225 refillUntouchedSlots();
1226 insertNewValuesIntoSlots();
1227 }
1228 }
1229
load()1230 std::string load() override {
1231 return fName + "." + Swizzle::MaskString(fComponents);
1232 }
1233
store(const std::string & value)1234 std::string store(const std::string& value) override {
1235 // `variable = `
1236 std::string result = fName;
1237 result += " = ";
1238
1239 if (fUntouchedComponents.empty()) {
1240 // `(new_value);`
1241 result += '(';
1242 result += value;
1243 result += ")";
1244 } else if (fReintegrateNewValueFirst) {
1245 // `vec4<f32>((new_value), `
1246 result += to_wgsl_type(fContext, fType);
1247 result += "((";
1248 result += value;
1249 result += "), ";
1250
1251 // `variable.yz)`
1252 result += fName;
1253 result += '.';
1254 result += Swizzle::MaskString(fUntouchedComponents);
1255 result += ')';
1256 } else {
1257 // `vec4<f32>(variable.yz`
1258 result += to_wgsl_type(fContext, fType);
1259 result += '(';
1260 result += fName;
1261 result += '.';
1262 result += Swizzle::MaskString(fUntouchedComponents);
1263
1264 // `, (new_value))`
1265 result += ", (";
1266 result += value;
1267 result += "))";
1268 }
1269
1270 if (!Swizzle::IsIdentity(fReintegrationSwizzle)) {
1271 // `.wzyx`
1272 result += '.';
1273 result += Swizzle::MaskString(fReintegrationSwizzle);
1274 }
1275
1276 return result + ';';
1277 }
1278
1279 private:
1280 const Context& fContext;
1281 std::string fName;
1282 const Type& fType;
1283 ComponentArray fComponents;
1284 ComponentArray fUntouchedComponents;
1285 ComponentArray fReintegrationSwizzle;
1286 bool fReintegrateNewValueFirst = false;
1287 };
1288
generateCode()1289 bool WGSLCodeGenerator::generateCode() {
1290 // The resources of a WGSL program are structured in the following way:
1291 // - Stage attribute inputs and outputs are bundled inside synthetic structs called
1292 // VSIn/VSOut/FSIn/FSOut/CSIn.
1293 // - All uniform and storage type resources are declared in global scope.
1294 this->preprocessProgram();
1295
1296 {
1297 AutoOutputStream outputToHeader(this, &fHeader, &fIndentation);
1298 this->writeEnables();
1299 this->writeStageInputStruct();
1300 this->writeStageOutputStruct();
1301 this->writeUniformsAndBuffers();
1302 this->writeNonBlockUniformsForTests();
1303 }
1304 StringStream body;
1305 {
1306 // Emit the program body.
1307 AutoOutputStream outputToBody(this, &body, &fIndentation);
1308 const FunctionDefinition* mainFunc = nullptr;
1309 for (const ProgramElement* e : fProgram.elements()) {
1310 this->writeProgramElement(*e);
1311
1312 if (e->is<FunctionDefinition>()) {
1313 const FunctionDefinition& func = e->as<FunctionDefinition>();
1314 if (func.declaration().isMain()) {
1315 mainFunc = &func;
1316 }
1317 }
1318 }
1319
1320 // At the bottom of the program body, emit the entrypoint function.
1321 // The entrypoint relies on state that has been collected while we emitted the rest of the
1322 // program, so it's important to do it last to make sure we don't miss anything.
1323 if (mainFunc) {
1324 this->writeEntryPoint(*mainFunc);
1325 }
1326 }
1327
1328 write_stringstream(fHeader, *fOut);
1329 write_stringstream(body, *fOut);
1330
1331 this->writeUniformPolyfills();
1332
1333 return fContext.fErrors->errorCount() == 0;
1334 }
1335
writeUniformPolyfills()1336 void WGSLCodeGenerator::writeUniformPolyfills() {
1337 // If we didn't encounter any uniforms that need polyfilling, there is nothing to do.
1338 if (fFieldPolyfillMap.empty()) {
1339 return;
1340 }
1341
1342 // We store the list of polyfilled fields as pointers in a hash-map, so the order can be
1343 // inconsistent across runs. For determinism, we sort the polyfilled objects by name here.
1344 TArray<const FieldPolyfillMap::Pair*> orderedFields;
1345 orderedFields.reserve_exact(fFieldPolyfillMap.count());
1346
1347 fFieldPolyfillMap.foreach([&](const FieldPolyfillMap::Pair& pair) {
1348 orderedFields.push_back(&pair);
1349 });
1350
1351 std::sort(orderedFields.begin(),
1352 orderedFields.end(),
1353 [](const FieldPolyfillMap::Pair* a, const FieldPolyfillMap::Pair* b) {
1354 return a->second.fReplacementName < b->second.fReplacementName;
1355 });
1356
1357 THashSet<const Type*> writtenArrayElementPolyfill;
1358 bool writtenUniformMatrixPolyfill[5][5] = {}; // m[column][row] for each matrix type
1359 bool writtenUniformRowPolyfill[5] = {}; // for each matrix row-size
1360 bool anyFieldAccessed = false;
1361 for (const FieldPolyfillMap::Pair* pair : orderedFields) {
1362 const auto& [field, info] = *pair;
1363 const Type* fieldType = field->fType;
1364 const Layout* fieldLayout = &field->fLayout;
1365
1366 if (info.fIsArray) {
1367 fieldType = &fieldType->componentType();
1368 if (!writtenArrayElementPolyfill.contains(fieldType)) {
1369 writtenArrayElementPolyfill.add(fieldType);
1370 this->write("struct _skArrayElement_");
1371 this->write(fieldType->abbreviatedName());
1372 this->writeLine(" {");
1373
1374 if (info.fIsMatrix) {
1375 // Create a struct representing the array containing std140-padded matrices.
1376 this->write(" e : _skMatrix");
1377 this->write(std::to_string(fieldType->columns()));
1378 this->writeLine(std::to_string(fieldType->rows()));
1379 } else {
1380 // Create a struct representing the array with extra padding between elements.
1381 this->write(" @size(16) e : ");
1382 this->writeLine(to_wgsl_type(fContext, *fieldType, fieldLayout));
1383 }
1384 this->writeLine("};");
1385 }
1386 }
1387
1388 if (info.fIsMatrix) {
1389 // Create structs representing the matrix as an array of vectors, whether or not the
1390 // matrix is ever accessed by the SkSL. (The struct itself is mentioned in the list of
1391 // uniforms.)
1392 int c = fieldType->columns();
1393 int r = fieldType->rows();
1394 if (!writtenUniformRowPolyfill[r]) {
1395 writtenUniformRowPolyfill[r] = true;
1396
1397 this->write("struct _skRow");
1398 this->write(std::to_string(r));
1399 this->writeLine(" {");
1400 this->write(" @size(16) r : vec");
1401 this->write(std::to_string(r));
1402 this->write("<");
1403 this->write(to_wgsl_type(fContext, fieldType->componentType(), fieldLayout));
1404 this->writeLine(">");
1405 this->writeLine("};");
1406 }
1407
1408 if (!writtenUniformMatrixPolyfill[c][r]) {
1409 writtenUniformMatrixPolyfill[c][r] = true;
1410
1411 this->write("struct _skMatrix");
1412 this->write(std::to_string(c));
1413 this->write(std::to_string(r));
1414 this->writeLine(" {");
1415 this->write(" c : array<_skRow");
1416 this->write(std::to_string(r));
1417 this->write(", ");
1418 this->write(std::to_string(c));
1419 this->writeLine(">");
1420 this->writeLine("};");
1421 }
1422 }
1423
1424 // We create a polyfill variable only if the uniform was actually accessed.
1425 if (!info.fWasAccessed) {
1426 continue;
1427 }
1428 anyFieldAccessed = true;
1429 this->write("var<private> ");
1430 this->write(info.fReplacementName);
1431 this->write(": ");
1432
1433 const Type& interfaceBlockType = info.fInterfaceBlock->var()->type();
1434 if (interfaceBlockType.isArray()) {
1435 this->write("array<");
1436 this->write(to_wgsl_type(fContext, *field->fType, fieldLayout));
1437 this->write(", ");
1438 this->write(std::to_string(interfaceBlockType.columns()));
1439 this->write(">");
1440 } else {
1441 this->write(to_wgsl_type(fContext, *field->fType, fieldLayout));
1442 }
1443 this->writeLine(";");
1444 }
1445
1446 // If no fields were actually accessed, _skInitializePolyfilledUniforms will not be called and
1447 // we can avoid emitting an empty, dead function.
1448 if (!anyFieldAccessed) {
1449 return;
1450 }
1451
1452 this->writeLine("fn _skInitializePolyfilledUniforms() {");
1453 ++fIndentation;
1454
1455 for (const FieldPolyfillMap::Pair* pair : orderedFields) {
1456 // Only initialize a polyfill global if the uniform was actually accessed.
1457 const auto& [field, info] = *pair;
1458 if (!info.fWasAccessed) {
1459 continue;
1460 }
1461
1462 // Synthesize the name of this uniform variable
1463 std::string_view instanceName = info.fInterfaceBlock->instanceName();
1464 const Type& interfaceBlockType = info.fInterfaceBlock->var()->type();
1465 if (instanceName.empty()) {
1466 instanceName = fInterfaceBlockNameMap[&interfaceBlockType.componentType()];
1467 }
1468
1469 // Initialize the global variable associated with this uniform.
1470 // If the interface block is arrayed, the associated global will be arrayed as well.
1471 int numIBElements = interfaceBlockType.isArray() ? interfaceBlockType.columns() : 1;
1472 for (int ibIdx = 0; ibIdx < numIBElements; ++ibIdx) {
1473 this->write(info.fReplacementName);
1474 if (interfaceBlockType.isArray()) {
1475 this->write("[");
1476 this->write(std::to_string(ibIdx));
1477 this->write("]");
1478 }
1479 this->write(" = ");
1480
1481 const Type* fieldType = field->fType;
1482 const Layout* fieldLayout = &field->fLayout;
1483
1484 int numArrayElements;
1485 if (info.fIsArray) {
1486 this->write(to_wgsl_type(fContext, *fieldType, fieldLayout));
1487 this->write("(");
1488 numArrayElements = fieldType->columns();
1489 fieldType = &fieldType->componentType();
1490 } else {
1491 numArrayElements = 1;
1492 }
1493
1494 auto arraySeparator = String::Separator();
1495 for (int arrayIdx = 0; arrayIdx < numArrayElements; arrayIdx++) {
1496 this->write(arraySeparator());
1497
1498 std::string fieldName{instanceName};
1499 if (interfaceBlockType.isArray()) {
1500 fieldName += '[';
1501 fieldName += std::to_string(ibIdx);
1502 fieldName += ']';
1503 }
1504 fieldName += '.';
1505 fieldName += this->assembleName(field->fName);
1506
1507 if (info.fIsArray) {
1508 fieldName += '[';
1509 fieldName += std::to_string(arrayIdx);
1510 fieldName += "].e";
1511 }
1512
1513 if (info.fIsMatrix) {
1514 this->write(to_wgsl_type(fContext, *fieldType, fieldLayout));
1515 this->write("(");
1516 int numColumns = fieldType->columns();
1517 auto matrixSeparator = String::Separator();
1518 for (int column = 0; column < numColumns; column++) {
1519 this->write(matrixSeparator());
1520 this->write(fieldName);
1521 this->write(".c[");
1522 this->write(std::to_string(column));
1523 this->write("].r");
1524 }
1525 this->write(")");
1526 } else {
1527 this->write(fieldName);
1528 }
1529 }
1530
1531 if (info.fIsArray) {
1532 this->write(")");
1533 }
1534
1535 this->writeLine(";");
1536 }
1537 }
1538
1539 --fIndentation;
1540 this->writeLine("}");
1541 }
1542
1543
preprocessProgram()1544 void WGSLCodeGenerator::preprocessProgram() {
1545 fRequirements = resolve_program_requirements(&fProgram);
1546 collect_pipeline_io_vars(&fProgram, &fPipelineInputs, ModifierFlag::kIn);
1547 collect_pipeline_io_vars(&fProgram, &fPipelineOutputs, ModifierFlag::kOut);
1548 }
1549
write(std::string_view s)1550 void WGSLCodeGenerator::write(std::string_view s) {
1551 if (s.empty()) {
1552 return;
1553 }
1554 #if defined(SK_DEBUG) || defined(SKSL_STANDALONE)
1555 if (fAtLineStart) {
1556 for (int i = 0; i < fIndentation; i++) {
1557 fOut->writeText(" ");
1558 }
1559 }
1560 #endif
1561 fOut->writeText(std::string(s).c_str());
1562 fAtLineStart = false;
1563 }
1564
writeLine(std::string_view s)1565 void WGSLCodeGenerator::writeLine(std::string_view s) {
1566 this->write(s);
1567 fOut->writeText("\n");
1568 fAtLineStart = true;
1569 }
1570
finishLine()1571 void WGSLCodeGenerator::finishLine() {
1572 if (!fAtLineStart) {
1573 this->writeLine();
1574 }
1575 }
1576
assembleName(std::string_view name)1577 std::string WGSLCodeGenerator::assembleName(std::string_view name) {
1578 if (name.empty()) {
1579 // WGSL doesn't allow anonymous function parameters.
1580 return "_skAnonymous" + std::to_string(fScratchCount++);
1581 }
1582 // Add `R_` before reserved names to avoid any potential reserved-word conflict.
1583 return (skstd::starts_with(name, "_sk") ||
1584 skstd::starts_with(name, "R_") ||
1585 is_reserved_word(name))
1586 ? std::string("R_") + std::string(name)
1587 : std::string(name);
1588 }
1589
writeVariableDecl(const Layout & layout,const Type & type,std::string_view name,Delimiter delimiter)1590 void WGSLCodeGenerator::writeVariableDecl(const Layout& layout,
1591 const Type& type,
1592 std::string_view name,
1593 Delimiter delimiter) {
1594 this->write(this->assembleName(name));
1595 this->write(": " + to_wgsl_type(fContext, type, &layout));
1596 this->writeLine(delimiter_to_str(delimiter));
1597 }
1598
writePipelineIODeclaration(const Layout & layout,const Type & type,std::string_view name,Delimiter delimiter)1599 void WGSLCodeGenerator::writePipelineIODeclaration(const Layout& layout,
1600 const Type& type,
1601 std::string_view name,
1602 Delimiter delimiter) {
1603 // In WGSL, an entry-point IO parameter is "one of either a built-in value or assigned a
1604 // location". However, some SkSL declarations, specifically sk_FragColor, can contain both a
1605 // location and a builtin modifier. In addition, WGSL doesn't have a built-in equivalent for
1606 // sk_FragColor as it relies on the user-defined location for a render target.
1607 //
1608 // Instead of special-casing sk_FragColor, we just give higher precedence to a location modifier
1609 // if a declaration happens to both have a location and it's a built-in.
1610 //
1611 // Also see:
1612 // https://www.w3.org/TR/WGSL/#input-output-locations
1613 // https://www.w3.org/TR/WGSL/#attribute-location
1614 // https://www.w3.org/TR/WGSL/#builtin-inputs-outputs
1615 if (layout.fLocation >= 0) {
1616 this->writeUserDefinedIODecl(layout, type, name, delimiter);
1617 return;
1618 }
1619 if (layout.fBuiltin >= 0) {
1620 if (layout.fBuiltin == SK_POINTSIZE_BUILTIN) {
1621 // WebGPU does not support the point-size builtin, but we silently replace it with a
1622 // global variable when it is used, instead of reporting an error.
1623 return;
1624 }
1625 auto builtin = builtin_from_sksl_name(layout.fBuiltin);
1626 if (builtin.has_value()) {
1627 this->writeBuiltinIODecl(type, name, *builtin, delimiter);
1628 return;
1629 }
1630 }
1631 fContext.fErrors->error(Position(), "declaration '" + std::string(name) + "' is not supported");
1632 }
1633
writeUserDefinedIODecl(const Layout & layout,const Type & type,std::string_view name,Delimiter delimiter)1634 void WGSLCodeGenerator::writeUserDefinedIODecl(const Layout& layout,
1635 const Type& type,
1636 std::string_view name,
1637 Delimiter delimiter) {
1638 this->write("@location(" + std::to_string(layout.fLocation) + ") ");
1639
1640 // @blend_src is only allowed when doing dual-source blending, and only on color attachment 0.
1641 if (layout.fLocation == 0 && layout.fIndex >= 0 && fProgram.fInterface.fOutputSecondaryColor) {
1642 this->write("@blend_src(" + std::to_string(layout.fIndex) + ") ");
1643 }
1644
1645 // "User-defined IO of scalar or vector integer type must always be specified as
1646 // @interpolate(flat)" (see https://www.w3.org/TR/WGSL/#interpolation)
1647 if (type.isInteger() || (type.isVector() && type.componentType().isInteger())) {
1648 this->write("@interpolate(flat) ");
1649 }
1650
1651 this->writeVariableDecl(layout, type, name, delimiter);
1652 }
1653
writeBuiltinIODecl(const Type & type,std::string_view name,Builtin builtin,Delimiter delimiter)1654 void WGSLCodeGenerator::writeBuiltinIODecl(const Type& type,
1655 std::string_view name,
1656 Builtin builtin,
1657 Delimiter delimiter) {
1658 this->write(wgsl_builtin_name(builtin));
1659 this->write(" ");
1660 this->write(this->assembleName(name));
1661 this->write(": ");
1662 this->write(wgsl_builtin_type(builtin));
1663 this->writeLine(delimiter_to_str(delimiter));
1664 }
1665
writeFunction(const FunctionDefinition & f)1666 void WGSLCodeGenerator::writeFunction(const FunctionDefinition& f) {
1667 const FunctionDeclaration& decl = f.declaration();
1668 fHasUnconditionalReturn = false;
1669 fConditionalScopeDepth = 0;
1670
1671 SkASSERT(!fAtFunctionScope);
1672 fAtFunctionScope = true;
1673
1674 // WGSL parameters are immutable and are considered as taking no storage, but SkSL parameters
1675 // are real variables. To work around this, we make var-based copies of parameters. It's
1676 // wasteful to make a copy of every single parameter--even if the compiler can eventually
1677 // optimize them all away, that takes time and generates bloated code. So, we only make
1678 // parameter copies if the variable is actually written-to.
1679 STArray<32, bool> paramNeedsDedicatedStorage;
1680 paramNeedsDedicatedStorage.push_back_n(decl.parameters().size(), true);
1681
1682 for (size_t index = 0; index < decl.parameters().size(); ++index) {
1683 const Variable& param = *decl.parameters()[index];
1684 if (param.type().isOpaque() || param.name().empty()) {
1685 // Opaque-typed or anonymous parameters don't need dedicated storage.
1686 paramNeedsDedicatedStorage[index] = false;
1687 continue;
1688 }
1689
1690 const ProgramUsage::VariableCounts counts = fProgram.fUsage->get(param);
1691 if ((param.modifierFlags() & ModifierFlag::kOut) || counts.fWrite == 0) {
1692 // Variables which are never written-to don't need dedicated storage.
1693 // Out-parameters are passed as pointers; the pointer itself is never modified, so
1694 // it doesn't need dedicated storage.
1695 paramNeedsDedicatedStorage[index] = false;
1696 }
1697 }
1698
1699 this->writeFunctionDeclaration(decl, paramNeedsDedicatedStorage);
1700 this->writeLine(" {");
1701 ++fIndentation;
1702
1703 // The parameters were given generic names like `_skParam1`, because WGSL parameters don't have
1704 // storage and are immutable. If mutability is required, we create variables here; otherwise, we
1705 // create properly-named `let` aliases.
1706 for (size_t index = 0; index < decl.parameters().size(); ++index) {
1707 if (paramNeedsDedicatedStorage[index]) {
1708 const Variable& param = *decl.parameters()[index];
1709 this->write("var ");
1710 this->write(this->assembleName(param.mangledName()));
1711 this->write(" = _skParam");
1712 this->write(std::to_string(index));
1713 this->writeLine(";");
1714 }
1715 }
1716
1717 this->writeBlock(f.body()->as<Block>());
1718
1719 // If fConditionalScopeDepth isn't zero, we have an unbalanced +1 or -1 when updating the depth.
1720 SkASSERT(fConditionalScopeDepth == 0);
1721 if (!fHasUnconditionalReturn && !f.declaration().returnType().isVoid()) {
1722 this->write("return ");
1723 this->write(to_wgsl_type(fContext, f.declaration().returnType()));
1724 this->writeLine("();");
1725 }
1726
1727 --fIndentation;
1728 this->writeLine("}");
1729
1730 SkASSERT(fAtFunctionScope);
1731 fAtFunctionScope = false;
1732 }
1733
writeFunctionDeclaration(const FunctionDeclaration & decl,SkSpan<const bool> paramNeedsDedicatedStorage)1734 void WGSLCodeGenerator::writeFunctionDeclaration(const FunctionDeclaration& decl,
1735 SkSpan<const bool> paramNeedsDedicatedStorage) {
1736 this->write("fn ");
1737 if (decl.isMain()) {
1738 this->write("_skslMain(");
1739 } else {
1740 this->write(this->assembleName(decl.mangledName()));
1741 this->write("(");
1742 }
1743 auto separator = SkSL::String::Separator();
1744 if (this->writeFunctionDependencyParams(decl)) {
1745 separator(); // update the separator as parameters have been written
1746 }
1747 for (size_t index = 0; index < decl.parameters().size(); ++index) {
1748 this->write(separator());
1749
1750 const Variable& param = *decl.parameters()[index];
1751 if (param.type().isOpaque()) {
1752 SkASSERT(!paramNeedsDedicatedStorage[index]);
1753 if (param.type().isSampler()) {
1754 // Create parameters for both the texture and associated sampler.
1755 this->write(param.name());
1756 this->write(kTextureSuffix);
1757 this->write(": texture_2d<f32>, ");
1758 this->write(param.name());
1759 this->write(kSamplerSuffix);
1760 this->write(": sampler");
1761 } else {
1762 // Create a parameter for the opaque object.
1763 this->write(param.name());
1764 this->write(": ");
1765 this->write(to_wgsl_type(fContext, param.type(), ¶m.layout()));
1766 }
1767 } else {
1768 if (paramNeedsDedicatedStorage[index] || param.name().empty()) {
1769 // Create an unnamed parameter. If the parameter needs dedicated storage, it will
1770 // later be assigned a `var` in the function body. (If it's anonymous, a var isn't
1771 // needed.)
1772 this->write("_skParam");
1773 this->write(std::to_string(index));
1774 } else {
1775 // Use the name directly from the SkSL program.
1776 this->write(this->assembleName(param.name()));
1777 }
1778 this->write(": ");
1779 if (param.modifierFlags() & ModifierFlag::kOut) {
1780 // Declare an "out" function parameter as a pointer.
1781 this->write(to_ptr_type(fContext, param.type(), ¶m.layout()));
1782 } else {
1783 this->write(to_wgsl_type(fContext, param.type(), ¶m.layout()));
1784 }
1785 }
1786 }
1787 this->write(")");
1788 if (!decl.returnType().isVoid()) {
1789 this->write(" -> ");
1790 this->write(to_wgsl_type(fContext, decl.returnType()));
1791 }
1792 }
1793
writeEntryPoint(const FunctionDefinition & main)1794 void WGSLCodeGenerator::writeEntryPoint(const FunctionDefinition& main) {
1795 SkASSERT(main.declaration().isMain());
1796 const ProgramKind programKind = fProgram.fConfig->fKind;
1797
1798 #if defined(SKSL_STANDALONE)
1799 if (ProgramConfig::IsRuntimeShader(programKind)) {
1800 // Synthesize a basic entrypoint which just calls straight through to main.
1801 // This is only used by skslc and just needs to pass the WGSL validator; Skia won't ever
1802 // emit functions like this.
1803 this->writeLine("@fragment fn main(@location(0) _coords: vec2<f32>) -> "
1804 "@location(0) vec4<f32> {");
1805 ++fIndentation;
1806 this->writeLine("return _skslMain(_coords);");
1807 --fIndentation;
1808 this->writeLine("}");
1809 return;
1810 }
1811 #endif
1812
1813 // The input and output parameters for a vertex/fragment stage entry point function have the
1814 // FSIn/FSOut/VSIn/VSOut/CSIn struct types that have been synthesized in generateCode(). An
1815 // entrypoint always has a predictable signature and acts as a trampoline to the user-defined
1816 // main function.
1817 if (ProgramConfig::IsVertex(programKind)) {
1818 this->write("@vertex");
1819 } else if (ProgramConfig::IsFragment(programKind)) {
1820 this->write("@fragment");
1821 } else if (ProgramConfig::IsCompute(programKind)) {
1822 this->write("@compute @workgroup_size(");
1823 this->write(std::to_string(fLocalSizeX));
1824 this->write(", ");
1825 this->write(std::to_string(fLocalSizeY));
1826 this->write(", ");
1827 this->write(std::to_string(fLocalSizeZ));
1828 this->write(")");
1829 } else {
1830 fContext.fErrors->error(Position(), "program kind not supported");
1831 return;
1832 }
1833
1834 this->write(" fn main(");
1835 // The stage input struct is a parameter passed to main().
1836 if (this->needsStageInputStruct()) {
1837 this->write("_stageIn: ");
1838 this->write(pipeline_struct_prefix(programKind));
1839 this->write("In");
1840 }
1841 // The stage output struct is returned from main().
1842 if (this->needsStageOutputStruct()) {
1843 this->write(") -> ");
1844 this->write(pipeline_struct_prefix(programKind));
1845 this->writeLine("Out {");
1846 } else {
1847 this->writeLine(") {");
1848 }
1849 // Initialize polyfilled matrix uniforms if any were used.
1850 fIndentation++;
1851 for (const auto& [field, info] : fFieldPolyfillMap) {
1852 if (info.fWasAccessed) {
1853 this->writeLine("_skInitializePolyfilledUniforms();");
1854 break;
1855 }
1856 }
1857 // Declare the stage output struct.
1858 if (this->needsStageOutputStruct()) {
1859 this->write("var _stageOut: ");
1860 this->write(pipeline_struct_prefix(programKind));
1861 this->writeLine("Out;");
1862 }
1863
1864 #if defined(SKSL_STANDALONE)
1865 // We are compiling a Runtime Effect as a fragment shader, for testing purposes. We assign the
1866 // result from _skslMain into sk_FragColor if the user-defined main returns a color. This
1867 // doesn't actually matter, but it is more indicative of what a real program would do.
1868 // `addImplicitFragColorWrite` from Transform::FindAndDeclareBuiltinVariables has already
1869 // injected sk_FragColor into our stage outputs even if it wasn't explicitly referenced.
1870 if (ProgramConfig::IsFragment(programKind)) {
1871 if (main.declaration().returnType().matches(*fContext.fTypes.fHalf4)) {
1872 this->write("_stageOut.sk_FragColor = ");
1873 }
1874 }
1875 #endif
1876
1877 // Generate a function call to the user-defined main.
1878 this->write("_skslMain(");
1879 auto separator = SkSL::String::Separator();
1880 WGSLFunctionDependencies* deps = fRequirements.fDependencies.find(&main.declaration());
1881 if (deps) {
1882 if (*deps & WGSLFunctionDependency::kPipelineInputs) {
1883 this->write(separator());
1884 this->write("_stageIn");
1885 }
1886 if (*deps & WGSLFunctionDependency::kPipelineOutputs) {
1887 this->write(separator());
1888 this->write("&_stageOut");
1889 }
1890 }
1891
1892 #if defined(SKSL_STANDALONE)
1893 if (const Variable* v = main.declaration().getMainCoordsParameter()) {
1894 // We are compiling a Runtime Effect as a fragment shader, for testing purposes.
1895 // We need to synthesize a coordinates parameter, but the coordinates don't matter.
1896 SkASSERT(ProgramConfig::IsFragment(programKind));
1897 const Type& type = v->type();
1898 if (!type.matches(*fContext.fTypes.fFloat2)) {
1899 fContext.fErrors->error(main.fPosition, "main function has unsupported parameter: " +
1900 type.description());
1901 return;
1902 }
1903 this->write(separator());
1904 this->write("/*fragcoord*/ vec2<f32>()");
1905 }
1906 #endif
1907
1908 this->writeLine(");");
1909
1910 if (this->needsStageOutputStruct()) {
1911 // Return the stage output struct.
1912 this->writeLine("return _stageOut;");
1913 }
1914
1915 fIndentation--;
1916 this->writeLine("}");
1917 }
1918
writeStatement(const Statement & s)1919 void WGSLCodeGenerator::writeStatement(const Statement& s) {
1920 switch (s.kind()) {
1921 case Statement::Kind::kBlock:
1922 this->writeBlock(s.as<Block>());
1923 break;
1924 case Statement::Kind::kBreak:
1925 this->writeLine("break;");
1926 break;
1927 case Statement::Kind::kContinue:
1928 this->writeLine("continue;");
1929 break;
1930 case Statement::Kind::kDiscard:
1931 this->writeLine("discard;");
1932 break;
1933 case Statement::Kind::kDo:
1934 this->writeDoStatement(s.as<DoStatement>());
1935 break;
1936 case Statement::Kind::kExpression:
1937 this->writeExpressionStatement(*s.as<ExpressionStatement>().expression());
1938 break;
1939 case Statement::Kind::kFor:
1940 this->writeForStatement(s.as<ForStatement>());
1941 break;
1942 case Statement::Kind::kIf:
1943 this->writeIfStatement(s.as<IfStatement>());
1944 break;
1945 case Statement::Kind::kNop:
1946 this->writeLine(";");
1947 break;
1948 case Statement::Kind::kReturn:
1949 this->writeReturnStatement(s.as<ReturnStatement>());
1950 break;
1951 case Statement::Kind::kSwitch:
1952 this->writeSwitchStatement(s.as<SwitchStatement>());
1953 break;
1954 case Statement::Kind::kSwitchCase:
1955 SkDEBUGFAIL("switch-case statements should only be present inside a switch");
1956 break;
1957 case Statement::Kind::kVarDeclaration:
1958 this->writeVarDeclaration(s.as<VarDeclaration>());
1959 break;
1960 }
1961 }
1962
writeStatements(const StatementArray & statements)1963 void WGSLCodeGenerator::writeStatements(const StatementArray& statements) {
1964 for (const auto& s : statements) {
1965 if (!s->isEmpty()) {
1966 this->writeStatement(*s);
1967 this->finishLine();
1968 }
1969 }
1970 }
1971
writeBlock(const Block & b)1972 void WGSLCodeGenerator::writeBlock(const Block& b) {
1973 // Write scope markers if this block is a scope, or if the block is empty (since we need to emit
1974 // something here to make the code valid).
1975 bool isScope = b.isScope() || b.isEmpty();
1976 if (isScope) {
1977 this->writeLine("{");
1978 fIndentation++;
1979 }
1980 this->writeStatements(b.children());
1981 if (isScope) {
1982 fIndentation--;
1983 this->writeLine("}");
1984 }
1985 }
1986
writeExpressionStatement(const Expression & expr)1987 void WGSLCodeGenerator::writeExpressionStatement(const Expression& expr) {
1988 // Any expression-related side effects must be emitted as separate statements when
1989 // `assembleExpression` is called.
1990 // The final result of the expression will be a variable, let-reference, or an expression with
1991 // no side effects (`foo + bar`). Discarding this result is safe, as the program never uses it.
1992 (void)this->assembleExpression(expr, Precedence::kStatement);
1993 }
1994
writeDoStatement(const DoStatement & s)1995 void WGSLCodeGenerator::writeDoStatement(const DoStatement& s) {
1996 // Generate a loop structure like this:
1997 // loop {
1998 // body-statement;
1999 // continuing {
2000 // break if inverted-test-expression;
2001 // }
2002 // }
2003
2004 ++fConditionalScopeDepth;
2005
2006 std::unique_ptr<Expression> invertedTestExpr = PrefixExpression::Make(
2007 fContext, s.test()->fPosition, OperatorKind::LOGICALNOT, s.test()->clone());
2008
2009 this->writeLine("loop {");
2010 fIndentation++;
2011 this->writeStatement(*s.statement());
2012 this->finishLine();
2013
2014 this->writeLine("continuing {");
2015 fIndentation++;
2016 std::string breakIfExpr = this->assembleExpression(*invertedTestExpr, Precedence::kExpression);
2017 this->write("break if ");
2018 this->write(breakIfExpr);
2019 this->writeLine(";");
2020 fIndentation--;
2021 this->writeLine("}");
2022 fIndentation--;
2023 this->writeLine("}");
2024
2025 --fConditionalScopeDepth;
2026 }
2027
writeForStatement(const ForStatement & s)2028 void WGSLCodeGenerator::writeForStatement(const ForStatement& s) {
2029 // Generate a loop structure wrapped in an extra scope:
2030 // {
2031 // initializer-statement;
2032 // loop;
2033 // }
2034 // The outer scope is necessary to prevent the initializer-variable from leaking out into the
2035 // rest of the code. In practice, the generated code actually tends to be scoped even more
2036 // deeply, as the body-statement almost always contributes an extra block.
2037
2038 ++fConditionalScopeDepth;
2039
2040 if (s.initializer()) {
2041 this->writeLine("{");
2042 fIndentation++;
2043 this->writeStatement(*s.initializer());
2044 this->writeLine();
2045 }
2046
2047 this->writeLine("loop {");
2048 fIndentation++;
2049
2050 if (s.unrollInfo()) {
2051 if (s.unrollInfo()->fCount <= 0) {
2052 // Loops which are known to never execute don't need to be emitted at all.
2053 // (However, the front end should have already replaced this loop with a Nop.)
2054 } else {
2055 // Loops which are known to execute at least once can use this form:
2056 //
2057 // loop {
2058 // body-statement;
2059 // continuing {
2060 // next-expression;
2061 // break if inverted-test-expression;
2062 // }
2063 // }
2064
2065 this->writeStatement(*s.statement());
2066 this->finishLine();
2067 this->writeLine("continuing {");
2068 ++fIndentation;
2069
2070 if (s.next()) {
2071 this->writeExpressionStatement(*s.next());
2072 this->finishLine();
2073 }
2074
2075 if (s.test()) {
2076 std::unique_ptr<Expression> invertedTestExpr = PrefixExpression::Make(
2077 fContext, s.test()->fPosition, OperatorKind::LOGICALNOT, s.test()->clone());
2078
2079 std::string breakIfExpr =
2080 this->assembleExpression(*invertedTestExpr, Precedence::kExpression);
2081 this->write("break if ");
2082 this->write(breakIfExpr);
2083 this->writeLine(";");
2084 }
2085
2086 --fIndentation;
2087 this->writeLine("}");
2088 }
2089 } else {
2090 // Loops without a known execution count are emitted in this form:
2091 //
2092 // loop {
2093 // if test-expression {
2094 // body-statement;
2095 // } else {
2096 // break;
2097 // }
2098 // continuing {
2099 // next-expression;
2100 // }
2101 // }
2102
2103 if (s.test()) {
2104 std::string testExpr = this->assembleExpression(*s.test(), Precedence::kExpression);
2105 this->write("if ");
2106 this->write(testExpr);
2107 this->writeLine(" {");
2108
2109 fIndentation++;
2110 this->writeStatement(*s.statement());
2111 this->finishLine();
2112 fIndentation--;
2113
2114 this->writeLine("} else {");
2115
2116 fIndentation++;
2117 this->writeLine("break;");
2118 fIndentation--;
2119
2120 this->writeLine("}");
2121 }
2122 else {
2123 this->writeStatement(*s.statement());
2124 this->finishLine();
2125 }
2126
2127 if (s.next()) {
2128 this->writeLine("continuing {");
2129 fIndentation++;
2130 this->writeExpressionStatement(*s.next());
2131 this->finishLine();
2132 fIndentation--;
2133 this->writeLine("}");
2134 }
2135 }
2136
2137 // This matches an open-brace at the top of the loop.
2138 fIndentation--;
2139 this->writeLine("}");
2140
2141 if (s.initializer()) {
2142 // This matches an open-brace before the initializer-statement.
2143 fIndentation--;
2144 this->writeLine("}");
2145 }
2146
2147 --fConditionalScopeDepth;
2148 }
2149
writeIfStatement(const IfStatement & s)2150 void WGSLCodeGenerator::writeIfStatement(const IfStatement& s) {
2151 ++fConditionalScopeDepth;
2152
2153 std::string testExpr = this->assembleExpression(*s.test(), Precedence::kExpression);
2154 this->write("if ");
2155 this->write(testExpr);
2156 this->writeLine(" {");
2157 fIndentation++;
2158 this->writeStatement(*s.ifTrue());
2159 this->finishLine();
2160 fIndentation--;
2161 if (s.ifFalse()) {
2162 this->writeLine("} else {");
2163 fIndentation++;
2164 this->writeStatement(*s.ifFalse());
2165 this->finishLine();
2166 fIndentation--;
2167 }
2168 this->writeLine("}");
2169
2170 --fConditionalScopeDepth;
2171 }
2172
writeReturnStatement(const ReturnStatement & s)2173 void WGSLCodeGenerator::writeReturnStatement(const ReturnStatement& s) {
2174 fHasUnconditionalReturn |= (fConditionalScopeDepth == 0);
2175
2176 std::string expr = s.expression()
2177 ? this->assembleExpression(*s.expression(), Precedence::kExpression)
2178 : std::string();
2179 this->write("return ");
2180 this->write(expr);
2181 this->write(";");
2182 }
2183
writeSwitchCaseList(SkSpan<const SwitchCase * const> cases)2184 void WGSLCodeGenerator::writeSwitchCaseList(SkSpan<const SwitchCase* const> cases) {
2185 auto separator = SkSL::String::Separator();
2186 for (const SwitchCase* const sc : cases) {
2187 this->write(separator());
2188 if (sc->isDefault()) {
2189 this->write("default");
2190 } else {
2191 this->write(std::to_string(sc->value()));
2192 }
2193 }
2194 }
2195
writeSwitchCases(SkSpan<const SwitchCase * const> cases)2196 void WGSLCodeGenerator::writeSwitchCases(SkSpan<const SwitchCase* const> cases) {
2197 if (!cases.empty()) {
2198 // Only the last switch-case should have a non-empty statement.
2199 SkASSERT(std::all_of(cases.begin(), std::prev(cases.end()), [](const SwitchCase* sc) {
2200 return sc->statement()->isEmpty();
2201 }));
2202
2203 // Emit the cases in a comma-separated list.
2204 this->write("case ");
2205 this->writeSwitchCaseList(cases);
2206 this->writeLine(" {");
2207 ++fIndentation;
2208
2209 // Emit the switch-case body.
2210 this->writeStatement(*cases.back()->statement());
2211 this->finishLine();
2212
2213 --fIndentation;
2214 this->writeLine("}");
2215 }
2216 }
2217
writeEmulatedSwitchFallthroughCases(SkSpan<const SwitchCase * const> cases,std::string_view switchValue)2218 void WGSLCodeGenerator::writeEmulatedSwitchFallthroughCases(SkSpan<const SwitchCase* const> cases,
2219 std::string_view switchValue) {
2220 // There's no need for fallthrough handling unless we actually have multiple case blocks.
2221 if (cases.size() < 2) {
2222 this->writeSwitchCases(cases);
2223 return;
2224 }
2225
2226 // Match against the entire case group.
2227 this->write("case ");
2228 this->writeSwitchCaseList(cases);
2229 this->writeLine(" {");
2230 ++fIndentation;
2231
2232 std::string fallthroughVar = this->writeScratchVar(*fContext.fTypes.fBool, "false");
2233 const size_t secondToLastCaseIndex = cases.size() - 2;
2234 const size_t lastCaseIndex = cases.size() - 1;
2235
2236 for (size_t index = 0; index < cases.size(); ++index) {
2237 const SwitchCase& sc = *cases[index];
2238 if (index < lastCaseIndex) {
2239 // The default case must come last in SkSL, and this case isn't the last one, so it
2240 // can't possibly be the default.
2241 SkASSERT(!sc.isDefault());
2242
2243 this->write("if ");
2244 if (index > 0) {
2245 this->write(fallthroughVar);
2246 this->write(" || ");
2247 }
2248 this->write(switchValue);
2249 this->write(" == ");
2250 this->write(std::to_string(sc.value()));
2251 this->writeLine(" {");
2252 fIndentation++;
2253
2254 // We write the entire case-block statement here, and then set `switchFallthrough`
2255 // to 1. If the case-block had a break statement in it, we break out of the outer
2256 // for-loop entirely, meaning the `switchFallthrough` assignment never occurs, nor
2257 // does any code after it inside the switch. We've forbidden `continue` statements
2258 // inside switch case-blocks entirely, so we don't need to consider their effect on
2259 // control flow; see the Finalizer in FunctionDefinition::Convert.
2260 this->writeStatement(*sc.statement());
2261 this->finishLine();
2262
2263 if (index < secondToLastCaseIndex) {
2264 // Set a variable to indicate falling through to the next block. The very last
2265 // case-block is reached by process of elimination and doesn't need this
2266 // variable, so we don't actually need to set it if we are on the second-to-last
2267 // case block.
2268 this->write(fallthroughVar);
2269 this->write(" = true; ");
2270 }
2271 this->writeLine("// fallthrough");
2272
2273 fIndentation--;
2274 this->writeLine("}");
2275 } else {
2276 // This is the final case. Since it's always last, we can just dump in the code.
2277 // (If we didn't match any of the other values, we must have matched this one by
2278 // process of elimination. If we did match one of the other values, we either hit a
2279 // `break` statement earlier--and won't get this far--or we're falling through.)
2280 this->writeStatement(*sc.statement());
2281 this->finishLine();
2282 }
2283 }
2284
2285 --fIndentation;
2286 this->writeLine("}");
2287 }
2288
writeSwitchStatement(const SwitchStatement & s)2289 void WGSLCodeGenerator::writeSwitchStatement(const SwitchStatement& s) {
2290 // WGSL supports the `switch` statement in a limited capacity. A default case must always be
2291 // specified. Each switch-case must be scoped inside braces. Fallthrough is not supported; a
2292 // trailing break is implied at the end of each switch-case block. (Explicit breaks are also
2293 // allowed.) One minor improvement over a traditional switch is that switch-cases take a list
2294 // of values to match, instead of a single value:
2295 // case 1, 2 { foo(); }
2296 // case 3, default { bar(); }
2297 //
2298 // We will use the native WGSL switch statement for any switch-cases in the SkSL which can be
2299 // made to conform to these limitations. The remaining cases which cannot conform will be
2300 // emulated with if-else blocks (similar to our GLSL ES2 switch-statement emulation path). This
2301 // should give us good performance in the common case, since most switches naturally conform.
2302
2303 // First, let's emit the switch itself.
2304 std::string valueExpr = this->writeNontrivialScratchLet(*s.value(), Precedence::kExpression);
2305 this->write("switch ");
2306 this->write(valueExpr);
2307 this->writeLine(" {");
2308 ++fIndentation;
2309
2310 // Now let's go through the switch-cases, and emit the ones that don't fall through.
2311 TArray<const SwitchCase*> nativeCases;
2312 TArray<const SwitchCase*> fallthroughCases;
2313 bool previousCaseFellThrough = false;
2314 bool foundNativeDefault = false;
2315 [[maybe_unused]] bool foundFallthroughDefault = false;
2316
2317 const int lastSwitchCaseIdx = s.cases().size() - 1;
2318 for (int index = 0; index <= lastSwitchCaseIdx; ++index) {
2319 const SwitchCase& sc = s.cases()[index]->as<SwitchCase>();
2320
2321 if (sc.statement()->isEmpty()) {
2322 // This is a `case X:` that immediately falls through to the next case.
2323 // If we aren't already falling through, we can handle this via a comma-separated list.
2324 if (previousCaseFellThrough) {
2325 fallthroughCases.push_back(&sc);
2326 foundFallthroughDefault |= sc.isDefault();
2327 } else {
2328 nativeCases.push_back(&sc);
2329 foundNativeDefault |= sc.isDefault();
2330 }
2331 continue;
2332 }
2333
2334 if (index == lastSwitchCaseIdx || Analysis::SwitchCaseContainsUnconditionalExit(sc)) {
2335 // This is a `case X:` that never falls through.
2336 if (previousCaseFellThrough) {
2337 // Because the previous cases fell through, we can't use a native switch-case here.
2338 fallthroughCases.push_back(&sc);
2339 foundFallthroughDefault |= sc.isDefault();
2340
2341 this->writeEmulatedSwitchFallthroughCases(fallthroughCases, valueExpr);
2342 fallthroughCases.clear();
2343
2344 // Fortunately, we're no longer falling through blocks, so we might be able to use a
2345 // native switch-case list again.
2346 previousCaseFellThrough = false;
2347 } else {
2348 // Emit a native switch-case block with a comma-separated case list.
2349 nativeCases.push_back(&sc);
2350 foundNativeDefault |= sc.isDefault();
2351
2352 this->writeSwitchCases(nativeCases);
2353 nativeCases.clear();
2354 }
2355 continue;
2356 }
2357
2358 // This case falls through, so it will need to be handled via emulation.
2359 // If we have put together a collection of "native" cases (cases that fall through with no
2360 // actual case-body), we will need to slide them over into the fallthrough-case list.
2361 fallthroughCases.push_back_n(nativeCases.size(), nativeCases.data());
2362 nativeCases.clear();
2363
2364 fallthroughCases.push_back(&sc);
2365 foundFallthroughDefault |= sc.isDefault();
2366 previousCaseFellThrough = true;
2367 }
2368
2369 // Finish out the remaining switch-cases.
2370 this->writeSwitchCases(nativeCases);
2371 nativeCases.clear();
2372
2373 this->writeEmulatedSwitchFallthroughCases(fallthroughCases, valueExpr);
2374 fallthroughCases.clear();
2375
2376 // WGSL requires a default case.
2377 if (!foundNativeDefault && !foundFallthroughDefault) {
2378 this->writeLine("case default {}");
2379 }
2380
2381 --fIndentation;
2382 this->writeLine("}");
2383 }
2384
writeVarDeclaration(const VarDeclaration & varDecl)2385 void WGSLCodeGenerator::writeVarDeclaration(const VarDeclaration& varDecl) {
2386 std::string initialValue =
2387 varDecl.value() ? this->assembleExpression(*varDecl.value(), Precedence::kAssignment)
2388 : std::string();
2389
2390 if (varDecl.var()->modifierFlags().isConst()) {
2391 // Use `const` at global scope, or if the value is a compile-time constant.
2392 SkASSERTF(varDecl.value(), "a constant variable must specify a value");
2393 this->write((!fAtFunctionScope || Analysis::IsCompileTimeConstant(*varDecl.value()))
2394 ? "const "
2395 : "let ");
2396 } else {
2397 this->write("var ");
2398 }
2399 this->write(this->assembleName(varDecl.var()->mangledName()));
2400 this->write(": ");
2401 this->write(to_wgsl_type(fContext, varDecl.var()->type(), &varDecl.var()->layout()));
2402
2403 if (varDecl.value()) {
2404 this->write(" = ");
2405 this->write(initialValue);
2406 }
2407
2408 this->write(";");
2409 }
2410
makeLValue(const Expression & e)2411 std::unique_ptr<WGSLCodeGenerator::LValue> WGSLCodeGenerator::makeLValue(const Expression& e) {
2412 if (e.is<VariableReference>()) {
2413 return std::make_unique<PointerLValue>(
2414 this->variableReferenceNameForLValue(e.as<VariableReference>()));
2415 }
2416 if (e.is<FieldAccess>()) {
2417 return std::make_unique<PointerLValue>(this->assembleFieldAccess(e.as<FieldAccess>()));
2418 }
2419 if (e.is<IndexExpression>()) {
2420 const IndexExpression& idx = e.as<IndexExpression>();
2421 if (idx.base()->type().isVector()) {
2422 // Rewrite indexed-swizzle accesses like `myVec.zyx[i]` into an index onto `myVec`.
2423 if (std::unique_ptr<Expression> rewrite =
2424 Transform::RewriteIndexedSwizzle(fContext, idx)) {
2425 return std::make_unique<VectorComponentLValue>(
2426 this->assembleExpression(*rewrite, Precedence::kAssignment));
2427 } else {
2428 return std::make_unique<VectorComponentLValue>(this->assembleIndexExpression(idx));
2429 }
2430 } else {
2431 return std::make_unique<PointerLValue>(this->assembleIndexExpression(idx));
2432 }
2433 }
2434 if (e.is<Swizzle>()) {
2435 const Swizzle& swizzle = e.as<Swizzle>();
2436 if (swizzle.components().size() == 1) {
2437 return std::make_unique<VectorComponentLValue>(this->assembleSwizzle(swizzle));
2438 } else {
2439 return std::make_unique<SwizzleLValue>(
2440 fContext,
2441 this->assembleExpression(*swizzle.base(), Precedence::kAssignment),
2442 swizzle.base()->type(),
2443 swizzle.components());
2444 }
2445 }
2446
2447 fContext.fErrors->error(e.fPosition, "unsupported lvalue type");
2448 return nullptr;
2449 }
2450
assembleExpression(const Expression & e,Precedence parentPrecedence)2451 std::string WGSLCodeGenerator::assembleExpression(const Expression& e,
2452 Precedence parentPrecedence) {
2453 switch (e.kind()) {
2454 case Expression::Kind::kBinary:
2455 return this->assembleBinaryExpression(e.as<BinaryExpression>(), parentPrecedence);
2456
2457 case Expression::Kind::kConstructorCompound:
2458 return this->assembleConstructorCompound(e.as<ConstructorCompound>());
2459
2460 case Expression::Kind::kConstructorArrayCast:
2461 // This is a no-op, since WGSL 1.0 doesn't have any concept of precision qualifiers.
2462 // When we add support for f16, this will need to copy the array contents.
2463 return this->assembleExpression(*e.as<ConstructorArrayCast>().argument(),
2464 parentPrecedence);
2465
2466 case Expression::Kind::kConstructorArray:
2467 case Expression::Kind::kConstructorCompoundCast:
2468 case Expression::Kind::kConstructorScalarCast:
2469 case Expression::Kind::kConstructorSplat:
2470 case Expression::Kind::kConstructorStruct:
2471 return this->assembleAnyConstructor(e.asAnyConstructor());
2472
2473 case Expression::Kind::kConstructorDiagonalMatrix:
2474 return this->assembleConstructorDiagonalMatrix(e.as<ConstructorDiagonalMatrix>());
2475
2476 case Expression::Kind::kConstructorMatrixResize:
2477 return this->assembleConstructorMatrixResize(e.as<ConstructorMatrixResize>());
2478
2479 case Expression::Kind::kEmpty:
2480 return "false";
2481
2482 case Expression::Kind::kFieldAccess:
2483 return this->assembleFieldAccess(e.as<FieldAccess>());
2484
2485 case Expression::Kind::kFunctionCall:
2486 return this->assembleFunctionCall(e.as<FunctionCall>(), parentPrecedence);
2487
2488 case Expression::Kind::kIndex:
2489 return this->assembleIndexExpression(e.as<IndexExpression>());
2490
2491 case Expression::Kind::kLiteral:
2492 return this->assembleLiteral(e.as<Literal>());
2493
2494 case Expression::Kind::kPrefix:
2495 return this->assemblePrefixExpression(e.as<PrefixExpression>(), parentPrecedence);
2496
2497 case Expression::Kind::kPostfix:
2498 return this->assemblePostfixExpression(e.as<PostfixExpression>(), parentPrecedence);
2499
2500 case Expression::Kind::kSetting:
2501 return this->assembleExpression(*e.as<Setting>().toLiteral(fCaps), parentPrecedence);
2502
2503 case Expression::Kind::kSwizzle:
2504 return this->assembleSwizzle(e.as<Swizzle>());
2505
2506 case Expression::Kind::kTernary:
2507 return this->assembleTernaryExpression(e.as<TernaryExpression>(), parentPrecedence);
2508
2509 case Expression::Kind::kVariableReference:
2510 return this->assembleVariableReference(e.as<VariableReference>());
2511
2512 default:
2513 SkDEBUGFAILF("unsupported expression:\n%s", e.description().c_str());
2514 return {};
2515 }
2516 }
2517
is_nontrivial_expression(const Expression & expr)2518 static bool is_nontrivial_expression(const Expression& expr) {
2519 // We consider a "trivial expression" one which we can repeat multiple times in the output
2520 // without being dangerous or spammy. We avoid emitting temporary variables for very trivial
2521 // expressions: literals, unadorned variable references, or constant vectors.
2522 if (expr.is<VariableReference>() || expr.is<Literal>()) {
2523 // Variables and literals are trivial; adding a let-declaration won't simplify anything.
2524 return false;
2525 }
2526 if (expr.type().isVector() && Analysis::IsConstantExpression(expr)) {
2527 // Compile-time constant vectors are also considered trivial; they're short and sweet.
2528 return false;
2529 }
2530 return true;
2531 }
2532
binary_op_is_ambiguous_in_wgsl(Operator op)2533 static bool binary_op_is_ambiguous_in_wgsl(Operator op) {
2534 // WGSL always requires parentheses for some operators which are deemed to be ambiguous.
2535 // (8.19. Operator Precedence and Associativity)
2536 switch (op.kind()) {
2537 case OperatorKind::LOGICALOR:
2538 case OperatorKind::LOGICALAND:
2539 case OperatorKind::BITWISEOR:
2540 case OperatorKind::BITWISEAND:
2541 case OperatorKind::BITWISEXOR:
2542 case OperatorKind::SHL:
2543 case OperatorKind::SHR:
2544 case OperatorKind::LT:
2545 case OperatorKind::GT:
2546 case OperatorKind::LTEQ:
2547 case OperatorKind::GTEQ:
2548 return true;
2549
2550 default:
2551 return false;
2552 }
2553 }
2554
binaryOpNeedsComponentwiseMatrixPolyfill(const Type & left,const Type & right,Operator op)2555 bool WGSLCodeGenerator::binaryOpNeedsComponentwiseMatrixPolyfill(const Type& left,
2556 const Type& right,
2557 Operator op) {
2558 switch (op.kind()) {
2559 case OperatorKind::SLASH:
2560 // WGSL does not natively support componentwise matrix-op-matrix for division.
2561 if (left.isMatrix() && right.isMatrix()) {
2562 return true;
2563 }
2564 [[fallthrough]];
2565
2566 case OperatorKind::PLUS:
2567 case OperatorKind::MINUS:
2568 // WGSL does not natively support componentwise matrix-op-scalar or scalar-op-matrix for
2569 // addition, subtraction or division.
2570 return (left.isMatrix() && right.isScalar()) ||
2571 (left.isScalar() && right.isMatrix());
2572
2573 default:
2574 return false;
2575 }
2576 }
2577
assembleBinaryExpression(const BinaryExpression & b,Precedence parentPrecedence)2578 std::string WGSLCodeGenerator::assembleBinaryExpression(const BinaryExpression& b,
2579 Precedence parentPrecedence) {
2580 return this->assembleBinaryExpression(*b.left(), b.getOperator(), *b.right(), b.type(),
2581 parentPrecedence);
2582 }
2583
assembleBinaryExpression(const Expression & left,Operator op,const Expression & right,const Type & resultType,Precedence parentPrecedence)2584 std::string WGSLCodeGenerator::assembleBinaryExpression(const Expression& left,
2585 Operator op,
2586 const Expression& right,
2587 const Type& resultType,
2588 Precedence parentPrecedence) {
2589 // If the operator is && or ||, we need to handle short-circuiting properly. Specifically, we
2590 // sometimes need to emit extra statements to paper over functionality that WGSL lacks, like
2591 // assignment in the middle of an expression. We need to guard those extra statements, to ensure
2592 // that they don't occur if the expression evaluation is short-circuited. Converting the
2593 // expression into an if-else block keeps the short-circuit property intact even when extra
2594 // statements are involved.
2595 // If the RHS doesn't have any side effects, then it's safe to just leave the expression as-is,
2596 // since we know that any possible extra statements are non-side-effecting.
2597 std::string expr;
2598 if (op.kind() == OperatorKind::LOGICALAND && Analysis::HasSideEffects(right)) {
2599 // Converts `left_expression && right_expression` into the following block:
2600
2601 // var _skTemp1: bool;
2602 // [[ prepare left_expression ]]
2603 // if left_expression {
2604 // [[ prepare right_expression ]]
2605 // _skTemp1 = right_expression;
2606 // } else {
2607 // _skTemp1 = false;
2608 // }
2609
2610 expr = this->writeScratchVar(resultType);
2611
2612 std::string leftExpr = this->assembleExpression(left, Precedence::kExpression);
2613 this->write("if ");
2614 this->write(leftExpr);
2615 this->writeLine(" {");
2616
2617 ++fIndentation;
2618 std::string rightExpr = this->assembleExpression(right, Precedence::kAssignment);
2619 this->write(expr);
2620 this->write(" = ");
2621 this->write(rightExpr);
2622 this->writeLine(";");
2623 --fIndentation;
2624
2625 this->writeLine("} else {");
2626
2627 ++fIndentation;
2628 this->write(expr);
2629 this->writeLine(" = false;");
2630 --fIndentation;
2631
2632 this->writeLine("}");
2633 return expr;
2634 }
2635
2636 if (op.kind() == OperatorKind::LOGICALOR && Analysis::HasSideEffects(right)) {
2637 // Converts `left_expression || right_expression` into the following block:
2638
2639 // var _skTemp1: bool;
2640 // [[ prepare left_expression ]]
2641 // if left_expression {
2642 // _skTemp1 = true;
2643 // } else {
2644 // [[ prepare right_expression ]]
2645 // _skTemp1 = right_expression;
2646 // }
2647
2648 expr = this->writeScratchVar(resultType);
2649
2650 std::string leftExpr = this->assembleExpression(left, Precedence::kExpression);
2651 this->write("if ");
2652 this->write(leftExpr);
2653 this->writeLine(" {");
2654
2655 ++fIndentation;
2656 this->write(expr);
2657 this->writeLine(" = true;");
2658 --fIndentation;
2659
2660 this->writeLine("} else {");
2661
2662 ++fIndentation;
2663 std::string rightExpr = this->assembleExpression(right, Precedence::kAssignment);
2664 this->write(expr);
2665 this->write(" = ");
2666 this->write(rightExpr);
2667 this->writeLine(";");
2668 --fIndentation;
2669
2670 this->writeLine("}");
2671 return expr;
2672 }
2673
2674 // Handle comma-expressions.
2675 if (op.kind() == OperatorKind::COMMA) {
2676 // The result from the left-expression is ignored, but its side effects must occur.
2677 this->assembleExpression(left, Precedence::kStatement);
2678
2679 // Evaluate the right side normally.
2680 return this->assembleExpression(right, parentPrecedence);
2681 }
2682
2683 // Handle assignment-expressions.
2684 if (op.isAssignment()) {
2685 std::unique_ptr<LValue> lvalue = this->makeLValue(left);
2686 if (!lvalue) {
2687 return "";
2688 }
2689
2690 if (op.kind() == OperatorKind::EQ) {
2691 // Evaluate the right-hand side of simple assignment (`a = b` --> `b`).
2692 expr = this->assembleExpression(right, Precedence::kAssignment);
2693 } else {
2694 // Evaluate the right-hand side of compound-assignment (`a += b` --> `a + b`).
2695 op = op.removeAssignment();
2696
2697 std::string lhs = lvalue->load();
2698 std::string rhs = this->assembleExpression(right, op.getBinaryPrecedence());
2699
2700 if (this->binaryOpNeedsComponentwiseMatrixPolyfill(left.type(), right.type(), op)) {
2701 if (is_nontrivial_expression(right)) {
2702 rhs = this->writeScratchLet(rhs);
2703 }
2704
2705 expr = this->assembleComponentwiseMatrixBinary(left.type(), right.type(),
2706 lhs, rhs, op);
2707 } else {
2708 expr = lhs + operator_name(op) + rhs;
2709 }
2710 }
2711
2712 // Emit the assignment statement (`a = a + b`).
2713 this->writeLine(lvalue->store(expr));
2714
2715 // Return the lvalue (`a`) as the result, since the value might be used by the caller.
2716 return lvalue->load();
2717 }
2718
2719 if (op.isEquality()) {
2720 return this->assembleEqualityExpression(left, right, op, parentPrecedence);
2721 }
2722
2723 Precedence precedence = op.getBinaryPrecedence();
2724 bool needParens = precedence >= parentPrecedence;
2725 if (binary_op_is_ambiguous_in_wgsl(op)) {
2726 precedence = Precedence::kParentheses;
2727 }
2728 if (needParens) {
2729 expr = "(";
2730 }
2731
2732 // If we are emitting `constant + constant`, this generally indicates that the values could not
2733 // be constant-folded. This happens when the values overflow or become nan. WGSL will refuse to
2734 // compile such expressions, as WGSL 1.0 has no infinity/nan support. However, the WGSL
2735 // compile-time check can be dodged by putting one side into a let-variable. This technically
2736 // gives us an indeterminate result, but the vast majority of backends will just calculate an
2737 // infinity or nan here, as we would expect. (skia:14385)
2738 bool bothSidesConstant = ConstantFolder::GetConstantValueOrNull(left) &&
2739 ConstantFolder::GetConstantValueOrNull(right);
2740
2741 std::string lhs = this->assembleExpression(left, precedence);
2742 std::string rhs = this->assembleExpression(right, precedence);
2743
2744 if (this->binaryOpNeedsComponentwiseMatrixPolyfill(left.type(), right.type(), op)) {
2745 if (bothSidesConstant || is_nontrivial_expression(left)) {
2746 lhs = this->writeScratchLet(lhs);
2747 }
2748 if (is_nontrivial_expression(right)) {
2749 rhs = this->writeScratchLet(rhs);
2750 }
2751
2752 expr += this->assembleComponentwiseMatrixBinary(left.type(), right.type(), lhs, rhs, op);
2753 } else {
2754 if (bothSidesConstant) {
2755 lhs = this->writeScratchLet(lhs);
2756 }
2757
2758 expr += lhs + operator_name(op) + rhs;
2759 }
2760
2761 if (needParens) {
2762 expr += ')';
2763 }
2764
2765 return expr;
2766 }
2767
assembleFieldAccess(const FieldAccess & f)2768 std::string WGSLCodeGenerator::assembleFieldAccess(const FieldAccess& f) {
2769 const Field* field = &f.base()->type().fields()[f.fieldIndex()];
2770 std::string expr;
2771
2772 if (FieldPolyfillInfo* polyfillInfo = fFieldPolyfillMap.find(field)) {
2773 // We found a matrix uniform. We are required to pass some matrix uniforms as array vectors,
2774 // since the std140 layout for a matrix assumes 4-column vectors for each row, and WGSL
2775 // tightly packs 2-column matrices. When emitting code, we replace the field-access
2776 // expression with a global variable which holds an unpacked version of the uniform.
2777 polyfillInfo->fWasAccessed = true;
2778
2779 // The polyfill can either be based directly onto a uniform in an interface block, or it
2780 // might be based on an index-expression onto a uniform if the interface block is arrayed.
2781 const Expression* base = f.base().get();
2782 const IndexExpression* indexExpr = nullptr;
2783 if (base->is<IndexExpression>()) {
2784 indexExpr = &base->as<IndexExpression>();
2785 base = indexExpr->base().get();
2786 }
2787
2788 SkASSERT(base->is<VariableReference>());
2789 expr = polyfillInfo->fReplacementName;
2790
2791 // If we had an index expression, we must append the index.
2792 if (indexExpr) {
2793 expr += '[';
2794 expr += this->assembleExpression(*indexExpr->index(), Precedence::kSequence);
2795 expr += ']';
2796 }
2797 return expr;
2798 }
2799
2800 switch (f.ownerKind()) {
2801 case FieldAccess::OwnerKind::kDefault:
2802 expr = this->assembleExpression(*f.base(), Precedence::kPostfix) + '.';
2803 break;
2804
2805 case FieldAccess::OwnerKind::kAnonymousInterfaceBlock:
2806 if (f.base()->is<VariableReference>() &&
2807 field->fLayout.fBuiltin != SK_POINTSIZE_BUILTIN) {
2808 expr = this->variablePrefix(*f.base()->as<VariableReference>().variable());
2809 }
2810 break;
2811 }
2812
2813 expr += this->assembleName(field->fName);
2814 return expr;
2815 }
2816
all_arguments_constant(const ExpressionArray & arguments)2817 static bool all_arguments_constant(const ExpressionArray& arguments) {
2818 // Returns true if all arguments in the ExpressionArray are compile-time constants. If we are
2819 // calling an intrinsic and all of its inputs are constant, but we didn't constant-fold it, this
2820 // generally indicates that constant-folding resulted in an infinity or nan. The WGSL compiler
2821 // will reject such an expression with a compile-time error. We can dodge the error, taking on
2822 // the risk of indeterminate behavior instead, by replacing one of the constant values with a
2823 // scratch let-variable. (skia:14385)
2824 for (const std::unique_ptr<Expression>& arg : arguments) {
2825 if (!ConstantFolder::GetConstantValueOrNull(*arg)) {
2826 return false;
2827 }
2828 }
2829 return true;
2830 }
2831
assembleSimpleIntrinsic(std::string_view intrinsicName,const FunctionCall & call)2832 std::string WGSLCodeGenerator::assembleSimpleIntrinsic(std::string_view intrinsicName,
2833 const FunctionCall& call) {
2834 // Invoke the function, passing each function argument.
2835 std::string expr = std::string(intrinsicName);
2836 expr.push_back('(');
2837 const ExpressionArray& args = call.arguments();
2838 auto separator = SkSL::String::Separator();
2839 bool allConstant = all_arguments_constant(call.arguments());
2840 for (int index = 0; index < args.size(); ++index) {
2841 expr += separator();
2842
2843 std::string argument = this->assembleExpression(*args[index], Precedence::kSequence);
2844 if (args[index]->type().isAtomic()) {
2845 // WGSL passes atomic values to intrinsics as pointers.
2846 expr += '&';
2847 expr += argument;
2848 } else if (allConstant && index == 0) {
2849 // We can use a scratch-let for argument 0 to dodge WGSL overflow errors. (skia:14385)
2850 expr += this->writeScratchLet(argument);
2851 } else {
2852 expr += argument;
2853 }
2854 }
2855 expr.push_back(')');
2856
2857 if (call.type().isVoid()) {
2858 this->write(expr);
2859 this->writeLine(";");
2860 return std::string();
2861 } else {
2862 return this->writeScratchLet(expr);
2863 }
2864 }
2865
assembleVectorizedIntrinsic(std::string_view intrinsicName,const FunctionCall & call)2866 std::string WGSLCodeGenerator::assembleVectorizedIntrinsic(std::string_view intrinsicName,
2867 const FunctionCall& call) {
2868 SkASSERT(!call.type().isVoid());
2869
2870 // Invoke the function, passing each function argument.
2871 std::string expr = std::string(intrinsicName);
2872 expr.push_back('(');
2873
2874 auto separator = SkSL::String::Separator();
2875 const ExpressionArray& args = call.arguments();
2876 bool returnsVector = call.type().isVector();
2877 bool allConstant = all_arguments_constant(call.arguments());
2878 for (int index = 0; index < args.size(); ++index) {
2879 expr += separator();
2880
2881 bool vectorize = returnsVector && args[index]->type().isScalar();
2882 if (vectorize) {
2883 expr += to_wgsl_type(fContext, call.type());
2884 expr.push_back('(');
2885 }
2886
2887 // We can use a scratch-let for argument 0 to dodge WGSL overflow errors. (skia:14385)
2888 std::string argument = this->assembleExpression(*args[index], Precedence::kSequence);
2889 expr += (allConstant && index == 0) ? this->writeScratchLet(argument)
2890 : argument;
2891 if (vectorize) {
2892 expr.push_back(')');
2893 }
2894 }
2895 expr.push_back(')');
2896
2897 return this->writeScratchLet(expr);
2898 }
2899
assembleUnaryOpIntrinsic(Operator op,const FunctionCall & call,Precedence parentPrecedence)2900 std::string WGSLCodeGenerator::assembleUnaryOpIntrinsic(Operator op,
2901 const FunctionCall& call,
2902 Precedence parentPrecedence) {
2903 SkASSERT(!call.type().isVoid());
2904
2905 bool needParens = Precedence::kPrefix >= parentPrecedence;
2906
2907 std::string expr;
2908 if (needParens) {
2909 expr.push_back('(');
2910 }
2911
2912 expr += operator_name(op);
2913 expr += this->assembleExpression(*call.arguments()[0], Precedence::kPrefix);
2914
2915 if (needParens) {
2916 expr.push_back(')');
2917 }
2918
2919 return expr;
2920 }
2921
assembleBinaryOpIntrinsic(Operator op,const FunctionCall & call,Precedence parentPrecedence)2922 std::string WGSLCodeGenerator::assembleBinaryOpIntrinsic(Operator op,
2923 const FunctionCall& call,
2924 Precedence parentPrecedence) {
2925 SkASSERT(!call.type().isVoid());
2926
2927 Precedence precedence = op.getBinaryPrecedence();
2928 bool needParens = precedence >= parentPrecedence ||
2929 binary_op_is_ambiguous_in_wgsl(op);
2930 std::string expr;
2931 if (needParens) {
2932 expr.push_back('(');
2933 }
2934
2935 // We can use a scratch-let for argument 0 to dodge WGSL overflow errors. (skia:14385)
2936 std::string argument = this->assembleExpression(*call.arguments()[0], precedence);
2937 expr += all_arguments_constant(call.arguments()) ? this->writeScratchLet(argument)
2938 : argument;
2939 expr += operator_name(op);
2940 expr += this->assembleExpression(*call.arguments()[1], precedence);
2941
2942 if (needParens) {
2943 expr.push_back(')');
2944 }
2945
2946 return expr;
2947 }
2948
2949 // Rewrite a WGSL intrinsic of the form "intrinsicName(in) -> struct" to the SkSL's
2950 // "intrinsicName(in, outField) -> returnField", where outField and returnField are the names of the
2951 // fields in the struct returned by the WGSL intrinsic.
assembleOutAssignedIntrinsic(std::string_view intrinsicName,std::string_view returnField,std::string_view outField,const FunctionCall & call)2952 std::string WGSLCodeGenerator::assembleOutAssignedIntrinsic(std::string_view intrinsicName,
2953 std::string_view returnField,
2954 std::string_view outField,
2955 const FunctionCall& call) {
2956 SkASSERT(call.type().componentType().isNumber());
2957 SkASSERT(call.arguments().size() == 2);
2958 SkASSERT(call.function().parameters()[1]->modifierFlags() & ModifierFlag::kOut);
2959
2960 std::string expr = std::string(intrinsicName);
2961 expr += "(";
2962
2963 // Invoke the intrinsic with the first parameter. Use a scratch-let if argument is a constant
2964 // to dodge WGSL overflow errors. (skia:14385)
2965 std::string argument = this->assembleExpression(*call.arguments()[0], Precedence::kSequence);
2966 expr += ConstantFolder::GetConstantValueOrNull(*call.arguments()[0])
2967 ? this->writeScratchLet(argument) : argument;
2968 expr += ")";
2969 // In WGSL the intrinsic returns a struct; assign it to a local so that its fields can be
2970 // accessed multiple times.
2971 expr = this->writeScratchLet(expr);
2972 expr += ".";
2973
2974 // Store the outField of `expr` to the intended "out" argument
2975 std::unique_ptr<LValue> lvalue = this->makeLValue(*call.arguments()[1]);
2976 if (!lvalue) {
2977 return "";
2978 }
2979 std::string outValue = expr;
2980 outValue += outField;
2981 this->writeLine(lvalue->store(outValue));
2982
2983 // And return the expression accessing the returnField.
2984 expr += returnField;
2985 return expr;
2986 }
2987
assemblePartialSampleCall(std::string_view functionName,const Expression & sampler,const Expression & coords)2988 std::string WGSLCodeGenerator::assemblePartialSampleCall(std::string_view functionName,
2989 const Expression& sampler,
2990 const Expression& coords) {
2991 // This function returns `functionName(inSampler_texture, inSampler_sampler, coords` without a
2992 // terminating comma or close-parenthesis. This allows the caller to add more arguments as
2993 // needed.
2994 SkASSERT(sampler.type().typeKind() == Type::TypeKind::kSampler);
2995 std::string expr = std::string(functionName) + '(';
2996 expr += this->assembleExpression(sampler, Precedence::kSequence);
2997 expr += kTextureSuffix;
2998 expr += ", ";
2999 expr += this->assembleExpression(sampler, Precedence::kSequence);
3000 expr += kSamplerSuffix;
3001 expr += ", ";
3002
3003 // Compute the sample coordinates, dividing out the Z if a vec3 was provided.
3004 SkASSERT(coords.type().isVector());
3005 if (coords.type().columns() == 3) {
3006 // The coordinates were passed as a vec3, so we need to emit `coords.xy / coords.z`.
3007 std::string vec3Coords = this->writeScratchLet(coords, Precedence::kMultiplicative);
3008 expr += vec3Coords + ".xy / " + vec3Coords + ".z";
3009 } else {
3010 // The coordinates should be a plain vec2; emit the expression as-is.
3011 SkASSERT(coords.type().columns() == 2);
3012 expr += this->assembleExpression(coords, Precedence::kSequence);
3013 }
3014
3015 return expr;
3016 }
3017
assembleComponentwiseMatrixBinary(const Type & leftType,const Type & rightType,const std::string & left,const std::string & right,Operator op)3018 std::string WGSLCodeGenerator::assembleComponentwiseMatrixBinary(const Type& leftType,
3019 const Type& rightType,
3020 const std::string& left,
3021 const std::string& right,
3022 Operator op) {
3023 bool leftIsMatrix = leftType.isMatrix();
3024 bool rightIsMatrix = rightType.isMatrix();
3025 const Type& matrixType = leftIsMatrix ? leftType : rightType;
3026
3027 std::string expr = to_wgsl_type(fContext, matrixType) + '(';
3028 auto separator = String::Separator();
3029 int columns = matrixType.columns();
3030 for (int c = 0; c < columns; ++c) {
3031 expr += separator();
3032 expr += left;
3033 if (leftIsMatrix) {
3034 expr += '[';
3035 expr += std::to_string(c);
3036 expr += ']';
3037 }
3038 expr += op.operatorName();
3039 expr += right;
3040 if (rightIsMatrix) {
3041 expr += '[';
3042 expr += std::to_string(c);
3043 expr += ']';
3044 }
3045 }
3046 return expr + ')';
3047 }
3048
assembleIntrinsicCall(const FunctionCall & call,IntrinsicKind kind,Precedence parentPrecedence)3049 std::string WGSLCodeGenerator::assembleIntrinsicCall(const FunctionCall& call,
3050 IntrinsicKind kind,
3051 Precedence parentPrecedence) {
3052 // Be careful: WGSL 1.0 will reject any intrinsic calls which can be constant-evaluated to
3053 // infinity or nan with a compile error. If all arguments to an intrinsic are compile-time
3054 // constants (`all_arguments_constant`), it is safest to copy one argument into a scratch-let so
3055 // that the call will be seen as runtime-evaluated, which defuses the overflow checks.
3056 // Don't worry; a competent driver should still optimize it away.
3057
3058 const ExpressionArray& arguments = call.arguments();
3059 switch (kind) {
3060 case k_atan_IntrinsicKind: {
3061 const char* name = (arguments.size() == 1) ? "atan" : "atan2";
3062 return this->assembleSimpleIntrinsic(name, call);
3063 }
3064 case k_dFdx_IntrinsicKind:
3065 return this->assembleSimpleIntrinsic("dpdx", call);
3066
3067 case k_dFdy_IntrinsicKind:
3068 // TODO(b/294274678): apply RTFlip here
3069 return this->assembleSimpleIntrinsic("dpdy", call);
3070
3071 case k_dot_IntrinsicKind: {
3072 if (arguments[0]->type().isScalar()) {
3073 return this->assembleBinaryOpIntrinsic(OperatorKind::STAR, call, parentPrecedence);
3074 }
3075 return this->assembleSimpleIntrinsic("dot", call);
3076 }
3077 case k_equal_IntrinsicKind:
3078 return this->assembleBinaryOpIntrinsic(OperatorKind::EQEQ, call, parentPrecedence);
3079
3080 case k_faceforward_IntrinsicKind: {
3081 if (arguments[0]->type().isScalar()) {
3082 // select(-N, N, (I * Nref) < 0)
3083 std::string N = this->writeNontrivialScratchLet(*arguments[0],
3084 Precedence::kAssignment);
3085 return this->writeScratchLet(
3086 "select(-" + N + ", " + N + ", " +
3087 this->assembleBinaryExpression(*arguments[1],
3088 OperatorKind::STAR,
3089 *arguments[2],
3090 arguments[1]->type(),
3091 Precedence::kRelational) +
3092 " < 0)");
3093 }
3094 return this->assembleSimpleIntrinsic("faceForward", call);
3095 }
3096 case k_frexp_IntrinsicKind:
3097 // SkSL frexp is "$genType fract = frexp($genType, out $genIType exp)" whereas WGSL
3098 // returns a struct with no out param: "let [fract, exp] = frexp($genType)".
3099 return this->assembleOutAssignedIntrinsic("frexp", "fract", "exp", call);
3100
3101 case k_greaterThan_IntrinsicKind:
3102 return this->assembleBinaryOpIntrinsic(OperatorKind::GT, call, parentPrecedence);
3103
3104 case k_greaterThanEqual_IntrinsicKind:
3105 return this->assembleBinaryOpIntrinsic(OperatorKind::GTEQ, call, parentPrecedence);
3106
3107 case k_inverse_IntrinsicKind:
3108 return this->assembleInversePolyfill(call);
3109
3110 case k_inversesqrt_IntrinsicKind:
3111 return this->assembleSimpleIntrinsic("inverseSqrt", call);
3112
3113 case k_lessThan_IntrinsicKind:
3114 return this->assembleBinaryOpIntrinsic(OperatorKind::LT, call, parentPrecedence);
3115
3116 case k_lessThanEqual_IntrinsicKind:
3117 return this->assembleBinaryOpIntrinsic(OperatorKind::LTEQ, call, parentPrecedence);
3118
3119 case k_matrixCompMult_IntrinsicKind: {
3120 // We use a scratch-let for arg0 to avoid the potential for WGSL overflow. (skia:14385)
3121 std::string arg0 = all_arguments_constant(arguments)
3122 ? this->writeScratchLet(*arguments[0], Precedence::kPostfix)
3123 : this->writeNontrivialScratchLet(*arguments[0], Precedence::kPostfix);
3124 std::string arg1 = this->writeNontrivialScratchLet(*arguments[1], Precedence::kPostfix);
3125 return this->writeScratchLet(
3126 this->assembleComponentwiseMatrixBinary(arguments[0]->type(),
3127 arguments[1]->type(),
3128 arg0,
3129 arg1,
3130 OperatorKind::STAR));
3131 }
3132 case k_mix_IntrinsicKind: {
3133 const char* name = arguments[2]->type().componentType().isBoolean() ? "select" : "mix";
3134 return this->assembleVectorizedIntrinsic(name, call);
3135 }
3136 case k_mod_IntrinsicKind: {
3137 // WGSL has no intrinsic equivalent to `mod`. Synthesize `x - y * floor(x / y)`.
3138 // We can use a scratch-let on one side to dodge WGSL overflow errors. In practice, I
3139 // can't find any values of x or y which would overflow, but it can't hurt. (skia:14385)
3140 std::string arg0 = all_arguments_constant(arguments)
3141 ? this->writeScratchLet(*arguments[0], Precedence::kAdditive)
3142 : this->writeNontrivialScratchLet(*arguments[0], Precedence::kAdditive);
3143 std::string arg1 = this->writeNontrivialScratchLet(*arguments[1],
3144 Precedence::kAdditive);
3145 return this->writeScratchLet(arg0 + " - " + arg1 + " * floor(" +
3146 arg0 + " / " + arg1 + ")");
3147 }
3148
3149 case k_modf_IntrinsicKind:
3150 // SkSL modf is "$genType fract = modf($genType, out $genType whole)" whereas WGSL
3151 // returns a struct with no out param: "let [fract, whole] = modf($genType)".
3152 return this->assembleOutAssignedIntrinsic("modf", "fract", "whole", call);
3153
3154 case k_normalize_IntrinsicKind: {
3155 const char* name = arguments[0]->type().isScalar() ? "sign" : "normalize";
3156 return this->assembleSimpleIntrinsic(name, call);
3157 }
3158 case k_not_IntrinsicKind:
3159 return this->assembleUnaryOpIntrinsic(OperatorKind::LOGICALNOT, call, parentPrecedence);
3160
3161 case k_notEqual_IntrinsicKind:
3162 return this->assembleBinaryOpIntrinsic(OperatorKind::NEQ, call, parentPrecedence);
3163
3164 case k_packHalf2x16_IntrinsicKind:
3165 return this->assembleSimpleIntrinsic("pack2x16float", call);
3166
3167 case k_packSnorm2x16_IntrinsicKind:
3168 return this->assembleSimpleIntrinsic("pack2x16snorm", call);
3169
3170 case k_packSnorm4x8_IntrinsicKind:
3171 return this->assembleSimpleIntrinsic("pack4x8snorm", call);
3172
3173 case k_packUnorm2x16_IntrinsicKind:
3174 return this->assembleSimpleIntrinsic("pack2x16unorm", call);
3175
3176 case k_packUnorm4x8_IntrinsicKind:
3177 return this->assembleSimpleIntrinsic("pack4x8unorm", call);
3178
3179 case k_reflect_IntrinsicKind:
3180 if (arguments[0]->type().isScalar()) {
3181 // I - 2 * N * I * N
3182 // We can use a scratch-let for N to dodge WGSL overflow errors. (skia:14385)
3183 std::string I = this->writeNontrivialScratchLet(*arguments[0],
3184 Precedence::kAdditive);
3185 std::string N = all_arguments_constant(arguments)
3186 ? this->writeScratchLet(*arguments[1], Precedence::kMultiplicative)
3187 : this->writeNontrivialScratchLet(*arguments[1], Precedence::kMultiplicative);
3188 return this->writeScratchLet(String::printf("%s - 2 * %s * %s * %s",
3189 I.c_str(), N.c_str(),
3190 I.c_str(), N.c_str()));
3191 }
3192 return this->assembleSimpleIntrinsic("reflect", call);
3193
3194 case k_refract_IntrinsicKind:
3195 if (arguments[0]->type().isScalar()) {
3196 // WGSL only implements refract for vectors; rather than reimplementing refract from
3197 // scratch, we can replace the call with `refract(float2(I,0), float2(N,0), eta).x`.
3198 std::string I = this->writeNontrivialScratchLet(*arguments[0],
3199 Precedence::kSequence);
3200 std::string N = this->writeNontrivialScratchLet(*arguments[1],
3201 Precedence::kSequence);
3202 // We can use a scratch-let for Eta to avoid WGSL overflow errors. (skia:14385)
3203 std::string Eta = all_arguments_constant(arguments)
3204 ? this->writeScratchLet(*arguments[2], Precedence::kSequence)
3205 : this->writeNontrivialScratchLet(*arguments[2], Precedence::kSequence);
3206 return this->writeScratchLet(
3207 String::printf("refract(vec2<%s>(%s, 0), vec2<%s>(%s, 0), %s).x",
3208 to_wgsl_type(fContext, arguments[0]->type()).c_str(),
3209 I.c_str(),
3210 to_wgsl_type(fContext, arguments[1]->type()).c_str(),
3211 N.c_str(),
3212 Eta.c_str()));
3213 }
3214 return this->assembleSimpleIntrinsic("refract", call);
3215
3216 case k_sample_IntrinsicKind: {
3217 // Determine if a bias argument was passed in.
3218 SkASSERT(arguments.size() == 2 || arguments.size() == 3);
3219 bool callIncludesBias = (arguments.size() == 3);
3220
3221 if (fProgram.fConfig->fSettings.fSharpenTextures || callIncludesBias) {
3222 // We need to supply a bias argument; this is a separate intrinsic in WGSL.
3223 std::string expr = this->assemblePartialSampleCall("textureSampleBias",
3224 *arguments[0],
3225 *arguments[1]);
3226 expr += ", ";
3227 if (callIncludesBias) {
3228 expr += this->assembleExpression(*arguments[2], Precedence::kAdditive) +
3229 " + ";
3230 }
3231 expr += skstd::to_string(fProgram.fConfig->fSettings.fSharpenTextures
3232 ? kSharpenTexturesBias
3233 : 0.0f);
3234 return expr + ')';
3235 }
3236
3237 // No bias is necessary, so we can call `textureSample` directly.
3238 return this->assemblePartialSampleCall("textureSample",
3239 *arguments[0],
3240 *arguments[1]) + ')';
3241 }
3242 case k_sampleLod_IntrinsicKind: {
3243 std::string expr = this->assemblePartialSampleCall("textureSampleLevel",
3244 *arguments[0],
3245 *arguments[1]);
3246 expr += ", " + this->assembleExpression(*arguments[2], Precedence::kSequence);
3247 return expr + ')';
3248 }
3249 case k_sampleGrad_IntrinsicKind: {
3250 std::string expr = this->assemblePartialSampleCall("textureSampleGrad",
3251 *arguments[0],
3252 *arguments[1]);
3253 expr += ", " + this->assembleExpression(*arguments[2], Precedence::kSequence);
3254 expr += ", " + this->assembleExpression(*arguments[3], Precedence::kSequence);
3255 return expr + ')';
3256 }
3257 case k_textureHeight_IntrinsicKind:
3258 return this->assembleSimpleIntrinsic("textureDimensions", call) + ".y";
3259
3260 case k_textureRead_IntrinsicKind: {
3261 // We need to inject an extra argument for the mip-level. We don't plan on using mipmaps
3262 // in our storage textures, so we can just pass zero.
3263 std::string tex = this->assembleExpression(*arguments[0], Precedence::kSequence);
3264 std::string pos = this->writeScratchLet(*arguments[1], Precedence::kSequence);
3265 return std::string("textureLoad(") + tex + ", " + pos + ", 0)";
3266 }
3267 case k_textureWidth_IntrinsicKind:
3268 return this->assembleSimpleIntrinsic("textureDimensions", call) + ".x";
3269
3270 case k_textureWrite_IntrinsicKind:
3271 return this->assembleSimpleIntrinsic("textureStore", call);
3272
3273 case k_unpackHalf2x16_IntrinsicKind:
3274 return this->assembleSimpleIntrinsic("unpack2x16float", call);
3275
3276 case k_unpackSnorm2x16_IntrinsicKind:
3277 return this->assembleSimpleIntrinsic("unpack2x16snorm", call);
3278
3279 case k_unpackSnorm4x8_IntrinsicKind:
3280 return this->assembleSimpleIntrinsic("unpack4x8snorm", call);
3281
3282 case k_unpackUnorm2x16_IntrinsicKind:
3283 return this->assembleSimpleIntrinsic("unpack2x16unorm", call);
3284
3285 case k_unpackUnorm4x8_IntrinsicKind:
3286 return this->assembleSimpleIntrinsic("unpack4x8unorm", call);
3287
3288 case k_clamp_IntrinsicKind:
3289 case k_max_IntrinsicKind:
3290 case k_min_IntrinsicKind:
3291 case k_smoothstep_IntrinsicKind:
3292 case k_step_IntrinsicKind:
3293 return this->assembleVectorizedIntrinsic(call.function().name(), call);
3294
3295 case k_abs_IntrinsicKind:
3296 case k_acos_IntrinsicKind:
3297 case k_all_IntrinsicKind:
3298 case k_any_IntrinsicKind:
3299 case k_asin_IntrinsicKind:
3300 case k_atomicAdd_IntrinsicKind:
3301 case k_atomicLoad_IntrinsicKind:
3302 case k_atomicStore_IntrinsicKind:
3303 case k_ceil_IntrinsicKind:
3304 case k_cos_IntrinsicKind:
3305 case k_cross_IntrinsicKind:
3306 case k_degrees_IntrinsicKind:
3307 case k_distance_IntrinsicKind:
3308 case k_exp_IntrinsicKind:
3309 case k_exp2_IntrinsicKind:
3310 case k_floor_IntrinsicKind:
3311 case k_fract_IntrinsicKind:
3312 case k_length_IntrinsicKind:
3313 case k_log_IntrinsicKind:
3314 case k_log2_IntrinsicKind:
3315 case k_radians_IntrinsicKind:
3316 case k_pow_IntrinsicKind:
3317 case k_saturate_IntrinsicKind:
3318 case k_sign_IntrinsicKind:
3319 case k_sin_IntrinsicKind:
3320 case k_sqrt_IntrinsicKind:
3321 case k_storageBarrier_IntrinsicKind:
3322 case k_tan_IntrinsicKind:
3323 case k_workgroupBarrier_IntrinsicKind:
3324 default:
3325 return this->assembleSimpleIntrinsic(call.function().name(), call);
3326 }
3327 }
3328
3329 static constexpr char kInverse2x2[] =
3330 "fn mat2_inverse(m: mat2x2<f32>) -> mat2x2<f32> {"
3331 "\n" "return mat2x2<f32>(m[1].y, -m[0].y, -m[1].x, m[0].x) * (1/determinant(m));"
3332 "\n" "}"
3333 "\n";
3334
3335 static constexpr char kInverse3x3[] =
3336 "fn mat3_inverse(m: mat3x3<f32>) -> mat3x3<f32> {"
3337 "\n" "let a00 = m[0].x; let a01 = m[0].y; let a02 = m[0].z;"
3338 "\n" "let a10 = m[1].x; let a11 = m[1].y; let a12 = m[1].z;"
3339 "\n" "let a20 = m[2].x; let a21 = m[2].y; let a22 = m[2].z;"
3340 "\n" "let b01 = a22*a11 - a12*a21;"
3341 "\n" "let b11 = -a22*a10 + a12*a20;"
3342 "\n" "let b21 = a21*a10 - a11*a20;"
3343 "\n" "let det = a00*b01 + a01*b11 + a02*b21;"
3344 "\n" "return mat3x3<f32>(b01, (-a22*a01 + a02*a21), ( a12*a01 - a02*a11),"
3345 "\n" "b11, ( a22*a00 - a02*a20), (-a12*a00 + a02*a10),"
3346 "\n" "b21, (-a21*a00 + a01*a20), ( a11*a00 - a01*a10)) * (1/det);"
3347 "\n" "}"
3348 "\n";
3349
3350 static constexpr char kInverse4x4[] =
3351 "fn mat4_inverse(m: mat4x4<f32>) -> mat4x4<f32>{"
3352 "\n" "let a00 = m[0].x; let a01 = m[0].y; let a02 = m[0].z; let a03 = m[0].w;"
3353 "\n" "let a10 = m[1].x; let a11 = m[1].y; let a12 = m[1].z; let a13 = m[1].w;"
3354 "\n" "let a20 = m[2].x; let a21 = m[2].y; let a22 = m[2].z; let a23 = m[2].w;"
3355 "\n" "let a30 = m[3].x; let a31 = m[3].y; let a32 = m[3].z; let a33 = m[3].w;"
3356 "\n" "let b00 = a00*a11 - a01*a10;"
3357 "\n" "let b01 = a00*a12 - a02*a10;"
3358 "\n" "let b02 = a00*a13 - a03*a10;"
3359 "\n" "let b03 = a01*a12 - a02*a11;"
3360 "\n" "let b04 = a01*a13 - a03*a11;"
3361 "\n" "let b05 = a02*a13 - a03*a12;"
3362 "\n" "let b06 = a20*a31 - a21*a30;"
3363 "\n" "let b07 = a20*a32 - a22*a30;"
3364 "\n" "let b08 = a20*a33 - a23*a30;"
3365 "\n" "let b09 = a21*a32 - a22*a31;"
3366 "\n" "let b10 = a21*a33 - a23*a31;"
3367 "\n" "let b11 = a22*a33 - a23*a32;"
3368 "\n" "let det = b00*b11 - b01*b10 + b02*b09 + b03*b08 - b04*b07 + b05*b06;"
3369 "\n" "return mat4x4<f32>(a11*b11 - a12*b10 + a13*b09,"
3370 "\n" "a02*b10 - a01*b11 - a03*b09,"
3371 "\n" "a31*b05 - a32*b04 + a33*b03,"
3372 "\n" "a22*b04 - a21*b05 - a23*b03,"
3373 "\n" "a12*b08 - a10*b11 - a13*b07,"
3374 "\n" "a00*b11 - a02*b08 + a03*b07,"
3375 "\n" "a32*b02 - a30*b05 - a33*b01,"
3376 "\n" "a20*b05 - a22*b02 + a23*b01,"
3377 "\n" "a10*b10 - a11*b08 + a13*b06,"
3378 "\n" "a01*b08 - a00*b10 - a03*b06,"
3379 "\n" "a30*b04 - a31*b02 + a33*b00,"
3380 "\n" "a21*b02 - a20*b04 - a23*b00,"
3381 "\n" "a11*b07 - a10*b09 - a12*b06,"
3382 "\n" "a00*b09 - a01*b07 + a02*b06,"
3383 "\n" "a31*b01 - a30*b03 - a32*b00,"
3384 "\n" "a20*b03 - a21*b01 + a22*b00) * (1/det);"
3385 "\n" "}"
3386 "\n";
3387
assembleInversePolyfill(const FunctionCall & call)3388 std::string WGSLCodeGenerator::assembleInversePolyfill(const FunctionCall& call) {
3389 const ExpressionArray& arguments = call.arguments();
3390 const Type& type = arguments.front()->type();
3391
3392 // The `inverse` intrinsic should only accept a single-argument square matrix.
3393 // Once we implement f16 support, these polyfills will need to be updated to support `hmat`;
3394 // for the time being, all floats in WGSL are f32, so we don't need to worry about precision.
3395 SkASSERT(arguments.size() == 1);
3396 SkASSERT(type.isMatrix());
3397 SkASSERT(type.rows() == type.columns());
3398
3399 switch (type.slotCount()) {
3400 case 4:
3401 if (!fWrittenInverse2) {
3402 fWrittenInverse2 = true;
3403 fHeader.writeText(kInverse2x2);
3404 }
3405 return this->assembleSimpleIntrinsic("mat2_inverse", call);
3406
3407 case 9:
3408 if (!fWrittenInverse3) {
3409 fWrittenInverse3 = true;
3410 fHeader.writeText(kInverse3x3);
3411 }
3412 return this->assembleSimpleIntrinsic("mat3_inverse", call);
3413
3414 case 16:
3415 if (!fWrittenInverse4) {
3416 fWrittenInverse4 = true;
3417 fHeader.writeText(kInverse4x4);
3418 }
3419 return this->assembleSimpleIntrinsic("mat4_inverse", call);
3420
3421 default:
3422 // We only support square matrices.
3423 SkUNREACHABLE;
3424 }
3425 }
3426
assembleFunctionCall(const FunctionCall & call,Precedence parentPrecedence)3427 std::string WGSLCodeGenerator::assembleFunctionCall(const FunctionCall& call,
3428 Precedence parentPrecedence) {
3429 const FunctionDeclaration& func = call.function();
3430 std::string result;
3431
3432 // Many intrinsics need to be rewritten in WGSL.
3433 if (func.isIntrinsic()) {
3434 return this->assembleIntrinsicCall(call, func.intrinsicKind(), parentPrecedence);
3435 }
3436
3437 // We implement function out-parameters by declaring them as pointers. SkSL follows GLSL's
3438 // out-parameter semantics, in which out-parameters are only written back to the original
3439 // variable after the function's execution is complete (see
3440 // https://www.khronos.org/opengl/wiki/Core_Language_(GLSL)#Parameters).
3441 //
3442 // In addition, SkSL supports swizzles and array index expressions to be passed into
3443 // out-parameters; however, WGSL does not allow taking their address into a pointer.
3444 //
3445 // We support these by using LValues to create temporary copies and then pass pointers to the
3446 // copies. Once the function returns, we copy the values back to the LValue.
3447
3448 // First detect which arguments are passed to out-parameters.
3449 // TODO: rewrite this method in terms of LValues.
3450 const ExpressionArray& args = call.arguments();
3451 SkSpan<Variable* const> params = func.parameters();
3452 SkASSERT(SkToSizeT(args.size()) == params.size());
3453
3454 STArray<16, std::unique_ptr<LValue>> writeback;
3455 STArray<16, std::string> substituteArgument;
3456 writeback.reserve_exact(args.size());
3457 substituteArgument.reserve_exact(args.size());
3458
3459 for (int index = 0; index < args.size(); ++index) {
3460 if (params[index]->modifierFlags() & ModifierFlag::kOut) {
3461 std::unique_ptr<LValue> lvalue = this->makeLValue(*args[index]);
3462 if (params[index]->modifierFlags() & ModifierFlag::kIn) {
3463 // Load the lvalue's contents into the substitute argument.
3464 substituteArgument.push_back(this->writeScratchVar(args[index]->type(),
3465 lvalue->load()));
3466 } else {
3467 // Create a substitute argument, but leave it uninitialized.
3468 substituteArgument.push_back(this->writeScratchVar(args[index]->type()));
3469 }
3470 writeback.push_back(std::move(lvalue));
3471 } else {
3472 substituteArgument.push_back(std::string());
3473 writeback.push_back(nullptr);
3474 }
3475 }
3476
3477 std::string expr = this->assembleName(func.mangledName());
3478 expr.push_back('(');
3479 auto separator = SkSL::String::Separator();
3480
3481 if (std::string funcDepArgs = this->functionDependencyArgs(func); !funcDepArgs.empty()) {
3482 expr += funcDepArgs;
3483 separator();
3484 }
3485
3486 // Pass the function arguments, or any substitutes as needed.
3487 for (int index = 0; index < args.size(); ++index) {
3488 expr += separator();
3489 if (!substituteArgument[index].empty()) {
3490 // We need to take the address of the variable and pass it down as a pointer.
3491 expr += '&' + substituteArgument[index];
3492 } else if (args[index]->type().isSampler()) {
3493 // If the argument is a sampler, we need to pass the texture _and_ its associated
3494 // sampler. (Function parameter lists also convert sampler parameters into a matching
3495 // texture/sampler parameter pair.)
3496 expr += this->assembleExpression(*args[index], Precedence::kSequence);
3497 expr += kTextureSuffix;
3498 expr += ", ";
3499 expr += this->assembleExpression(*args[index], Precedence::kSequence);
3500 expr += kSamplerSuffix;
3501 } else {
3502 expr += this->assembleExpression(*args[index], Precedence::kSequence);
3503 }
3504 }
3505 expr += ')';
3506
3507 if (call.type().isVoid()) {
3508 // Making function calls that result in `void` is only valid in on the left side of a
3509 // comma-sequence, or in a top-level statement. Emit the function call as a top-level
3510 // statement and return an empty string, as the result will not be used.
3511 SkASSERT(parentPrecedence >= Precedence::kSequence);
3512 this->write(expr);
3513 this->writeLine(";");
3514 } else {
3515 result = this->writeScratchLet(expr);
3516 }
3517
3518 // Write the substitute arguments back into their lvalues.
3519 for (int index = 0; index < args.size(); ++index) {
3520 if (!substituteArgument[index].empty()) {
3521 this->writeLine(writeback[index]->store(substituteArgument[index]));
3522 }
3523 }
3524
3525 // Return the result of invoking the function.
3526 return result;
3527 }
3528
assembleIndexExpression(const IndexExpression & i)3529 std::string WGSLCodeGenerator::assembleIndexExpression(const IndexExpression& i) {
3530 // Put the index value into a let-expression.
3531 std::string idx = this->writeNontrivialScratchLet(*i.index(), Precedence::kExpression);
3532 return this->assembleExpression(*i.base(), Precedence::kPostfix) + "[" + idx + "]";
3533 }
3534
assembleLiteral(const Literal & l)3535 std::string WGSLCodeGenerator::assembleLiteral(const Literal& l) {
3536 const Type& type = l.type();
3537 if (type.isFloat() || type.isBoolean()) {
3538 return l.description(OperatorPrecedence::kExpression);
3539 }
3540 SkASSERT(type.isInteger());
3541 if (type.matches(*fContext.fTypes.fUInt)) {
3542 return std::to_string(l.intValue() & 0xffffffff) + "u";
3543 } else if (type.matches(*fContext.fTypes.fUShort)) {
3544 return std::to_string(l.intValue() & 0xffff) + "u";
3545 } else {
3546 return std::to_string(l.intValue());
3547 }
3548 }
3549
assembleIncrementExpr(const Type & type)3550 std::string WGSLCodeGenerator::assembleIncrementExpr(const Type& type) {
3551 // `type(`
3552 std::string expr = to_wgsl_type(fContext, type);
3553 expr.push_back('(');
3554
3555 // `1, 1, 1...)`
3556 auto separator = SkSL::String::Separator();
3557 for (int slots = type.slotCount(); slots > 0; --slots) {
3558 expr += separator();
3559 expr += "1";
3560 }
3561 expr.push_back(')');
3562 return expr;
3563 }
3564
assemblePrefixExpression(const PrefixExpression & p,Precedence parentPrecedence)3565 std::string WGSLCodeGenerator::assemblePrefixExpression(const PrefixExpression& p,
3566 Precedence parentPrecedence) {
3567 // Unary + does nothing, so we omit it from the output.
3568 Operator op = p.getOperator();
3569 if (op.kind() == Operator::Kind::PLUS) {
3570 return this->assembleExpression(*p.operand(), Precedence::kPrefix);
3571 }
3572
3573 // Pre-increment/decrement expressions have no direct equivalent in WGSL.
3574 if (op.kind() == Operator::Kind::PLUSPLUS || op.kind() == Operator::Kind::MINUSMINUS) {
3575 std::unique_ptr<LValue> lvalue = this->makeLValue(*p.operand());
3576 if (!lvalue) {
3577 return "";
3578 }
3579
3580 // Generate the new value: `lvalue + type(1, 1, 1...)`.
3581 std::string newValue =
3582 lvalue->load() +
3583 (p.getOperator().kind() == Operator::Kind::PLUSPLUS ? " + " : " - ") +
3584 this->assembleIncrementExpr(p.operand()->type());
3585 this->writeLine(lvalue->store(newValue));
3586 return lvalue->load();
3587 }
3588
3589 // WGSL natively supports unary negation/not expressions (!,~,-).
3590 SkASSERT(op.kind() == OperatorKind::LOGICALNOT ||
3591 op.kind() == OperatorKind::BITWISENOT ||
3592 op.kind() == OperatorKind::MINUS);
3593
3594 // The unary negation operator only applies to scalars and vectors. For other mathematical
3595 // objects (such as matrices) we can express it as a multiplication by -1.
3596 std::string expr;
3597 const bool needsNegation = op.kind() == Operator::Kind::MINUS &&
3598 !p.operand()->type().isScalar() && !p.operand()->type().isVector();
3599 const bool needParens = needsNegation || Precedence::kPrefix >= parentPrecedence;
3600
3601 if (needParens) {
3602 expr.push_back('(');
3603 }
3604
3605 if (needsNegation) {
3606 expr += "-1.0 * ";
3607 expr += this->assembleExpression(*p.operand(), Precedence::kMultiplicative);
3608 } else {
3609 expr += p.getOperator().tightOperatorName();
3610 expr += this->assembleExpression(*p.operand(), Precedence::kPrefix);
3611 }
3612
3613 if (needParens) {
3614 expr.push_back(')');
3615 }
3616
3617 return expr;
3618 }
3619
assemblePostfixExpression(const PostfixExpression & p,Precedence parentPrecedence)3620 std::string WGSLCodeGenerator::assemblePostfixExpression(const PostfixExpression& p,
3621 Precedence parentPrecedence) {
3622 SkASSERT(p.getOperator().kind() == Operator::Kind::PLUSPLUS ||
3623 p.getOperator().kind() == Operator::Kind::MINUSMINUS);
3624
3625 // Post-increment/decrement expressions have no direct equivalent in WGSL; they do exist as a
3626 // standalone statement for convenience, but these aren't the same as SkSL's post-increments.
3627 std::unique_ptr<LValue> lvalue = this->makeLValue(*p.operand());
3628 if (!lvalue) {
3629 return "";
3630 }
3631
3632 // If the expression is used, create a let-copy of the original value.
3633 // (At statement-level precedence, we know the value is unused and can skip this let-copy.)
3634 std::string originalValue;
3635 if (parentPrecedence != Precedence::kStatement) {
3636 originalValue = this->writeScratchLet(lvalue->load());
3637 }
3638 // Generate the new value: `lvalue + type(1, 1, 1...)`.
3639 std::string newValue = lvalue->load() +
3640 (p.getOperator().kind() == Operator::Kind::PLUSPLUS ? " + " : " - ") +
3641 this->assembleIncrementExpr(p.operand()->type());
3642 this->writeLine(lvalue->store(newValue));
3643
3644 return originalValue;
3645 }
3646
assembleSwizzle(const Swizzle & swizzle)3647 std::string WGSLCodeGenerator::assembleSwizzle(const Swizzle& swizzle) {
3648 return this->assembleExpression(*swizzle.base(), Precedence::kPostfix) + "." +
3649 Swizzle::MaskString(swizzle.components());
3650 }
3651
writeScratchVar(const Type & type,const std::string & value)3652 std::string WGSLCodeGenerator::writeScratchVar(const Type& type, const std::string& value) {
3653 std::string scratchVarName = "_skTemp" + std::to_string(fScratchCount++);
3654 this->write("var ");
3655 this->write(scratchVarName);
3656 this->write(": ");
3657 this->write(to_wgsl_type(fContext, type));
3658 if (!value.empty()) {
3659 this->write(" = ");
3660 this->write(value);
3661 }
3662 this->writeLine(";");
3663 return scratchVarName;
3664 }
3665
writeScratchLet(const std::string & expr)3666 std::string WGSLCodeGenerator::writeScratchLet(const std::string& expr) {
3667 std::string scratchVarName = "_skTemp" + std::to_string(fScratchCount++);
3668 this->write(fAtFunctionScope ? "let " : "const ");
3669 this->write(scratchVarName);
3670 this->write(" = ");
3671 this->write(expr);
3672 this->writeLine(";");
3673 return scratchVarName;
3674 }
3675
writeScratchLet(const Expression & expr,Precedence parentPrecedence)3676 std::string WGSLCodeGenerator::writeScratchLet(const Expression& expr,
3677 Precedence parentPrecedence) {
3678 return this->writeScratchLet(this->assembleExpression(expr, parentPrecedence));
3679 }
3680
writeNontrivialScratchLet(const Expression & expr,Precedence parentPrecedence)3681 std::string WGSLCodeGenerator::writeNontrivialScratchLet(const Expression& expr,
3682 Precedence parentPrecedence) {
3683 std::string result = this->assembleExpression(expr, parentPrecedence);
3684 return is_nontrivial_expression(expr) ? this->writeScratchLet(result)
3685 : result;
3686 }
3687
assembleTernaryExpression(const TernaryExpression & t,Precedence parentPrecedence)3688 std::string WGSLCodeGenerator::assembleTernaryExpression(const TernaryExpression& t,
3689 Precedence parentPrecedence) {
3690 std::string expr;
3691
3692 // The trivial case is when neither branch has side effects and evaluate to a scalar or vector
3693 // type. This can be represented with a call to the WGSL `select` intrinsic. Select doesn't
3694 // support short-circuiting, so we should only use it when both the true- and false-expressions
3695 // are trivial to evaluate.
3696 if ((t.type().isScalar() || t.type().isVector()) &&
3697 !Analysis::HasSideEffects(*t.test()) &&
3698 Analysis::IsTrivialExpression(*t.ifTrue()) &&
3699 Analysis::IsTrivialExpression(*t.ifFalse())) {
3700
3701 bool needParens = Precedence::kTernary >= parentPrecedence;
3702 if (needParens) {
3703 expr.push_back('(');
3704 }
3705 expr += "select(";
3706 expr += this->assembleExpression(*t.ifFalse(), Precedence::kSequence);
3707 expr += ", ";
3708 expr += this->assembleExpression(*t.ifTrue(), Precedence::kSequence);
3709 expr += ", ";
3710
3711 bool isVector = t.type().isVector();
3712 if (isVector) {
3713 // Splat the condition expression into a vector.
3714 expr += String::printf("vec%d<bool>(", t.type().columns());
3715 }
3716 expr += this->assembleExpression(*t.test(), Precedence::kSequence);
3717 if (isVector) {
3718 expr.push_back(')');
3719 }
3720 expr.push_back(')');
3721 if (needParens) {
3722 expr.push_back(')');
3723 }
3724 } else {
3725 // WGSL does not support ternary expressions. Instead, we hoist the expression out into the
3726 // surrounding block, convert it into an if statement, and write the result to a synthesized
3727 // variable. Instead of the original expression, we return that variable.
3728 expr = this->writeScratchVar(t.ifTrue()->type());
3729
3730 std::string testExpr = this->assembleExpression(*t.test(), Precedence::kExpression);
3731 this->write("if ");
3732 this->write(testExpr);
3733 this->writeLine(" {");
3734
3735 ++fIndentation;
3736 std::string trueExpr = this->assembleExpression(*t.ifTrue(), Precedence::kAssignment);
3737 this->write(expr);
3738 this->write(" = ");
3739 this->write(trueExpr);
3740 this->writeLine(";");
3741 --fIndentation;
3742
3743 this->writeLine("} else {");
3744
3745 ++fIndentation;
3746 std::string falseExpr = this->assembleExpression(*t.ifFalse(), Precedence::kAssignment);
3747 this->write(expr);
3748 this->write(" = ");
3749 this->write(falseExpr);
3750 this->writeLine(";");
3751 --fIndentation;
3752
3753 this->writeLine("}");
3754 }
3755 return expr;
3756 }
3757
variablePrefix(const Variable & v)3758 std::string WGSLCodeGenerator::variablePrefix(const Variable& v) {
3759 if (v.storage() == Variable::Storage::kGlobal) {
3760 // If the field refers to a pipeline IO parameter, then we access it via the synthesized IO
3761 // structs. We make an explicit exception for `sk_PointSize` which we declare as a
3762 // placeholder variable in global scope as it is not supported by WebGPU as a pipeline IO
3763 // parameter (see comments in `writeStageOutputStruct`).
3764 if (v.modifierFlags() & ModifierFlag::kIn) {
3765 return "_stageIn.";
3766 }
3767 if (v.modifierFlags() & ModifierFlag::kOut) {
3768 return "(*_stageOut).";
3769 }
3770
3771 // If the field refers to an anonymous-interface-block structure, access it via the
3772 // synthesized `_uniform0` or `_storage1` global.
3773 if (const InterfaceBlock* ib = v.interfaceBlock()) {
3774 const Type& ibType = ib->var()->type().componentType();
3775 if (const std::string* ibName = fInterfaceBlockNameMap.find(&ibType)) {
3776 return *ibName + '.';
3777 }
3778 }
3779
3780 // If the field refers to an top-level uniform, access it via the synthesized
3781 // `_globalUniforms` global. (Note that this should only occur in test code; Skia will
3782 // always put uniforms in an interface block.)
3783 if (is_in_global_uniforms(v)) {
3784 return "_globalUniforms.";
3785 }
3786 }
3787
3788 return "";
3789 }
3790
variableReferenceNameForLValue(const VariableReference & r)3791 std::string WGSLCodeGenerator::variableReferenceNameForLValue(const VariableReference& r) {
3792 const Variable& v = *r.variable();
3793
3794 if ((v.storage() == Variable::Storage::kParameter &&
3795 v.modifierFlags() & ModifierFlag::kOut)) {
3796 // This is an out-parameter; it's pointer-typed, so we need to dereference it. We wrap the
3797 // dereference in parentheses, in case the value is used in an access expression later.
3798 return "(*" + this->assembleName(v.mangledName()) + ')';
3799 }
3800
3801 return this->variablePrefix(v) + this->assembleName(v.mangledName());
3802 }
3803
assembleVariableReference(const VariableReference & r)3804 std::string WGSLCodeGenerator::assembleVariableReference(const VariableReference& r) {
3805 // TODO(b/294274678): Correctly handle RTFlip for built-ins.
3806 const Variable& v = *r.variable();
3807
3808 // Insert a conversion expression if this is a built-in variable whose type differs from the
3809 // SkSL.
3810 std::string expr;
3811 std::optional<std::string_view> conversion = needs_builtin_type_conversion(v);
3812 if (conversion.has_value()) {
3813 expr += *conversion;
3814 expr.push_back('(');
3815 }
3816
3817 expr += this->variableReferenceNameForLValue(r);
3818
3819 if (conversion.has_value()) {
3820 expr.push_back(')');
3821 }
3822
3823 return expr;
3824 }
3825
assembleAnyConstructor(const AnyConstructor & c)3826 std::string WGSLCodeGenerator::assembleAnyConstructor(const AnyConstructor& c) {
3827 std::string expr = to_wgsl_type(fContext, c.type());
3828 expr.push_back('(');
3829 auto separator = SkSL::String::Separator();
3830 for (const auto& e : c.argumentSpan()) {
3831 expr += separator();
3832 expr += this->assembleExpression(*e, Precedence::kSequence);
3833 }
3834 expr.push_back(')');
3835 return expr;
3836 }
3837
assembleConstructorCompound(const ConstructorCompound & c)3838 std::string WGSLCodeGenerator::assembleConstructorCompound(const ConstructorCompound& c) {
3839 if (c.type().isVector()) {
3840 return this->assembleConstructorCompoundVector(c);
3841 } else if (c.type().isMatrix()) {
3842 return this->assembleConstructorCompoundMatrix(c);
3843 } else {
3844 fContext.fErrors->error(c.fPosition, "unsupported compound constructor");
3845 return {};
3846 }
3847 }
3848
assembleConstructorCompoundVector(const ConstructorCompound & c)3849 std::string WGSLCodeGenerator::assembleConstructorCompoundVector(const ConstructorCompound& c) {
3850 // WGSL supports constructing vectors from a mix of scalars and vectors but
3851 // not matrices (see https://www.w3.org/TR/WGSL/#type-constructor-expr).
3852 //
3853 // SkSL supports vec4(mat2x2) which we handle specially.
3854 if (c.type().columns() == 4 && c.argumentSpan().size() == 1) {
3855 const Expression& arg = *c.argumentSpan().front();
3856 if (arg.type().isMatrix()) {
3857 SkASSERT(arg.type().columns() == 2);
3858 SkASSERT(arg.type().rows() == 2);
3859
3860 std::string matrix = this->writeNontrivialScratchLet(arg, Precedence::kPostfix);
3861 return String::printf("%s(%s[0], %s[1])", to_wgsl_type(fContext, c.type()).c_str(),
3862 matrix.c_str(),
3863 matrix.c_str());
3864 }
3865 }
3866 return this->assembleAnyConstructor(c);
3867 }
3868
assembleConstructorCompoundMatrix(const ConstructorCompound & ctor)3869 std::string WGSLCodeGenerator::assembleConstructorCompoundMatrix(const ConstructorCompound& ctor) {
3870 SkASSERT(ctor.type().isMatrix());
3871
3872 std::string expr = to_wgsl_type(fContext, ctor.type()) + '(';
3873 auto separator = String::Separator();
3874 for (const std::unique_ptr<Expression>& arg : ctor.arguments()) {
3875 SkASSERT(arg->type().isScalar() || arg->type().isVector());
3876
3877 if (arg->type().isScalar()) {
3878 expr += separator();
3879 expr += this->assembleExpression(*arg, Precedence::kSequence);
3880 } else {
3881 std::string inner = this->writeNontrivialScratchLet(*arg, Precedence::kSequence);
3882 int numSlots = arg->type().slotCount();
3883 for (int slot = 0; slot < numSlots; ++slot) {
3884 String::appendf(&expr, "%s%s[%d]", separator().c_str(), inner.c_str(), slot);
3885 }
3886 }
3887 }
3888 return expr + ')';
3889 }
3890
assembleConstructorDiagonalMatrix(const ConstructorDiagonalMatrix & c)3891 std::string WGSLCodeGenerator::assembleConstructorDiagonalMatrix(
3892 const ConstructorDiagonalMatrix& c) {
3893 const Type& type = c.type();
3894 SkASSERT(type.isMatrix());
3895 SkASSERT(c.argument()->type().isScalar());
3896
3897 // Evaluate the inner-expression, creating a scratch variable if necessary.
3898 std::string inner = this->writeNontrivialScratchLet(*c.argument(), Precedence::kAssignment);
3899
3900 // Assemble a diagonal-matrix expression.
3901 std::string expr = to_wgsl_type(fContext, type) + '(';
3902 auto separator = String::Separator();
3903 for (int col = 0; col < type.columns(); ++col) {
3904 for (int row = 0; row < type.rows(); ++row) {
3905 expr += separator();
3906 if (col == row) {
3907 expr += inner;
3908 } else {
3909 expr += "0.0";
3910 }
3911 }
3912 }
3913 return expr + ')';
3914 }
3915
assembleConstructorMatrixResize(const ConstructorMatrixResize & ctor)3916 std::string WGSLCodeGenerator::assembleConstructorMatrixResize(
3917 const ConstructorMatrixResize& ctor) {
3918 std::string source = this->writeScratchLet(this->assembleExpression(*ctor.argument(),
3919 Precedence::kSequence));
3920 int columns = ctor.type().columns();
3921 int rows = ctor.type().rows();
3922 int sourceColumns = ctor.argument()->type().columns();
3923 int sourceRows = ctor.argument()->type().rows();
3924 auto separator = String::Separator();
3925 std::string expr = to_wgsl_type(fContext, ctor.type()) + '(';
3926
3927 for (int c = 0; c < columns; ++c) {
3928 for (int r = 0; r < rows; ++r) {
3929 expr += separator();
3930 if (c < sourceColumns && r < sourceRows) {
3931 String::appendf(&expr, "%s[%d][%d]", source.c_str(), c, r);
3932 } else if (r == c) {
3933 expr += "1.0";
3934 } else {
3935 expr += "0.0";
3936 }
3937 }
3938 }
3939
3940 return expr + ')';
3941 }
3942
assembleEqualityExpression(const Type & left,const std::string & leftName,const Type & right,const std::string & rightName,Operator op,Precedence parentPrecedence)3943 std::string WGSLCodeGenerator::assembleEqualityExpression(const Type& left,
3944 const std::string& leftName,
3945 const Type& right,
3946 const std::string& rightName,
3947 Operator op,
3948 Precedence parentPrecedence) {
3949 SkASSERT(op.kind() == OperatorKind::EQEQ || op.kind() == OperatorKind::NEQ);
3950
3951 std::string expr;
3952 bool isEqual = (op.kind() == Operator::Kind::EQEQ);
3953 const char* const combiner = isEqual ? " && " : " || ";
3954
3955 if (left.isMatrix()) {
3956 // Each matrix column must be compared as if it were an individual vector.
3957 SkASSERT(right.isMatrix());
3958 SkASSERT(left.rows() == right.rows());
3959 SkASSERT(left.columns() == right.columns());
3960 int columns = left.columns();
3961 const Type& vecType = left.columnType(fContext);
3962 const char* separator = "(";
3963 for (int index = 0; index < columns; ++index) {
3964 expr += separator;
3965 std::string suffix = '[' + std::to_string(index) + ']';
3966 expr += this->assembleEqualityExpression(vecType, leftName + suffix,
3967 vecType, rightName + suffix,
3968 op, Precedence::kParentheses);
3969 separator = combiner;
3970 }
3971 return expr + ')';
3972 }
3973
3974 if (left.isArray()) {
3975 SkASSERT(right.matches(left));
3976 const Type& indexedType = left.componentType();
3977 const char* separator = "(";
3978 for (int index = 0; index < left.columns(); ++index) {
3979 expr += separator;
3980 std::string suffix = '[' + std::to_string(index) + ']';
3981 expr += this->assembleEqualityExpression(indexedType, leftName + suffix,
3982 indexedType, rightName + suffix,
3983 op, Precedence::kParentheses);
3984 separator = combiner;
3985 }
3986 return expr + ')';
3987 }
3988
3989 if (left.isStruct()) {
3990 // Recursively compare every field in the struct.
3991 SkASSERT(right.matches(left));
3992 SkSpan<const Field> fields = left.fields();
3993
3994 const char* separator = "(";
3995 for (const Field& field : fields) {
3996 expr += separator;
3997 expr += this->assembleEqualityExpression(
3998 *field.fType, leftName + '.' + this->assembleName(field.fName),
3999 *field.fType, rightName + '.' + this->assembleName(field.fName),
4000 op, Precedence::kParentheses);
4001 separator = combiner;
4002 }
4003 return expr + ')';
4004 }
4005
4006 if (left.isVector()) {
4007 // Compare vectors via `all(x == y)` or `any(x != y)`.
4008 SkASSERT(right.isVector());
4009 SkASSERT(left.slotCount() == right.slotCount());
4010
4011 expr += isEqual ? "all(" : "any(";
4012 expr += leftName;
4013 expr += operator_name(op);
4014 expr += rightName;
4015 return expr + ')';
4016 }
4017
4018 // Compare scalars via `x == y`.
4019 SkASSERT(right.isScalar());
4020 if (parentPrecedence < Precedence::kSequence) {
4021 expr = '(';
4022 }
4023 expr += leftName;
4024 expr += operator_name(op);
4025 expr += rightName;
4026 if (parentPrecedence < Precedence::kSequence) {
4027 expr += ')';
4028 }
4029 return expr;
4030 }
4031
assembleEqualityExpression(const Expression & left,const Expression & right,Operator op,Precedence parentPrecedence)4032 std::string WGSLCodeGenerator::assembleEqualityExpression(const Expression& left,
4033 const Expression& right,
4034 Operator op,
4035 Precedence parentPrecedence) {
4036 std::string leftName, rightName;
4037 if (left.type().isScalar() || left.type().isVector()) {
4038 // WGSL supports scalar and vector comparisons natively. We know the expressions will only
4039 // be emitted once, so there isn't a benefit to creating a let-declaration.
4040 leftName = this->assembleExpression(left, Precedence::kParentheses);
4041 rightName = this->assembleExpression(right, Precedence::kParentheses);
4042 } else {
4043 leftName = this->writeNontrivialScratchLet(left, Precedence::kAssignment);
4044 rightName = this->writeNontrivialScratchLet(right, Precedence::kAssignment);
4045 }
4046 return this->assembleEqualityExpression(left.type(), leftName, right.type(), rightName,
4047 op, parentPrecedence);
4048 }
4049
writeProgramElement(const ProgramElement & e)4050 void WGSLCodeGenerator::writeProgramElement(const ProgramElement& e) {
4051 switch (e.kind()) {
4052 case ProgramElement::Kind::kExtension:
4053 // TODO(skia:13092): WGSL supports extensions via the "enable" directive
4054 // (https://www.w3.org/TR/WGSL/#enable-extensions-sec ). While we could easily emit this
4055 // directive, we should first ensure that all possible SkSL extension names are
4056 // converted to their appropriate WGSL extension.
4057 break;
4058 case ProgramElement::Kind::kGlobalVar:
4059 this->writeGlobalVarDeclaration(e.as<GlobalVarDeclaration>());
4060 break;
4061 case ProgramElement::Kind::kInterfaceBlock:
4062 // All interface block declarations are handled explicitly as the "program header" in
4063 // generateCode().
4064 break;
4065 case ProgramElement::Kind::kStructDefinition:
4066 this->writeStructDefinition(e.as<StructDefinition>());
4067 break;
4068 case ProgramElement::Kind::kFunctionPrototype:
4069 // A WGSL function declaration must contain its body and the function name is in scope
4070 // for the entire program (see https://www.w3.org/TR/WGSL/#function-declaration and
4071 // https://www.w3.org/TR/WGSL/#declaration-and-scope).
4072 //
4073 // As such, we don't emit function prototypes.
4074 break;
4075 case ProgramElement::Kind::kFunction:
4076 this->writeFunction(e.as<FunctionDefinition>());
4077 break;
4078 case ProgramElement::Kind::kModifiers:
4079 this->writeModifiersDeclaration(e.as<ModifiersDeclaration>());
4080 break;
4081 default:
4082 SkDEBUGFAILF("unsupported program element: %s\n", e.description().c_str());
4083 break;
4084 }
4085 }
4086
writeTextureOrSampler(const Variable & var,int bindingLocation,std::string_view suffix,std::string_view wgslType)4087 void WGSLCodeGenerator::writeTextureOrSampler(const Variable& var,
4088 int bindingLocation,
4089 std::string_view suffix,
4090 std::string_view wgslType) {
4091 if (var.type().dimensions() != SpvDim2D) {
4092 // Skia currently only uses 2D textures.
4093 fContext.fErrors->error(var.varDeclaration()->position(), "unsupported texture dimensions");
4094 return;
4095 }
4096
4097 this->write("@group(");
4098 this->write(std::to_string(std::max(0, var.layout().fSet)));
4099 this->write(") @binding(");
4100 this->write(std::to_string(bindingLocation));
4101 this->write(") var ");
4102 this->write(this->assembleName(var.mangledName()));
4103 this->write(suffix);
4104 this->write(": ");
4105 this->write(wgslType);
4106 this->writeLine(";");
4107 }
4108
writeGlobalVarDeclaration(const GlobalVarDeclaration & d)4109 void WGSLCodeGenerator::writeGlobalVarDeclaration(const GlobalVarDeclaration& d) {
4110 const VarDeclaration& decl = d.varDeclaration();
4111 const Variable& var = *decl.var();
4112 if ((var.modifierFlags() & (ModifierFlag::kIn | ModifierFlag::kOut)) ||
4113 is_in_global_uniforms(var)) {
4114 // Pipeline stage I/O parameters and top-level (non-block) uniforms are handled specially
4115 // in generateCode().
4116 return;
4117 }
4118
4119 const Type::TypeKind varKind = var.type().typeKind();
4120 if (varKind == Type::TypeKind::kSampler) {
4121 // If the sampler binding was unassigned, provide a scratch value; this will make
4122 // golden-output tests pass, but will not actually be usable for drawing.
4123 int samplerLocation = var.layout().fSampler >= 0 ? var.layout().fSampler
4124 : 10000 + fScratchCount++;
4125 this->writeTextureOrSampler(var, samplerLocation, kSamplerSuffix, "sampler");
4126
4127 // If the texture binding was unassigned, provide a scratch value (for golden-output tests).
4128 int textureLocation = var.layout().fTexture >= 0 ? var.layout().fTexture
4129 : 10000 + fScratchCount++;
4130 this->writeTextureOrSampler(var, textureLocation, kTextureSuffix, "texture_2d<f32>");
4131 return;
4132 }
4133
4134 if (varKind == Type::TypeKind::kTexture) {
4135 // If a binding location was unassigned, provide a scratch value (for golden-output tests).
4136 int textureLocation = var.layout().fBinding >= 0 ? var.layout().fBinding
4137 : 10000 + fScratchCount++;
4138 // For a texture without an associated sampler, we don't apply a suffix.
4139 this->writeTextureOrSampler(var, textureLocation, /*suffix=*/"",
4140 to_wgsl_type(fContext, var.type(), &var.layout()));
4141 return;
4142 }
4143
4144 std::string initializer;
4145 if (decl.value()) {
4146 // We assume here that the initial-value expression will not emit any helper statements.
4147 // Initial-value expressions are required to pass IsConstantExpression, which limits the
4148 // blast radius to constructors, literals, and other constant values/variables.
4149 initializer += " = ";
4150 initializer += this->assembleExpression(*decl.value(), Precedence::kAssignment);
4151 }
4152
4153 if (var.modifierFlags().isConst()) {
4154 this->write("const ");
4155 } else if (var.modifierFlags().isWorkgroup()) {
4156 this->write("var<workgroup> ");
4157 } else if (var.modifierFlags().isPixelLocal()) {
4158 this->write("var<pixel_local> ");
4159 } else {
4160 this->write("var<private> ");
4161 }
4162 this->write(this->assembleName(var.mangledName()));
4163 this->write(": " + to_wgsl_type(fContext, var.type(), &var.layout()));
4164 this->write(initializer);
4165 this->writeLine(";");
4166 }
4167
writeStructDefinition(const StructDefinition & s)4168 void WGSLCodeGenerator::writeStructDefinition(const StructDefinition& s) {
4169 const Type& type = s.type();
4170 this->writeLine("struct " + type.displayName() + " {");
4171 this->writeFields(type.fields(), /*memoryLayout=*/nullptr);
4172 this->writeLine("};");
4173 }
4174
writeModifiersDeclaration(const ModifiersDeclaration & modifiers)4175 void WGSLCodeGenerator::writeModifiersDeclaration(const ModifiersDeclaration& modifiers) {
4176 LayoutFlags flags = modifiers.layout().fFlags;
4177 flags &= ~(LayoutFlag::kLocalSizeX | LayoutFlag::kLocalSizeY | LayoutFlag::kLocalSizeZ);
4178 if (flags != LayoutFlag::kNone) {
4179 fContext.fErrors->error(modifiers.position(), "unsupported declaration");
4180 return;
4181 }
4182
4183 if (modifiers.layout().fLocalSizeX >= 0) {
4184 fLocalSizeX = modifiers.layout().fLocalSizeX;
4185 }
4186 if (modifiers.layout().fLocalSizeY >= 0) {
4187 fLocalSizeY = modifiers.layout().fLocalSizeY;
4188 }
4189 if (modifiers.layout().fLocalSizeZ >= 0) {
4190 fLocalSizeZ = modifiers.layout().fLocalSizeZ;
4191 }
4192 }
4193
writeFields(SkSpan<const Field> fields,const MemoryLayout * memoryLayout)4194 void WGSLCodeGenerator::writeFields(SkSpan<const Field> fields, const MemoryLayout* memoryLayout) {
4195 fIndentation++;
4196
4197 // TODO(skia:14370): array uniforms may need manual fixup for std140 padding. (Those uniforms
4198 // will also need special handling when they are accessed, or passed to functions.)
4199 for (size_t index = 0; index < fields.size(); ++index) {
4200 const Field& field = fields[index];
4201 if (memoryLayout && !memoryLayout->isSupported(*field.fType)) {
4202 // Reject types that aren't supported by the memory layout.
4203 fContext.fErrors->error(field.fPosition, "type '" + std::string(field.fType->name()) +
4204 "' is not permitted here");
4205 return;
4206 }
4207
4208 // Prepend @size(n) to enforce the offsets from the SkSL layout. (This is effectively
4209 // a gadget that we can use to insert padding between elements.)
4210 if (index < fields.size() - 1) {
4211 int thisFieldOffset = field.fLayout.fOffset;
4212 int nextFieldOffset = fields[index + 1].fLayout.fOffset;
4213 if (index == 0 && thisFieldOffset > 0) {
4214 fContext.fErrors->error(field.fPosition, "field must have an offset of zero");
4215 return;
4216 }
4217 if (thisFieldOffset >= 0 && nextFieldOffset > thisFieldOffset) {
4218 this->write("@size(");
4219 this->write(std::to_string(nextFieldOffset - thisFieldOffset));
4220 this->write(") ");
4221 }
4222 }
4223
4224 this->write(this->assembleName(field.fName));
4225 this->write(": ");
4226 if (const FieldPolyfillInfo* info = fFieldPolyfillMap.find(&field)) {
4227 if (info->fIsArray) {
4228 // This properly handles arrays of matrices, as well as arrays of other primitives.
4229 SkASSERT(field.fType->isArray());
4230 this->write("array<_skArrayElement_");
4231 this->write(field.fType->abbreviatedName());
4232 this->write(", ");
4233 this->write(std::to_string(field.fType->columns()));
4234 this->write(">");
4235 } else if (info->fIsMatrix) {
4236 this->write("_skMatrix");
4237 this->write(std::to_string(field.fType->columns()));
4238 this->write(std::to_string(field.fType->rows()));
4239 } else {
4240 SkDEBUGFAILF("need polyfill for %s", info->fReplacementName.c_str());
4241 }
4242 } else {
4243 this->write(to_wgsl_type(fContext, *field.fType, &field.fLayout));
4244 }
4245 this->writeLine(",");
4246 }
4247
4248 fIndentation--;
4249 }
4250
writeEnables()4251 void WGSLCodeGenerator::writeEnables() {
4252 this->writeLine("diagnostic(off, derivative_uniformity);");
4253 this->writeLine("diagnostic(off, chromium.unreachable_code);");
4254
4255 if (fRequirements.fPixelLocalExtension) {
4256 this->writeLine("enable chromium_experimental_pixel_local;");
4257 }
4258 if (fProgram.fInterface.fUseLastFragColor) {
4259 this->writeLine("enable chromium_experimental_framebuffer_fetch;");
4260 }
4261 if (fProgram.fInterface.fOutputSecondaryColor) {
4262 this->writeLine("enable dual_source_blending;");
4263 }
4264 }
4265
needsStageInputStruct() const4266 bool WGSLCodeGenerator::needsStageInputStruct() const {
4267 // It is illegal to declare a struct with no members; we can't emit a placeholder empty stage
4268 // input struct.
4269 return !fPipelineInputs.empty();
4270 }
4271
writeStageInputStruct()4272 void WGSLCodeGenerator::writeStageInputStruct() {
4273 if (!this->needsStageInputStruct()) {
4274 return;
4275 }
4276
4277 std::string_view structNamePrefix = pipeline_struct_prefix(fProgram.fConfig->fKind);
4278 SkASSERT(!structNamePrefix.empty());
4279
4280 this->write("struct ");
4281 this->write(structNamePrefix);
4282 this->writeLine("In {");
4283 fIndentation++;
4284
4285 for (const Variable* v : fPipelineInputs) {
4286 if (v->type().isInterfaceBlock()) {
4287 for (const Field& f : v->type().fields()) {
4288 this->writePipelineIODeclaration(f.fLayout, *f.fType, f.fName, Delimiter::kComma);
4289 }
4290 } else {
4291 this->writePipelineIODeclaration(v->layout(), v->type(), v->mangledName(),
4292 Delimiter::kComma);
4293 }
4294 }
4295
4296 fIndentation--;
4297 this->writeLine("};");
4298 }
4299
needsStageOutputStruct() const4300 bool WGSLCodeGenerator::needsStageOutputStruct() const {
4301 // It is illegal to declare a struct with no members. However, vertex programs will _always_
4302 // have an output stage in WGSL, because the spec requires them to emit `@builtin(position)`.
4303 // So we always synthesize a reference to `sk_Position` even if the program doesn't need it.
4304 return !fPipelineOutputs.empty() || ProgramConfig::IsVertex(fProgram.fConfig->fKind);
4305 }
4306
writeStageOutputStruct()4307 void WGSLCodeGenerator::writeStageOutputStruct() {
4308 if (!this->needsStageOutputStruct()) {
4309 return;
4310 }
4311
4312 std::string_view structNamePrefix = pipeline_struct_prefix(fProgram.fConfig->fKind);
4313 SkASSERT(!structNamePrefix.empty());
4314
4315 this->write("struct ");
4316 this->write(structNamePrefix);
4317 this->writeLine("Out {");
4318 fIndentation++;
4319
4320 bool declaredPositionBuiltin = false;
4321 bool requiresPointSizeBuiltin = false;
4322 for (const Variable* v : fPipelineOutputs) {
4323 if (v->type().isInterfaceBlock()) {
4324 for (const auto& f : v->type().fields()) {
4325 this->writePipelineIODeclaration(f.fLayout, *f.fType, f.fName, Delimiter::kComma);
4326 if (f.fLayout.fBuiltin == SK_POSITION_BUILTIN) {
4327 declaredPositionBuiltin = true;
4328 } else if (f.fLayout.fBuiltin == SK_POINTSIZE_BUILTIN) {
4329 // sk_PointSize is explicitly not supported by `builtin_from_sksl_name` so
4330 // writePipelineIODeclaration will never write it. We mark it here if the
4331 // declaration is needed so we can synthesize it below.
4332 requiresPointSizeBuiltin = true;
4333 }
4334 }
4335 } else {
4336 this->writePipelineIODeclaration(v->layout(), v->type(), v->mangledName(),
4337 Delimiter::kComma);
4338 }
4339 }
4340
4341 // A vertex program must include the `position` builtin in its entrypoint's return type.
4342 const bool positionBuiltinRequired = ProgramConfig::IsVertex(fProgram.fConfig->fKind);
4343 if (positionBuiltinRequired && !declaredPositionBuiltin) {
4344 this->writeLine("@builtin(position) sk_Position: vec4<f32>,");
4345 }
4346
4347 fIndentation--;
4348 this->writeLine("};");
4349
4350 // In WebGPU/WGSL, the vertex stage does not support a point-size output and the size
4351 // of a point primitive is always 1 pixel (see https://github.com/gpuweb/gpuweb/issues/332).
4352 //
4353 // There isn't anything we can do to emulate this correctly at this stage so we synthesize a
4354 // placeholder global variable that has no effect. Programs should not rely on sk_PointSize when
4355 // using the Dawn backend.
4356 if (ProgramConfig::IsVertex(fProgram.fConfig->fKind) && requiresPointSizeBuiltin) {
4357 this->writeLine("/* unsupported */ var<private> sk_PointSize: f32;");
4358 }
4359 }
4360
prepareUniformPolyfillsForInterfaceBlock(const InterfaceBlock * interfaceBlock,std::string_view instanceName,MemoryLayout::Standard nativeLayout)4361 void WGSLCodeGenerator::prepareUniformPolyfillsForInterfaceBlock(
4362 const InterfaceBlock* interfaceBlock,
4363 std::string_view instanceName,
4364 MemoryLayout::Standard nativeLayout) {
4365 SkSL::MemoryLayout std140(MemoryLayout::Standard::k140);
4366 SkSL::MemoryLayout native(nativeLayout);
4367
4368 const Type& structType = interfaceBlock->var()->type().componentType();
4369 for (const Field& field : structType.fields()) {
4370 const Type* type = field.fType;
4371 bool needsArrayPolyfill = false;
4372 bool needsMatrixPolyfill = false;
4373
4374 auto isPolyfillableMatrixType = [&](const Type* type) {
4375 return type->isMatrix() && std140.stride(*type) != native.stride(*type);
4376 };
4377
4378 if (isPolyfillableMatrixType(type)) {
4379 // Matrices will be represented as 16-byte aligned arrays in std140, and reconstituted
4380 // into proper matrices as they are later accessed. We need to synthesize polyfill.
4381 needsMatrixPolyfill = true;
4382 } else if (type->isArray() && !type->isUnsizedArray() &&
4383 !type->componentType().isOpaque()) {
4384 const Type* innerType = &type->componentType();
4385 if (isPolyfillableMatrixType(innerType)) {
4386 // Use a polyfill when the array contains a matrix that requires polyfill.
4387 needsArrayPolyfill = true;
4388 needsMatrixPolyfill = true;
4389 } else if (native.size(*innerType) < 16) {
4390 // Use a polyfill when the array elements are smaller than 16 bytes, since std140
4391 // will pad elements to a 16-byte stride.
4392 needsArrayPolyfill = true;
4393 }
4394 }
4395
4396 if (needsArrayPolyfill || needsMatrixPolyfill) {
4397 // Add a polyfill for this matrix type.
4398 FieldPolyfillInfo info;
4399 info.fInterfaceBlock = interfaceBlock;
4400 info.fReplacementName = "_skUnpacked_" + std::string(instanceName) + '_' +
4401 this->assembleName(field.fName);
4402 info.fIsArray = needsArrayPolyfill;
4403 info.fIsMatrix = needsMatrixPolyfill;
4404 fFieldPolyfillMap.set(&field, info);
4405 }
4406 }
4407 }
4408
writeUniformsAndBuffers()4409 void WGSLCodeGenerator::writeUniformsAndBuffers() {
4410 for (const ProgramElement* e : fProgram.elements()) {
4411 // Iterate through the interface blocks.
4412 if (!e->is<InterfaceBlock>()) {
4413 continue;
4414 }
4415 const InterfaceBlock& ib = e->as<InterfaceBlock>();
4416
4417 // Determine if this interface block holds uniforms, buffers, or something else (skip it).
4418 std::string_view addressSpace;
4419 std::string_view accessMode;
4420 MemoryLayout::Standard nativeLayout;
4421 if (ib.var()->modifierFlags().isUniform()) {
4422 addressSpace = "uniform";
4423 nativeLayout = MemoryLayout::Standard::kWGSLUniform_Base;
4424 } else if (ib.var()->modifierFlags().isBuffer()) {
4425 addressSpace = "storage";
4426 nativeLayout = MemoryLayout::Standard::kWGSLStorage_Base;
4427 accessMode = ib.var()->modifierFlags().isReadOnly() ? ", read" : ", read_write";
4428 } else {
4429 continue;
4430 }
4431
4432 // If we have an anonymous interface block, assign a name like `_uniform0` or `_storage1`.
4433 std::string instanceName;
4434 if (ib.instanceName().empty()) {
4435 instanceName = "_" + std::string(addressSpace) + std::to_string(fScratchCount++);
4436 fInterfaceBlockNameMap[&ib.var()->type().componentType()] = instanceName;
4437 } else {
4438 instanceName = std::string(ib.instanceName());
4439 }
4440
4441 this->prepareUniformPolyfillsForInterfaceBlock(&ib, instanceName, nativeLayout);
4442
4443 // Create a struct to hold all of the fields from this InterfaceBlock.
4444 SkASSERT(!ib.typeName().empty());
4445 this->write("struct ");
4446 this->write(ib.typeName());
4447 this->writeLine(" {");
4448
4449 // Find the struct type and fields used by this interface block.
4450 const Type& ibType = ib.var()->type().componentType();
4451 SkASSERT(ibType.isStruct());
4452
4453 SkSpan<const Field> ibFields = ibType.fields();
4454 SkASSERT(!ibFields.empty());
4455
4456 MemoryLayout layout(MemoryLayout::Standard::k140);
4457 this->writeFields(ibFields, &layout);
4458 this->writeLine("};");
4459 this->write("@group(");
4460 this->write(std::to_string(std::max(0, ib.var()->layout().fSet)));
4461 this->write(") @binding(");
4462 this->write(std::to_string(std::max(0, ib.var()->layout().fBinding)));
4463 this->write(") var<");
4464 this->write(addressSpace);
4465 this->write(accessMode);
4466 this->write("> ");
4467 this->write(instanceName);
4468 this->write(" : ");
4469 this->write(to_wgsl_type(fContext, ib.var()->type(), &ib.var()->layout()));
4470 this->writeLine(";");
4471 }
4472 }
4473
writeNonBlockUniformsForTests()4474 void WGSLCodeGenerator::writeNonBlockUniformsForTests() {
4475 bool declaredUniformsStruct = false;
4476
4477 for (const ProgramElement* e : fProgram.elements()) {
4478 if (e->is<GlobalVarDeclaration>()) {
4479 const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
4480 const Variable& var = *decls.varDeclaration().var();
4481 if (is_in_global_uniforms(var)) {
4482 if (!declaredUniformsStruct) {
4483 this->write("struct _GlobalUniforms {\n");
4484 declaredUniformsStruct = true;
4485 }
4486 this->write(" ");
4487 this->writeVariableDecl(var.layout(), var.type(), var.mangledName(),
4488 Delimiter::kComma);
4489 }
4490 }
4491 }
4492 if (declaredUniformsStruct) {
4493 int binding = fProgram.fConfig->fSettings.fDefaultUniformBinding;
4494 int set = fProgram.fConfig->fSettings.fDefaultUniformSet;
4495 this->write("};\n");
4496 this->write("@binding(" + std::to_string(binding) + ") ");
4497 this->write("@group(" + std::to_string(set) + ") ");
4498 this->writeLine("var<uniform> _globalUniforms: _GlobalUniforms;");
4499 }
4500 }
4501
functionDependencyArgs(const FunctionDeclaration & f)4502 std::string WGSLCodeGenerator::functionDependencyArgs(const FunctionDeclaration& f) {
4503 WGSLFunctionDependencies* deps = fRequirements.fDependencies.find(&f);
4504 std::string args;
4505 if (deps && *deps) {
4506 const char* separator = "";
4507 if (*deps & WGSLFunctionDependency::kPipelineInputs) {
4508 args += "_stageIn";
4509 separator = ", ";
4510 }
4511 if (*deps & WGSLFunctionDependency::kPipelineOutputs) {
4512 args += separator;
4513 args += "_stageOut";
4514 }
4515 }
4516 return args;
4517 }
4518
writeFunctionDependencyParams(const FunctionDeclaration & f)4519 bool WGSLCodeGenerator::writeFunctionDependencyParams(const FunctionDeclaration& f) {
4520 WGSLFunctionDependencies* deps = fRequirements.fDependencies.find(&f);
4521 if (!deps || !*deps) {
4522 return false;
4523 }
4524
4525 std::string_view structNamePrefix = pipeline_struct_prefix(fProgram.fConfig->fKind);
4526 if (structNamePrefix.empty()) {
4527 return false;
4528 }
4529 const char* separator = "";
4530 if (*deps & WGSLFunctionDependency::kPipelineInputs) {
4531 this->write("_stageIn: ");
4532 separator = ", ";
4533 this->write(structNamePrefix);
4534 this->write("In");
4535 }
4536 if (*deps & WGSLFunctionDependency::kPipelineOutputs) {
4537 this->write(separator);
4538 this->write("_stageOut: ptr<function, ");
4539 this->write(structNamePrefix);
4540 this->write("Out>");
4541 }
4542 return true;
4543 }
4544
4545 #if defined(SK_ENABLE_WGSL_VALIDATION)
validate_wgsl(ErrorReporter & reporter,const std::string & wgsl,std::string * warnings)4546 static bool validate_wgsl(ErrorReporter& reporter, const std::string& wgsl, std::string* warnings) {
4547 // Enable the WGSL optional features that Skia might rely on.
4548 tint::wgsl::reader::Options options;
4549 for (auto extension : {tint::wgsl::Extension::kChromiumExperimentalPixelLocal,
4550 tint::wgsl::Extension::kDualSourceBlending}) {
4551 options.allowed_features.extensions.insert(extension);
4552 }
4553
4554 // Verify that the WGSL we produced is valid.
4555 tint::Source::File srcFile("", wgsl);
4556 tint::Program program(tint::wgsl::reader::Parse(&srcFile, options));
4557
4558 if (program.Diagnostics().ContainsErrors()) {
4559 // The program isn't valid WGSL.
4560 #if defined(SKSL_STANDALONE)
4561 reporter.error(Position(), std::string("Tint compilation failed.\n\n") + wgsl);
4562 #else
4563 // In debug, report the error via SkDEBUGFAIL. We also append the generated program for
4564 // ease of debugging.
4565 tint::diag::Formatter diagFormatter;
4566 std::string diagOutput = diagFormatter.Format(program.Diagnostics()).Plain();
4567 diagOutput += "\n";
4568 diagOutput += wgsl;
4569 SkDEBUGFAILF("%s", diagOutput.c_str());
4570 #endif
4571 return false;
4572 }
4573
4574 if (!program.Diagnostics().empty()) {
4575 // The program contains warnings. Report them as-is.
4576 tint::diag::Formatter diagFormatter;
4577 *warnings = diagFormatter.Format(program.Diagnostics()).Plain();
4578 }
4579 return true;
4580 }
4581 #endif // defined(SK_ENABLE_WGSL_VALIDATION)
4582
ToWGSL(Program & program,const ShaderCaps * caps,OutputStream & out)4583 bool ToWGSL(Program& program, const ShaderCaps* caps, OutputStream& out) {
4584 TRACE_EVENT0("skia.shaders", "SkSL::ToWGSL");
4585 SkASSERT(caps != nullptr);
4586
4587 program.fContext->fErrors->setSource(*program.fSource);
4588 #ifdef SK_ENABLE_WGSL_VALIDATION
4589 StringStream wgsl;
4590 WGSLCodeGenerator cg(program.fContext.get(), caps, &program, &wgsl);
4591 bool result = cg.generateCode();
4592 if (result) {
4593 std::string wgslString = wgsl.str();
4594 std::string warnings;
4595 result = validate_wgsl(*program.fContext->fErrors, wgslString, &warnings);
4596 if (!warnings.empty()) {
4597 out.writeText("/* Tint reported warnings. */\n\n");
4598 }
4599 out.writeString(wgslString);
4600 }
4601 #else
4602 WGSLCodeGenerator cg(program.fContext.get(), caps, &program, &out);
4603 bool result = cg.generateCode();
4604 #endif
4605 program.fContext->fErrors->setSource(std::string_view());
4606
4607 return result;
4608 }
4609
ToWGSL(Program & program,const ShaderCaps * caps,std::string * out)4610 bool ToWGSL(Program& program, const ShaderCaps* caps, std::string* out) {
4611 StringStream buffer;
4612 if (!ToWGSL(program, caps, buffer)) {
4613 return false;
4614 }
4615 *out = buffer.str();
4616 return true;
4617 }
4618
4619 } // namespace SkSL
4620