1 /*
2 * Copyright 2022 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8 #include "src/sksl/codegen/SkSLWGSLCodeGenerator.h"
9
10 #include <memory>
11 #include <optional>
12 #include <string>
13 #include <vector>
14
15 #include "include/core/SkSpan.h"
16 #include "include/core/SkTypes.h"
17 #include "include/private/SkBitmaskEnum.h"
18 #include "include/private/SkSLIRNode.h"
19 #include "include/private/SkSLLayout.h"
20 #include "include/private/SkSLModifiers.h"
21 #include "include/private/SkSLProgramElement.h"
22 #include "include/private/SkSLStatement.h"
23 #include "include/private/SkSLString.h"
24 #include "include/private/SkSLSymbol.h"
25 #include "include/private/base/SkTArray.h"
26 #include "include/private/base/SkTo.h"
27 #include "include/sksl/SkSLErrorReporter.h"
28 #include "include/sksl/SkSLOperator.h"
29 #include "include/sksl/SkSLPosition.h"
30 #include "src/sksl/SkSLAnalysis.h"
31 #include "src/sksl/SkSLBuiltinTypes.h"
32 #include "src/sksl/SkSLCompiler.h"
33 #include "src/sksl/SkSLContext.h"
34 #include "src/sksl/SkSLOutputStream.h"
35 #include "src/sksl/SkSLProgramSettings.h"
36 #include "src/sksl/SkSLStringStream.h"
37 #include "src/sksl/SkSLUtil.h"
38 #include "src/sksl/analysis/SkSLProgramVisitor.h"
39 #include "src/sksl/ir/SkSLBinaryExpression.h"
40 #include "src/sksl/ir/SkSLBlock.h"
41 #include "src/sksl/ir/SkSLConstructor.h"
42 #include "src/sksl/ir/SkSLConstructorCompound.h"
43 #include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
44 #include "src/sksl/ir/SkSLExpression.h"
45 #include "src/sksl/ir/SkSLExpressionStatement.h"
46 #include "src/sksl/ir/SkSLFieldAccess.h"
47 #include "src/sksl/ir/SkSLFunctionCall.h"
48 #include "src/sksl/ir/SkSLFunctionDeclaration.h"
49 #include "src/sksl/ir/SkSLFunctionDefinition.h"
50 #include "src/sksl/ir/SkSLIfStatement.h"
51 #include "src/sksl/ir/SkSLIndexExpression.h"
52 #include "src/sksl/ir/SkSLInterfaceBlock.h"
53 #include "src/sksl/ir/SkSLLiteral.h"
54 #include "src/sksl/ir/SkSLProgram.h"
55 #include "src/sksl/ir/SkSLReturnStatement.h"
56 #include "src/sksl/ir/SkSLStructDefinition.h"
57 #include "src/sksl/ir/SkSLSwizzle.h"
58 #include "src/sksl/ir/SkSLSymbolTable.h"
59 #include "src/sksl/ir/SkSLTernaryExpression.h"
60 #include "src/sksl/ir/SkSLType.h"
61 #include "src/sksl/ir/SkSLVarDeclarations.h"
62 #include "src/sksl/ir/SkSLVariable.h"
63 #include "src/sksl/ir/SkSLVariableReference.h"
64
65 // TODO(skia:13092): This is a temporary debug feature. Remove when the implementation is
66 // complete and this is no longer needed.
67 #define DUMP_SRC_IR 0
68
69 namespace SkSL {
70
71 enum class ProgramKind : int8_t;
72
73 namespace {
74
75 // See https://www.w3.org/TR/WGSL/#memory-view-types
76 enum class PtrAddressSpace {
77 kFunction,
78 kPrivate,
79 kStorage,
80 };
81
pipeline_struct_prefix(ProgramKind kind)82 std::string_view pipeline_struct_prefix(ProgramKind kind) {
83 if (ProgramConfig::IsVertex(kind)) {
84 return "VS";
85 }
86 if (ProgramConfig::IsFragment(kind)) {
87 return "FS";
88 }
89 return "";
90 }
91
address_space_to_str(PtrAddressSpace addressSpace)92 std::string_view address_space_to_str(PtrAddressSpace addressSpace) {
93 switch (addressSpace) {
94 case PtrAddressSpace::kFunction:
95 return "function";
96 case PtrAddressSpace::kPrivate:
97 return "private";
98 case PtrAddressSpace::kStorage:
99 return "storage";
100 }
101 SkDEBUGFAIL("unsupported ptr address space");
102 return "unsupported";
103 }
104
to_scalar_type(const Type & type)105 std::string_view to_scalar_type(const Type& type) {
106 SkASSERT(type.typeKind() == Type::TypeKind::kScalar);
107 switch (type.numberKind()) {
108 // Floating-point numbers in WebGPU currently always have 32-bit footprint and
109 // relaxed-precision is not supported without extensions. f32 is the only floating-point
110 // number type in WGSL (see the discussion on https://github.com/gpuweb/gpuweb/issues/658).
111 case Type::NumberKind::kFloat:
112 return "f32";
113 case Type::NumberKind::kSigned:
114 return "i32";
115 case Type::NumberKind::kUnsigned:
116 return "u32";
117 case Type::NumberKind::kBoolean:
118 return "bool";
119 case Type::NumberKind::kNonnumeric:
120 [[fallthrough]];
121 default:
122 break;
123 }
124 return type.name();
125 }
126
127 // Convert a SkSL type to a WGSL type. Handles all plain types except structure types
128 // (see https://www.w3.org/TR/WGSL/#plain-types-section).
to_wgsl_type(const Type & type)129 std::string to_wgsl_type(const Type& type) {
130 switch (type.typeKind()) {
131 case Type::TypeKind::kScalar:
132 return std::string(to_scalar_type(type));
133 case Type::TypeKind::kVector: {
134 std::string_view ct = to_scalar_type(type.componentType());
135 return String::printf("vec%d<%.*s>", type.columns(), (int)ct.length(), ct.data());
136 }
137 case Type::TypeKind::kMatrix: {
138 std::string_view ct = to_scalar_type(type.componentType());
139 return String::printf(
140 "mat%dx%d<%.*s>", type.columns(), type.rows(), (int)ct.length(), ct.data());
141 }
142 case Type::TypeKind::kArray: {
143 std::string elementType = to_wgsl_type(type.componentType());
144 if (type.isUnsizedArray()) {
145 return String::printf("array<%s>", elementType.c_str());
146 }
147 return String::printf("array<%s, %d>", elementType.c_str(), type.columns());
148 }
149 default:
150 break;
151 }
152 return std::string(type.name());
153 }
154
155 // Create a mangled WGSL type name that can be used in function and variable declarations (regular
156 // type names cannot be used in this manner since they may contain tokens that are not allowed in
157 // symbol names).
to_mangled_wgsl_type_name(const Type & type)158 std::string to_mangled_wgsl_type_name(const Type& type) {
159 switch (type.typeKind()) {
160 case Type::TypeKind::kScalar:
161 return std::string(to_scalar_type(type));
162 case Type::TypeKind::kVector: {
163 std::string_view ct = to_scalar_type(type.componentType());
164 return String::printf("vec%d%.*s", type.columns(), (int)ct.length(), ct.data());
165 }
166 case Type::TypeKind::kMatrix: {
167 std::string_view ct = to_scalar_type(type.componentType());
168 return String::printf(
169 "mat%dx%d%.*s", type.columns(), type.rows(), (int)ct.length(), ct.data());
170 }
171 case Type::TypeKind::kArray: {
172 std::string elementType = to_wgsl_type(type.componentType());
173 if (type.isUnsizedArray()) {
174 return String::printf("arrayof%s", elementType.c_str());
175 }
176 return String::printf("array%dof%s", type.columns(), elementType.c_str());
177 }
178 default:
179 break;
180 }
181 return std::string(type.name());
182 }
183
to_ptr_type(const Type & type,PtrAddressSpace addressSpace=PtrAddressSpace::kFunction)184 std::string to_ptr_type(const Type& type,
185 PtrAddressSpace addressSpace = PtrAddressSpace::kFunction) {
186 return "ptr<" + std::string(address_space_to_str(addressSpace)) + ", " + to_wgsl_type(type) +
187 ">";
188 }
189
wgsl_builtin_name(WGSLCodeGenerator::Builtin builtin)190 std::string_view wgsl_builtin_name(WGSLCodeGenerator::Builtin builtin) {
191 using Builtin = WGSLCodeGenerator::Builtin;
192 switch (builtin) {
193 case Builtin::kVertexIndex:
194 return "vertex_index";
195 case Builtin::kInstanceIndex:
196 return "instance_index";
197 case Builtin::kPosition:
198 return "position";
199 case Builtin::kFrontFacing:
200 return "front_facing";
201 case Builtin::kSampleIndex:
202 return "sample_index";
203 case Builtin::kFragDepth:
204 return "frag_depth";
205 case Builtin::kSampleMask:
206 return "sample_mask";
207 case Builtin::kLocalInvocationId:
208 return "local_invocation_id";
209 case Builtin::kLocalInvocationIndex:
210 return "local_invocation_index";
211 case Builtin::kGlobalInvocationId:
212 return "global_invocation_id";
213 case Builtin::kWorkgroupId:
214 return "workgroup_id";
215 case Builtin::kNumWorkgroups:
216 return "num_workgroups";
217 default:
218 break;
219 }
220
221 SkDEBUGFAIL("unsupported builtin");
222 return "unsupported";
223 }
224
wgsl_builtin_type(WGSLCodeGenerator::Builtin builtin)225 std::string_view wgsl_builtin_type(WGSLCodeGenerator::Builtin builtin) {
226 using Builtin = WGSLCodeGenerator::Builtin;
227 switch (builtin) {
228 case Builtin::kVertexIndex:
229 return "u32";
230 case Builtin::kInstanceIndex:
231 return "u32";
232 case Builtin::kPosition:
233 return "vec4<f32>";
234 case Builtin::kFrontFacing:
235 return "bool";
236 case Builtin::kSampleIndex:
237 return "u32";
238 case Builtin::kFragDepth:
239 return "f32";
240 case Builtin::kSampleMask:
241 return "u32";
242 case Builtin::kLocalInvocationId:
243 return "vec3<u32>";
244 case Builtin::kLocalInvocationIndex:
245 return "u32";
246 case Builtin::kGlobalInvocationId:
247 return "vec3<u32>";
248 case Builtin::kWorkgroupId:
249 return "vec3<u32>";
250 case Builtin::kNumWorkgroups:
251 return "vec3<u32>";
252 default:
253 break;
254 }
255
256 SkDEBUGFAIL("unsupported builtin");
257 return "unsupported";
258 }
259
260 // Some built-in variables have a type that differs from their SkSL counterpart (e.g. signed vs
261 // unsigned integer). We handle these cases with an explicit type conversion during a variable
262 // reference. Returns the WGSL type of the conversion target if conversion is needed, otherwise
263 // returns std::nullopt.
needs_builtin_type_conversion(const Variable & v)264 std::optional<std::string_view> needs_builtin_type_conversion(const Variable& v) {
265 switch (v.modifiers().fLayout.fBuiltin) {
266 case SK_VERTEXID_BUILTIN:
267 case SK_INSTANCEID_BUILTIN:
268 return {"i32"};
269 default:
270 break;
271 }
272 return std::nullopt;
273 }
274
275 // Map a SkSL builtin flag to a WGSL builtin kind. Returns std::nullopt if `builtin` is not
276 // not supported for WGSL.
277 //
278 // Also see //src/sksl/sksl_vert.sksl and //src/sksl/sksl_frag.sksl for supported built-ins.
builtin_from_sksl_name(int builtin)279 std::optional<WGSLCodeGenerator::Builtin> builtin_from_sksl_name(int builtin) {
280 using Builtin = WGSLCodeGenerator::Builtin;
281 switch (builtin) {
282 case SK_POSITION_BUILTIN:
283 [[fallthrough]];
284 case SK_FRAGCOORD_BUILTIN:
285 return {Builtin::kPosition};
286 case SK_VERTEXID_BUILTIN:
287 return {Builtin::kVertexIndex};
288 case SK_INSTANCEID_BUILTIN:
289 return {Builtin::kInstanceIndex};
290 case SK_CLOCKWISE_BUILTIN:
291 // TODO(skia:13092): While `front_facing` is the corresponding built-in, it does not
292 // imply a particular winding order. We correctly compute the face orientation based
293 // on how Skia configured the render pipeline for all references to this built-in
294 // variable (see `SkSL::Program::Inputs::fUseFlipRTUniform`).
295 return {Builtin::kFrontFacing};
296 default:
297 break;
298 }
299 return std::nullopt;
300 }
301
top_level_symbol_table(const FunctionDefinition & f)302 const SymbolTable* top_level_symbol_table(const FunctionDefinition& f) {
303 return f.body()->as<Block>().symbolTable()->fParent.get();
304 }
305
delimiter_to_str(WGSLCodeGenerator::Delimiter delimiter)306 const char* delimiter_to_str(WGSLCodeGenerator::Delimiter delimiter) {
307 using Delim = WGSLCodeGenerator::Delimiter;
308 switch (delimiter) {
309 case Delim::kComma:
310 return ",";
311 case Delim::kSemicolon:
312 return ";";
313 case Delim::kNone:
314 default:
315 break;
316 }
317 return "";
318 }
319
320 // FunctionDependencyResolver visits the IR tree rooted at a particular function definition and
321 // computes that function's dependencies on pipeline stage IO parameters. These are later used to
322 // synthesize arguments when writing out function definitions.
323 class FunctionDependencyResolver : public ProgramVisitor {
324 public:
325 using Deps = WGSLCodeGenerator::FunctionDependencies;
326 using DepsMap = WGSLCodeGenerator::ProgramRequirements::DepsMap;
327
FunctionDependencyResolver(const Program * p,const FunctionDeclaration * f,DepsMap * programDependencyMap)328 FunctionDependencyResolver(const Program* p,
329 const FunctionDeclaration* f,
330 DepsMap* programDependencyMap)
331 : fProgram(p), fFunction(f), fDependencyMap(programDependencyMap) {}
332
resolve()333 Deps resolve() {
334 fDeps = Deps::kNone;
335 this->visit(*fProgram);
336 return fDeps;
337 }
338
339 private:
visitProgramElement(const ProgramElement & p)340 bool visitProgramElement(const ProgramElement& p) override {
341 // Only visit the program that matches the requested function.
342 if (p.is<FunctionDefinition>() && &p.as<FunctionDefinition>().declaration() == fFunction) {
343 return INHERITED::visitProgramElement(p);
344 }
345 // Continue visiting other program elements.
346 return false;
347 }
348
visitExpression(const Expression & e)349 bool visitExpression(const Expression& e) override {
350 if (e.is<VariableReference>()) {
351 const VariableReference& v = e.as<VariableReference>();
352 const Modifiers& modifiers = v.variable()->modifiers();
353 if (v.variable()->storage() == Variable::Storage::kGlobal) {
354 if (modifiers.fFlags & Modifiers::kIn_Flag) {
355 fDeps |= Deps::kPipelineInputs;
356 }
357 if (modifiers.fFlags & Modifiers::kOut_Flag) {
358 fDeps |= Deps::kPipelineOutputs;
359 }
360 }
361 } else if (e.is<FunctionCall>()) {
362 // The current function that we're processing (`fFunction`) inherits the dependencies of
363 // functions that it makes calls to, because the pipeline stage IO parameters need to be
364 // passed down as an argument.
365 const FunctionCall& callee = e.as<FunctionCall>();
366
367 // Don't process a function again if we have already resolved it.
368 Deps* found = fDependencyMap->find(&callee.function());
369 if (found) {
370 fDeps |= *found;
371 } else {
372 // Store the dependencies that have been discovered for the current function so far.
373 // If `callee` directly or indirectly calls the current function, then this value
374 // will prevent an infinite recursion.
375 fDependencyMap->set(fFunction, fDeps);
376
377 // Separately traverse the called function's definition and determine its
378 // dependencies.
379 FunctionDependencyResolver resolver(fProgram, &callee.function(), fDependencyMap);
380 Deps calleeDeps = resolver.resolve();
381
382 // Store the callee's dependencies in the global map to avoid processing
383 // the function again for future calls.
384 fDependencyMap->set(&callee.function(), calleeDeps);
385
386 // Add to the current function's dependencies.
387 fDeps |= calleeDeps;
388 }
389 }
390 return INHERITED::visitExpression(e);
391 }
392
393 const Program* const fProgram;
394 const FunctionDeclaration* const fFunction;
395 DepsMap* const fDependencyMap;
396 Deps fDeps = Deps::kNone;
397
398 using INHERITED = ProgramVisitor;
399 };
400
resolve_program_requirements(const Program * program)401 WGSLCodeGenerator::ProgramRequirements resolve_program_requirements(const Program* program) {
402 bool mainNeedsCoordsArgument = false;
403 WGSLCodeGenerator::ProgramRequirements::DepsMap dependencies;
404
405 for (const ProgramElement* e : program->elements()) {
406 if (!e->is<FunctionDefinition>()) {
407 continue;
408 }
409
410 const FunctionDeclaration& decl = e->as<FunctionDefinition>().declaration();
411 if (decl.isMain()) {
412 for (const Variable* v : decl.parameters()) {
413 if (v->modifiers().fLayout.fBuiltin == SK_MAIN_COORDS_BUILTIN) {
414 mainNeedsCoordsArgument = true;
415 break;
416 }
417 }
418 }
419
420 FunctionDependencyResolver resolver(program, &decl, &dependencies);
421 dependencies.set(&decl, resolver.resolve());
422 }
423
424 return WGSLCodeGenerator::ProgramRequirements(std::move(dependencies), mainNeedsCoordsArgument);
425 }
426
count_pipeline_inputs(const Program * program)427 int count_pipeline_inputs(const Program* program) {
428 int inputCount = 0;
429 for (const ProgramElement* e : program->elements()) {
430 if (e->is<GlobalVarDeclaration>()) {
431 const Variable* v = e->as<GlobalVarDeclaration>().varDeclaration().var();
432 if (v->modifiers().fFlags & Modifiers::kIn_Flag) {
433 inputCount++;
434 }
435 } else if (e->is<InterfaceBlock>()) {
436 const Variable* v = e->as<InterfaceBlock>().var();
437 if (v->modifiers().fFlags & Modifiers::kIn_Flag) {
438 inputCount++;
439 }
440 }
441 }
442 return inputCount;
443 }
444
is_in_global_uniforms(const Variable & var)445 static bool is_in_global_uniforms(const Variable& var) {
446 SkASSERT(var.storage() == VariableStorage::kGlobal);
447 return var.modifiers().fFlags & Modifiers::kUniform_Flag && !var.type().isOpaque();
448 }
449
450 } // namespace
451
generateCode()452 bool WGSLCodeGenerator::generateCode() {
453 // The resources of a WGSL program are structured in the following way:
454 // - Vertex and fragment stage attribute inputs and outputs are bundled
455 // inside synthetic structs called VSIn/VSOut/FSIn/FSOut.
456 // - All uniform and storage type resources are declared in global scope.
457 this->preprocessProgram();
458
459 StringStream header;
460 {
461 AutoOutputStream outputToHeader(this, &header, &fIndentation);
462 // TODO(skia:13092): Implement the following:
463 // - global uniform/storage resource declarations, including interface blocks.
464 this->writeStageInputStruct();
465 this->writeStageOutputStruct();
466 this->writeNonBlockUniformsForTests();
467 }
468 StringStream body;
469 {
470 AutoOutputStream outputToBody(this, &body, &fIndentation);
471 for (const ProgramElement* e : fProgram.elements()) {
472 this->writeProgramElement(*e);
473 }
474
475 // TODO(skia:13092): This is a temporary debug feature. Remove when the implementation is
476 // complete and this is no longer needed.
477 #if DUMP_SRC_IR
478 this->writeLine("\n----------");
479 this->writeLine("Source IR:\n");
480 for (const ProgramElement* e : fProgram.elements()) {
481 this->writeLine(e->description().c_str());
482 }
483 #endif
484 }
485
486 write_stringstream(header, *fOut);
487 write_stringstream(fExtraFunctions, *fOut);
488 write_stringstream(body, *fOut);
489 return fContext.fErrors->errorCount() == 0;
490 }
491
preprocessProgram()492 void WGSLCodeGenerator::preprocessProgram() {
493 fRequirements = resolve_program_requirements(&fProgram);
494 fPipelineInputCount = count_pipeline_inputs(&fProgram);
495 }
496
write(std::string_view s)497 void WGSLCodeGenerator::write(std::string_view s) {
498 if (s.empty()) {
499 return;
500 }
501 if (fAtLineStart) {
502 for (int i = 0; i < fIndentation; i++) {
503 fOut->writeText(" ");
504 }
505 }
506 fOut->writeText(std::string(s).c_str());
507 fAtLineStart = false;
508 }
509
writeLine(std::string_view s)510 void WGSLCodeGenerator::writeLine(std::string_view s) {
511 this->write(s);
512 fOut->writeText("\n");
513 fAtLineStart = true;
514 }
515
finishLine()516 void WGSLCodeGenerator::finishLine() {
517 if (!fAtLineStart) {
518 this->writeLine();
519 }
520 }
521
writeName(std::string_view name)522 void WGSLCodeGenerator::writeName(std::string_view name) {
523 // Add underscore before name to avoid conflict with reserved words.
524 if (fReservedWords.contains(name)) {
525 this->write("_");
526 }
527 this->write(name);
528 }
529
writeVariableDecl(const Type & type,std::string_view name,Delimiter delimiter)530 void WGSLCodeGenerator::writeVariableDecl(const Type& type,
531 std::string_view name,
532 Delimiter delimiter) {
533 this->writeName(name);
534 this->write(": " + to_wgsl_type(type));
535 this->writeLine(delimiter_to_str(delimiter));
536 }
537
writePipelineIODeclaration(Modifiers modifiers,const Type & type,std::string_view name,Delimiter delimiter)538 void WGSLCodeGenerator::writePipelineIODeclaration(Modifiers modifiers,
539 const Type& type,
540 std::string_view name,
541 Delimiter delimiter) {
542 // In WGSL, an entry-point IO parameter is "one of either a built-in value or
543 // assigned a location". However, some SkSL declarations, specifically sk_FragColor, can
544 // contain both a location and a builtin modifier. In addition, WGSL doesn't have a built-in
545 // equivalent for sk_FragColor as it relies on the user-defined location for a render
546 // target.
547 //
548 // Instead of special-casing sk_FragColor, we just give higher precedence to a location
549 // modifier if a declaration happens to both have a location and it's a built-in.
550 //
551 // Also see:
552 // https://www.w3.org/TR/WGSL/#input-output-locations
553 // https://www.w3.org/TR/WGSL/#attribute-location
554 // https://www.w3.org/TR/WGSL/#builtin-inputs-outputs
555 int location = modifiers.fLayout.fLocation;
556 if (location >= 0) {
557 this->writeUserDefinedIODecl(type, name, location, delimiter);
558 } else if (modifiers.fLayout.fBuiltin >= 0) {
559 auto builtin = builtin_from_sksl_name(modifiers.fLayout.fBuiltin);
560 if (builtin.has_value()) {
561 this->writeBuiltinIODecl(type, name, *builtin, delimiter);
562 }
563 }
564 }
565
writeUserDefinedIODecl(const Type & type,std::string_view name,int location,Delimiter delimiter)566 void WGSLCodeGenerator::writeUserDefinedIODecl(const Type& type,
567 std::string_view name,
568 int location,
569 Delimiter delimiter) {
570 this->write("@location(" + std::to_string(location) + ") ");
571
572 // "User-defined IO of scalar or vector integer type must always be specified as
573 // @interpolate(flat)" (see https://www.w3.org/TR/WGSL/#interpolation)
574 if (type.isInteger() || (type.isVector() && type.componentType().isInteger())) {
575 this->write("@interpolate(flat) ");
576 }
577
578 this->writeVariableDecl(type, name, delimiter);
579 }
580
writeBuiltinIODecl(const Type & type,std::string_view name,Builtin builtin,Delimiter delimiter)581 void WGSLCodeGenerator::writeBuiltinIODecl(const Type& type,
582 std::string_view name,
583 Builtin builtin,
584 Delimiter delimiter) {
585 this->write("@builtin(");
586 this->write(wgsl_builtin_name(builtin));
587 this->write(") ");
588
589 this->writeName(name);
590 this->write(": ");
591 this->write(wgsl_builtin_type(builtin));
592 this->writeLine(delimiter_to_str(delimiter));
593 }
594
writeFunction(const FunctionDefinition & f)595 void WGSLCodeGenerator::writeFunction(const FunctionDefinition& f) {
596 this->writeFunctionDeclaration(f.declaration());
597 this->write(" ");
598 this->writeBlock(f.body()->as<Block>());
599
600 if (f.declaration().isMain()) {
601 // We just emitted the user-defined main function. Next, we generate a program entry point
602 // that calls the user-defined main.
603 this->writeEntryPoint(f);
604 }
605 }
606
writeFunctionDeclaration(const FunctionDeclaration & f)607 void WGSLCodeGenerator::writeFunctionDeclaration(const FunctionDeclaration& f) {
608 this->write("fn ");
609 this->write(f.mangledName());
610 this->write("(");
611 auto separator = SkSL::String::Separator();
612 if (this->writeFunctionDependencyParams(f)) {
613 separator(); // update the separator as parameters have been written
614 }
615 for (const Variable* param : f.parameters()) {
616 this->write(separator());
617 this->writeName(param->mangledName());
618 this->write(": ");
619
620 // Declare an "out" function parameter as a pointer.
621 if (param->modifiers().fFlags & Modifiers::kOut_Flag) {
622 this->write(to_ptr_type(param->type()));
623 } else {
624 this->write(to_wgsl_type(param->type()));
625 }
626 }
627 this->write(")");
628 if (!f.returnType().isVoid()) {
629 this->write(" -> ");
630 this->write(to_wgsl_type(f.returnType()));
631 }
632 }
633
writeEntryPoint(const FunctionDefinition & main)634 void WGSLCodeGenerator::writeEntryPoint(const FunctionDefinition& main) {
635 SkASSERT(main.declaration().isMain());
636
637 // The input and output parameters for a vertex/fragment stage entry point function have the
638 // FSIn/FSOut/VSIn/VSOut struct types that have been synthesized in generateCode(). An entry
639 // point always has the same signature and acts as a trampoline to the user-defined main
640 // function.
641 std::string outputType;
642 if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
643 this->write("@vertex fn vertexMain(");
644 if (fPipelineInputCount > 0) {
645 this->write("_stageIn: VSIn");
646 }
647 this->writeLine(") -> VSOut {");
648 outputType = "VSOut";
649 } else if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
650 this->write("@fragment fn fragmentMain(");
651 if (fPipelineInputCount > 0) {
652 this->write("_stageIn: FSIn");
653 }
654 this->writeLine(") -> FSOut {");
655 outputType = "FSOut";
656 } else {
657 fContext.fErrors->error(Position(), "program kind not supported");
658 return;
659 }
660
661 // Declare the stage output struct.
662 fIndentation++;
663 this->write("var _stageOut: ");
664 this->write(outputType);
665 this->writeLine(";");
666
667 // Generate assignment to sk_FragColor built-in if the user-defined main returns a color.
668 if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
669 const SymbolTable* symbolTable = top_level_symbol_table(main);
670 const Symbol* symbol = symbolTable->find("sk_FragColor");
671 SkASSERT(symbol);
672 if (main.declaration().returnType().matches(symbol->type())) {
673 this->write("_stageOut.sk_FragColor = ");
674 }
675 }
676
677 // Generate the function call to the user-defined main:
678 this->write(main.declaration().mangledName());
679 this->write("(");
680 auto separator = SkSL::String::Separator();
681 FunctionDependencies* deps = fRequirements.dependencies.find(&main.declaration());
682 if (deps) {
683 if ((*deps & FunctionDependencies::kPipelineInputs) != FunctionDependencies::kNone) {
684 this->write(separator());
685 this->write("_stageIn");
686 }
687 if ((*deps & FunctionDependencies::kPipelineOutputs) != FunctionDependencies::kNone) {
688 this->write(separator());
689 this->write("&_stageOut");
690 }
691 }
692 // TODO(armansito): Handle arbitrary parameters.
693 if (main.declaration().parameters().size() != 0) {
694 const Variable* v = main.declaration().parameters()[0];
695 const Type& type = v->type();
696 if (v->modifiers().fLayout.fBuiltin == SK_MAIN_COORDS_BUILTIN) {
697 if (!type.matches(*fContext.fTypes.fFloat2)) {
698 fContext.fErrors->error(
699 main.fPosition,
700 "main function has unsupported parameter: " + type.description());
701 return;
702 }
703
704 this->write(separator());
705 this->write("_stageIn.sk_FragCoord.xy");
706 }
707 }
708 this->writeLine(");");
709 this->writeLine("return _stageOut;");
710
711 fIndentation--;
712 this->writeLine("}");
713 }
714
writeStatement(const Statement & s)715 void WGSLCodeGenerator::writeStatement(const Statement& s) {
716 switch (s.kind()) {
717 case Statement::Kind::kBlock:
718 this->writeBlock(s.as<Block>());
719 break;
720 case Statement::Kind::kExpression:
721 this->writeExpressionStatement(s.as<ExpressionStatement>());
722 break;
723 case Statement::Kind::kIf:
724 this->writeIfStatement(s.as<IfStatement>());
725 break;
726 case Statement::Kind::kReturn:
727 this->writeReturnStatement(s.as<ReturnStatement>());
728 break;
729 case Statement::Kind::kVarDeclaration:
730 this->writeVarDeclaration(s.as<VarDeclaration>());
731 break;
732 default:
733 SkDEBUGFAILF("unsupported statement (kind: %d) %s",
734 static_cast<int>(s.kind()), s.description().c_str());
735 break;
736 }
737 }
738
writeStatements(const StatementArray & statements)739 void WGSLCodeGenerator::writeStatements(const StatementArray& statements) {
740 for (const auto& s : statements) {
741 if (!s->isEmpty()) {
742 this->writeStatement(*s);
743 this->finishLine();
744 }
745 }
746 }
747
writeBlock(const Block & b)748 void WGSLCodeGenerator::writeBlock(const Block& b) {
749 // Write scope markers if this block is a scope, or if the block is empty (since we need to emit
750 // something here to make the code valid).
751 bool isScope = b.isScope() || b.isEmpty();
752 if (isScope) {
753 this->writeLine("{");
754 fIndentation++;
755 }
756 this->writeStatements(b.children());
757 if (isScope) {
758 fIndentation--;
759 this->writeLine("}");
760 }
761 }
762
writeExpressionStatement(const ExpressionStatement & s)763 void WGSLCodeGenerator::writeExpressionStatement(const ExpressionStatement& s) {
764 if (Analysis::HasSideEffects(*s.expression())) {
765 this->writeExpression(*s.expression(), Precedence::kTopLevel);
766 this->write(";");
767 }
768 }
769
writeIfStatement(const IfStatement & s)770 void WGSLCodeGenerator::writeIfStatement(const IfStatement& s) {
771 this->write("if (");
772 this->writeExpression(*s.test(), Precedence::kTopLevel);
773 this->write(") ");
774 this->writeStatement(*s.ifTrue());
775 if (s.ifFalse()) {
776 this->write("else ");
777 this->writeStatement(*s.ifFalse());
778 }
779 }
780
writeReturnStatement(const ReturnStatement & s)781 void WGSLCodeGenerator::writeReturnStatement(const ReturnStatement& s) {
782 this->write("return");
783 if (s.expression()) {
784 this->write(" ");
785 this->writeExpression(*s.expression(), Precedence::kTopLevel);
786 }
787 this->write(";");
788 }
789
writeVarDeclaration(const VarDeclaration & varDecl)790 void WGSLCodeGenerator::writeVarDeclaration(const VarDeclaration& varDecl) {
791 bool isConst = varDecl.var()->modifiers().fFlags & Modifiers::kConst_Flag;
792 if (isConst) {
793 this->write("let ");
794 } else {
795 this->write("var ");
796 }
797 this->writeName(varDecl.var()->mangledName());
798 this->write(": ");
799 this->write(to_wgsl_type(varDecl.var()->type()));
800
801 if (varDecl.value()) {
802 this->write(" = ");
803 this->writeExpression(*varDecl.value(), Precedence::kTopLevel);
804 } else if (isConst) {
805 SkDEBUGFAILF("A let-declared constant must specify a value");
806 }
807
808 this->write(";");
809 }
810
writeExpression(const Expression & e,Precedence parentPrecedence)811 void WGSLCodeGenerator::writeExpression(const Expression& e, Precedence parentPrecedence) {
812 switch (e.kind()) {
813 case Expression::Kind::kBinary:
814 this->writeBinaryExpression(e.as<BinaryExpression>(), parentPrecedence);
815 break;
816 case Expression::Kind::kConstructorCompound:
817 this->writeConstructorCompound(e.as<ConstructorCompound>(), parentPrecedence);
818 break;
819 case Expression::Kind::kConstructorCompoundCast:
820 case Expression::Kind::kConstructorScalarCast:
821 case Expression::Kind::kConstructorSplat:
822 this->writeAnyConstructor(e.asAnyConstructor(), parentPrecedence);
823 break;
824 case Expression::Kind::kConstructorDiagonalMatrix:
825 this->writeConstructorDiagonalMatrix(e.as<ConstructorDiagonalMatrix>(),
826 parentPrecedence);
827 break;
828 case Expression::Kind::kFieldAccess:
829 this->writeFieldAccess(e.as<FieldAccess>());
830 break;
831 case Expression::Kind::kFunctionCall:
832 this->writeFunctionCall(e.as<FunctionCall>());
833 break;
834 case Expression::Kind::kIndex:
835 this->writeIndexExpression(e.as<IndexExpression>());
836 break;
837 case Expression::Kind::kLiteral:
838 this->writeLiteral(e.as<Literal>());
839 break;
840 case Expression::Kind::kSwizzle:
841 this->writeSwizzle(e.as<Swizzle>());
842 break;
843 case Expression::Kind::kTernary:
844 this->writeTernaryExpression(e.as<TernaryExpression>(), parentPrecedence);
845 break;
846 case Expression::Kind::kVariableReference:
847 this->writeVariableReference(e.as<VariableReference>());
848 break;
849 default:
850 SkDEBUGFAILF("unsupported expression (kind: %d) %s",
851 static_cast<int>(e.kind()),
852 e.description().c_str());
853 break;
854 }
855 }
856
writeBinaryExpression(const BinaryExpression & b,Precedence parentPrecedence)857 void WGSLCodeGenerator::writeBinaryExpression(const BinaryExpression& b,
858 Precedence parentPrecedence) {
859 const Expression& left = *b.left();
860 const Expression& right = *b.right();
861 Operator op = b.getOperator();
862
863 // The equality and comparison operators are only supported for scalar and vector types.
864 if (op.isEquality() && !left.type().isScalar() && !left.type().isVector()) {
865 if (left.type().isMatrix()) {
866 if (op.kind() == OperatorKind::NEQ) {
867 this->write("!");
868 }
869 this->writeMatrixEquality(left, right);
870 return;
871 }
872
873 // TODO(skia:13092): Synthesize helper functions for structs and arrays.
874 return;
875 }
876
877 Precedence precedence = op.getBinaryPrecedence();
878 bool needParens = precedence >= parentPrecedence;
879
880 // The equality operators ('=='/'!=') in WGSL apply component-wise to vectors and result in a
881 // vector. We need to reduce the value to a boolean.
882 if (left.type().isVector()) {
883 if (op.kind() == Operator::Kind::EQEQ) {
884 this->write("all");
885 needParens = true;
886 } else if (op.kind() == Operator::Kind::NEQ) {
887 this->write("any");
888 needParens = true;
889 }
890 }
891
892 if (needParens) {
893 this->write("(");
894 }
895
896 // TODO(skia:13092): Correctly handle the case when lhs is a pointer.
897
898 this->writeExpression(left, precedence);
899 this->write(op.operatorName());
900 this->writeExpression(right, precedence);
901
902 if (needParens) {
903 this->write(")");
904 }
905 }
906
writeFieldAccess(const FieldAccess & f)907 void WGSLCodeGenerator::writeFieldAccess(const FieldAccess& f) {
908 const Type::Field* field = &f.base()->type().fields()[f.fieldIndex()];
909 if (FieldAccess::OwnerKind::kDefault == f.ownerKind()) {
910 this->writeExpression(*f.base(), Precedence::kPostfix);
911 this->write(".");
912 } else {
913 // We are accessing a field in an anonymous interface block. If the field refers to a
914 // pipeline IO parameter, then we access it via the synthesized IO structs. We make an
915 // explicit exception for `sk_PointSize` which we declare as a placeholder variable in
916 // global scope as it is not supported by WebGPU as a pipeline IO parameter (see comments
917 // in `writeStageOutputStruct`).
918 const Variable& v = *f.base()->as<VariableReference>().variable();
919 if (v.modifiers().fFlags & Modifiers::kIn_Flag) {
920 this->write("_stageIn.");
921 } else if (v.modifiers().fFlags & Modifiers::kOut_Flag &&
922 field->fModifiers.fLayout.fBuiltin != SK_POINTSIZE_BUILTIN) {
923 this->write("(*_stageOut).");
924 } else {
925 // TODO(skia:13092): Reference the variable using the base name used for its
926 // uniform/storage block global declaration.
927 }
928 }
929 this->writeName(field->fName);
930 }
931
writeFunctionCall(const FunctionCall & c)932 void WGSLCodeGenerator::writeFunctionCall(const FunctionCall& c) {
933 const FunctionDeclaration& func = c.function();
934
935 // TODO(skia:13092): Handle intrinsic call as many of them need to be rewritten.
936
937 // We implement function out-parameters by declaring them as pointers. SkSL follows GLSL's
938 // out-parameter semantics, in which out-parameters are only written back to the original
939 // variable after the function's execution is complete (see
940 // https://www.khronos.org/opengl/wiki/Core_Language_(GLSL)#Parameters).
941 //
942 // In addition, SkSL supports swizzles and array index expressions to be passed into
943 // out-parameters however WGSL does not allow taking their address into a pointer.
944 //
945 // We support these by wrapping each function call in a special helper, which internally stores
946 // all out parameters in temporaries.
947
948 // First detect which arguments are passed to out-parameters.
949 const ExpressionArray& args = c.arguments();
950 const std::vector<Variable*>& params = func.parameters();
951 SkASSERT(SkToSizeT(args.size()) == params.size());
952
953 bool foundOutParam = false;
954 SkSTArray<16, VariableReference*> outVars;
955 outVars.push_back_n(args.size(), static_cast<VariableReference*>(nullptr));
956
957 for (int i = 0; i < args.size(); ++i) {
958 if (params[i]->modifiers().fFlags & Modifiers::kOut_Flag) {
959 // Find the expression's inner variable being written to. Assignability was verified at
960 // IR generation time, so this should always succeed.
961 Analysis::AssignmentInfo info;
962 SkAssertResult(Analysis::IsAssignable(*args[i], &info));
963 outVars[i] = info.fAssignedVar;
964 foundOutParam = true;
965 }
966 }
967
968 if (foundOutParam) {
969 this->writeName(this->writeOutParamHelper(c, args, outVars));
970 } else {
971 this->writeName(func.mangledName());
972 }
973
974 this->write("(");
975 auto separator = SkSL::String::Separator();
976 if (this->writeFunctionDependencyArgs(func)) {
977 separator();
978 }
979 for (int i = 0; i < args.size(); ++i) {
980 this->write(separator());
981 if (outVars[i]) {
982 // We need to take the address of the variable and pass it down as a pointer.
983 this->write("&");
984 this->writeExpression(*outVars[i], Precedence::kSequence);
985 } else {
986 this->writeExpression(*args[i], Precedence::kSequence);
987 }
988 }
989 this->write(")");
990 }
991
writeIndexExpression(const IndexExpression & i)992 void WGSLCodeGenerator::writeIndexExpression(const IndexExpression& i) {
993 this->writeExpression(*i.base(), Precedence::kPostfix);
994 this->write("[");
995 this->writeExpression(*i.index(), Precedence::kTopLevel);
996 this->write("]");
997 }
998
writeLiteral(const Literal & l)999 void WGSLCodeGenerator::writeLiteral(const Literal& l) {
1000 const Type& type = l.type();
1001 if (type.isFloat() || type.isBoolean()) {
1002 this->write(l.description(OperatorPrecedence::kTopLevel));
1003 return;
1004 }
1005 SkASSERT(type.isInteger());
1006 if (type.matches(*fContext.fTypes.fUInt)) {
1007 this->write(std::to_string(l.intValue() & 0xffffffff));
1008 this->write("u");
1009 } else if (type.matches(*fContext.fTypes.fUShort)) {
1010 this->write(std::to_string(l.intValue() & 0xffff));
1011 this->write("u");
1012 } else {
1013 this->write(std::to_string(l.intValue()));
1014 }
1015 }
1016
writeSwizzle(const Swizzle & swizzle)1017 void WGSLCodeGenerator::writeSwizzle(const Swizzle& swizzle) {
1018 this->writeExpression(*swizzle.base(), Precedence::kPostfix);
1019 this->write(".");
1020 for (int c : swizzle.components()) {
1021 SkASSERT(c >= 0 && c <= 3);
1022 this->write(&("x\0y\0z\0w\0"[c * 2]));
1023 }
1024 }
1025
writeTernaryExpression(const TernaryExpression & t,Precedence parentPrecedence)1026 void WGSLCodeGenerator::writeTernaryExpression(const TernaryExpression& t,
1027 Precedence parentPrecedence) {
1028 bool needParens = Precedence::kTernary >= parentPrecedence;
1029 if (needParens) {
1030 this->write("(");
1031 }
1032
1033 // The trivial case is when neither branch has side effects and evaluate to a scalar or vector
1034 // type. This can be represented with a call to the WGSL `select` intrinsic although it doesn't
1035 // support short-circuiting.
1036 if ((t.type().isScalar() || t.type().isVector()) && !Analysis::HasSideEffects(*t.ifTrue()) &&
1037 !Analysis::HasSideEffects(*t.ifFalse())) {
1038 this->write("select(");
1039 this->writeExpression(*t.ifFalse(), Precedence::kTernary);
1040 this->write(", ");
1041 this->writeExpression(*t.ifTrue(), Precedence::kTernary);
1042 this->write(", ");
1043
1044 bool isVector = t.type().isVector();
1045 if (isVector) {
1046 // Splat the condition expression into a vector.
1047 this->write(String::printf("vec%d<bool>(", t.type().columns()));
1048 }
1049 this->writeExpression(*t.test(), Precedence::kTernary);
1050 if (isVector) {
1051 this->write(")");
1052 }
1053 this->write(")");
1054 if (needParens) {
1055 this->write(")");
1056 }
1057 return;
1058 }
1059
1060 // TODO(skia:13092): WGSL does not support ternary expressions. To replicate the required
1061 // short-circuting behavior we need to hoist the expression out into the surrounding block,
1062 // convert it into an if statement that writes the result to a synthesized variable, and replace
1063 // the original expression with a reference to that variable.
1064 //
1065 // Once hoisting is supported, we may want to use that for vector type expressions as well,
1066 // since select above does a component-wise select
1067 }
1068
writeVariableReference(const VariableReference & r)1069 void WGSLCodeGenerator::writeVariableReference(const VariableReference& r) {
1070 // TODO(skia:13092): Correctly handle RTflip for built-ins.
1071 const Variable& v = *r.variable();
1072
1073 // Insert a conversion expression if this is a built-in variable whose type differs from the
1074 // SkSL.
1075 std::optional<std::string_view> conversion = needs_builtin_type_conversion(v);
1076 if (conversion.has_value()) {
1077 this->write(*conversion);
1078 this->write("(");
1079 }
1080
1081 bool needsDeref = false;
1082 bool isSynthesizedOutParamArg = fOutParamArgVars.contains(&v);
1083
1084 // When a variable is referenced in the context of a synthesized out-parameter helper argument,
1085 // two special rules apply:
1086 // 1. If it's accessed via a pipeline I/O or global uniforms struct, it should instead
1087 // be referenced by name (since it's actually referring to a function parameter).
1088 // 2. Its type should be treated as a pointer and should be dereferenced as such.
1089 if (v.storage() == Variable::Storage::kGlobal && !isSynthesizedOutParamArg) {
1090 if (v.modifiers().fFlags & Modifiers::kIn_Flag) {
1091 this->write("_stageIn.");
1092 } else if (v.modifiers().fFlags & Modifiers::kOut_Flag) {
1093 this->write("(*_stageOut).");
1094 } else if (is_in_global_uniforms(v)) {
1095 this->write("_globalUniforms.");
1096 }
1097 } else if ((v.storage() == Variable::Storage::kParameter &&
1098 v.modifiers().fFlags & Modifiers::kOut_Flag) ||
1099 isSynthesizedOutParamArg) {
1100 // This is an out-parameter and its type is a pointer, which we need to dereference.
1101 // We wrap the dereference in parentheses in case the value is used in an access expression
1102 // later.
1103 needsDeref = true;
1104 this->write("(*");
1105 }
1106
1107 this->writeName(v.mangledName());
1108 if (needsDeref) {
1109 this->write(")");
1110 }
1111 if (conversion.has_value()) {
1112 this->write(")");
1113 }
1114 }
1115
writeAnyConstructor(const AnyConstructor & c,Precedence parentPrecedence)1116 void WGSLCodeGenerator::writeAnyConstructor(const AnyConstructor& c, Precedence parentPrecedence) {
1117 this->write(to_wgsl_type(c.type()));
1118 this->write("(");
1119 auto separator = SkSL::String::Separator();
1120 for (const auto& e : c.argumentSpan()) {
1121 this->write(separator());
1122 this->writeExpression(*e, Precedence::kSequence);
1123 }
1124 this->write(")");
1125 }
1126
writeConstructorCompound(const ConstructorCompound & c,Precedence parentPrecedence)1127 void WGSLCodeGenerator::writeConstructorCompound(const ConstructorCompound& c,
1128 Precedence parentPrecedence) {
1129 // TODO(skia:13092): Support matrix constructors
1130 if (c.type().isVector()) {
1131 this->writeConstructorCompoundVector(c, parentPrecedence);
1132 } else {
1133 fContext.fErrors->error(c.fPosition, "unsupported compound constructor");
1134 }
1135 }
1136
writeConstructorCompoundVector(const ConstructorCompound & c,Precedence parentPrecedence)1137 void WGSLCodeGenerator::writeConstructorCompoundVector(const ConstructorCompound& c,
1138 Precedence parentPrecedence) {
1139 // WGSL supports constructing vectors from a mix of scalars and vectors but
1140 // not matrices (see https://www.w3.org/TR/WGSL/#type-constructor-expr).
1141 //
1142 // SkSL supports vec4(mat2x2) which we handle specially.
1143 if (c.type().columns() == 4 && c.argumentSpan().size() == 1) {
1144 const Expression& arg = *c.argumentSpan().front();
1145 if (arg.type().isMatrix()) {
1146 // This is the vec4(mat2x2) case.
1147 SkASSERT(arg.type().columns() == 2);
1148 SkASSERT(arg.type().rows() == 2);
1149
1150 // Generate a helper so that the argument expression gets evaluated once.
1151 std::string name = String::printf("%s_from_%s",
1152 to_mangled_wgsl_type_name(c.type()).c_str(),
1153 to_mangled_wgsl_type_name(arg.type()).c_str());
1154 if (!fHelpers.contains(name)) {
1155 fHelpers.add(name);
1156 std::string returnType = to_wgsl_type(c.type());
1157 std::string argType = to_wgsl_type(arg.type());
1158 fExtraFunctions.printf(
1159 "fn %s(x: %s) -> %s {\n return %s(x[0].xy, x[1].xy);\n}\n",
1160 name.c_str(),
1161 argType.c_str(),
1162 returnType.c_str(),
1163 returnType.c_str());
1164 }
1165 this->write(name);
1166 this->write("(");
1167 this->writeExpression(arg, Precedence::kSequence);
1168 this->write(")");
1169 return;
1170 }
1171 }
1172 this->writeAnyConstructor(c, parentPrecedence);
1173 }
1174
writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix & c,Precedence parentPrecedence)1175 void WGSLCodeGenerator::writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c,
1176 Precedence parentPrecedence) {
1177 const Type& type = c.type();
1178 SkASSERT(type.isMatrix());
1179 SkASSERT(c.argument()->type().isScalar());
1180
1181 // Generate a helper so that the argument expression gets evaluated once.
1182 std::string name = String::printf("%s_diagonal", to_mangled_wgsl_type_name(type).c_str());
1183 if (!fHelpers.contains(name)) {
1184 fHelpers.add(name);
1185
1186 std::string typeName = to_wgsl_type(type);
1187 fExtraFunctions.printf("fn %s(x: %s) -> %s {\n",
1188 name.c_str(),
1189 to_wgsl_type(c.argument()->type()).c_str(),
1190 typeName.c_str());
1191 fExtraFunctions.printf(" return %s(", typeName.c_str());
1192 auto separator = String::Separator();
1193 for (int col = 0; col < type.columns(); ++col) {
1194 for (int row = 0; row < type.rows(); ++row) {
1195 fExtraFunctions.printf("%s%s", separator().c_str(), (col == row) ? "x" : "0.0");
1196 }
1197 }
1198 fExtraFunctions.printf(");\n}\n");
1199 }
1200 this->write(name);
1201 this->write("(");
1202 this->writeExpression(*c.argument(), Precedence::kSequence);
1203 this->write(")");
1204 }
1205
writeMatrixEquality(const Expression & left,const Expression & right)1206 void WGSLCodeGenerator::writeMatrixEquality(const Expression& left, const Expression& right) {
1207 const Type& leftType = left.type();
1208 const Type& rightType = right.type();
1209 SkASSERT(leftType.isMatrix());
1210 SkASSERT(rightType.isMatrix());
1211 SkASSERT(leftType.rows() == rightType.rows());
1212 SkASSERT(leftType.columns() == rightType.columns());
1213
1214 std::string name = String::printf("%s_eq_%s",
1215 to_mangled_wgsl_type_name(leftType).c_str(),
1216 to_mangled_wgsl_type_name(rightType).c_str());
1217 if (!fHelpers.contains(name)) {
1218 fHelpers.add(name);
1219 fExtraFunctions.printf("fn %s(left: %s, right: %s) -> bool {\n return ",
1220 name.c_str(),
1221 to_wgsl_type(leftType).c_str(),
1222 to_wgsl_type(rightType).c_str());
1223 const char* separator = "";
1224 for (int i = 0; i < leftType.columns(); ++i) {
1225 fExtraFunctions.printf("%sall(left[%d] == right[%d])", separator, i, i);
1226 separator = " &&\n ";
1227 }
1228 fExtraFunctions.printf(";\n}\n");
1229 }
1230 this->write(name);
1231 this->write("(");
1232 this->writeExpression(left, Precedence::kSequence);
1233 this->write(", ");
1234 this->writeExpression(right, Precedence::kSequence);
1235 this->write(")");
1236 }
1237
writeProgramElement(const ProgramElement & e)1238 void WGSLCodeGenerator::writeProgramElement(const ProgramElement& e) {
1239 switch (e.kind()) {
1240 case ProgramElement::Kind::kExtension:
1241 // TODO(skia:13092): WGSL supports extensions via the "enable" directive
1242 // (https://www.w3.org/TR/WGSL/#language-extensions). While we could easily emit this
1243 // directive, we should first ensure that all possible SkSL extension names are
1244 // converted to their appropriate WGSL extension. Currently there are no known supported
1245 // WGSL extensions aside from the hypotheticals listed in the spec.
1246 break;
1247 case ProgramElement::Kind::kGlobalVar:
1248 this->writeGlobalVarDeclaration(e.as<GlobalVarDeclaration>());
1249 break;
1250 case ProgramElement::Kind::kInterfaceBlock:
1251 // All interface block declarations are handled explicitly as the "program header" in
1252 // generateCode().
1253 break;
1254 case ProgramElement::Kind::kStructDefinition:
1255 this->writeStructDefinition(e.as<StructDefinition>());
1256 break;
1257 case ProgramElement::Kind::kFunctionPrototype:
1258 // A WGSL function declaration must contain its body and the function name is in scope
1259 // for the entire program (see https://www.w3.org/TR/WGSL/#function-declaration and
1260 // https://www.w3.org/TR/WGSL/#declaration-and-scope).
1261 //
1262 // As such, we don't emit function prototypes.
1263 break;
1264 case ProgramElement::Kind::kFunction:
1265 this->writeFunction(e.as<FunctionDefinition>());
1266 break;
1267 default:
1268 SkDEBUGFAILF("unsupported program element: %s\n", e.description().c_str());
1269 break;
1270 }
1271 }
1272
writeGlobalVarDeclaration(const GlobalVarDeclaration & d)1273 void WGSLCodeGenerator::writeGlobalVarDeclaration(const GlobalVarDeclaration& d) {
1274 const Variable& var = *d.declaration()->as<VarDeclaration>().var();
1275 if ((var.modifiers().fFlags & (Modifiers::kIn_Flag | Modifiers::kOut_Flag)) ||
1276 is_in_global_uniforms(var)) {
1277 // Pipeline stage I/O parameters and top-level (non-block) uniforms are handled specially
1278 // in generateCode().
1279 return;
1280 }
1281
1282 // TODO(skia:13092): Implement workgroup variable decoration
1283 this->write("var<private> ");
1284 this->writeVariableDecl(var.type(), var.name(), Delimiter::kSemicolon);
1285 }
1286
writeStructDefinition(const StructDefinition & s)1287 void WGSLCodeGenerator::writeStructDefinition(const StructDefinition& s) {
1288 const Type& type = s.type();
1289 this->writeLine("struct " + type.displayName() + " {");
1290 fIndentation++;
1291 this->writeFields(SkSpan(type.fields()), type.fPosition);
1292 fIndentation--;
1293 this->writeLine("};");
1294 }
1295
writeFields(SkSpan<const Type::Field> fields,Position parentPos,const MemoryLayout *)1296 void WGSLCodeGenerator::writeFields(SkSpan<const Type::Field> fields,
1297 Position parentPos,
1298 const MemoryLayout*) {
1299 // TODO(skia:13092): Check alignment against `layout` constraints, if present. A layout
1300 // constraint will be specified for interface blocks and for structs that appear in a block.
1301 for (const Type::Field& field : fields) {
1302 const Type* fieldType = field.fType;
1303 this->writeVariableDecl(*fieldType, field.fName, Delimiter::kComma);
1304 }
1305 }
1306
writeStageInputStruct()1307 void WGSLCodeGenerator::writeStageInputStruct() {
1308 std::string_view structNamePrefix = pipeline_struct_prefix(fProgram.fConfig->fKind);
1309 if (structNamePrefix.empty()) {
1310 // There's no need to declare pipeline stage outputs.
1311 return;
1312 }
1313
1314 // It is illegal to declare a struct with no members.
1315 if (fPipelineInputCount < 1) {
1316 return;
1317 }
1318
1319 this->write("struct ");
1320 this->write(structNamePrefix);
1321 this->writeLine("In {");
1322 fIndentation++;
1323
1324 bool declaredFragCoordsBuiltin = false;
1325 for (const ProgramElement* e : fProgram.elements()) {
1326 if (e->is<GlobalVarDeclaration>()) {
1327 const Variable* v = e->as<GlobalVarDeclaration>().declaration()
1328 ->as<VarDeclaration>().var();
1329 if (v->modifiers().fFlags & Modifiers::kIn_Flag) {
1330 this->writePipelineIODeclaration(v->modifiers(), v->type(), v->mangledName(),
1331 Delimiter::kComma);
1332 if (v->modifiers().fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN) {
1333 declaredFragCoordsBuiltin = true;
1334 }
1335 }
1336 } else if (e->is<InterfaceBlock>()) {
1337 const Variable* v = e->as<InterfaceBlock>().var();
1338 // Merge all the members of `in` interface blocks to the input struct, which are
1339 // specified as either "builtin" or with a "layout(location=".
1340 //
1341 // TODO(armansito): Is it legal to have an interface block without a storage qualifier
1342 // but with members that have individual storage qualifiers?
1343 if (v->modifiers().fFlags & Modifiers::kIn_Flag) {
1344 for (const auto& f : v->type().fields()) {
1345 this->writePipelineIODeclaration(f.fModifiers, *f.fType, f.fName,
1346 Delimiter::kComma);
1347 if (f.fModifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN) {
1348 declaredFragCoordsBuiltin = true;
1349 }
1350 }
1351 }
1352 }
1353 }
1354
1355 if (ProgramConfig::IsFragment(fProgram.fConfig->fKind) &&
1356 fRequirements.mainNeedsCoordsArgument && !declaredFragCoordsBuiltin) {
1357 this->writeLine("@builtin(position) sk_FragCoord: vec4<f32>,");
1358 }
1359
1360 fIndentation--;
1361 this->writeLine("};");
1362 }
1363
writeStageOutputStruct()1364 void WGSLCodeGenerator::writeStageOutputStruct() {
1365 std::string_view structNamePrefix = pipeline_struct_prefix(fProgram.fConfig->fKind);
1366 if (structNamePrefix.empty()) {
1367 // There's no need to declare pipeline stage outputs.
1368 return;
1369 }
1370
1371 this->write("struct ");
1372 this->write(structNamePrefix);
1373 this->writeLine("Out {");
1374 fIndentation++;
1375
1376 // TODO(skia:13092): Remember all variables that are added to the output struct here so they
1377 // can be referenced correctly when handling variable references.
1378 bool declaredPositionBuiltin = false;
1379 bool requiresPointSizeBuiltin = false;
1380 for (const ProgramElement* e : fProgram.elements()) {
1381 if (e->is<GlobalVarDeclaration>()) {
1382 const Variable* v = e->as<GlobalVarDeclaration>().declaration()
1383 ->as<VarDeclaration>().var();
1384 if (v->modifiers().fFlags & Modifiers::kOut_Flag) {
1385 this->writePipelineIODeclaration(v->modifiers(), v->type(), v->mangledName(),
1386 Delimiter::kComma);
1387 }
1388 } else if (e->is<InterfaceBlock>()) {
1389 const Variable* v = e->as<InterfaceBlock>().var();
1390 // Merge all the members of `out` interface blocks to the output struct, which are
1391 // specified as either "builtin" or with a "layout(location=".
1392 //
1393 // TODO(armansito): Is it legal to have an interface block without a storage qualifier
1394 // but with members that have individual storage qualifiers?
1395 if (v->modifiers().fFlags & Modifiers::kOut_Flag) {
1396 for (const auto& f : v->type().fields()) {
1397 this->writePipelineIODeclaration(f.fModifiers, *f.fType, f.fName,
1398 Delimiter::kComma);
1399 if (f.fModifiers.fLayout.fBuiltin == SK_POSITION_BUILTIN) {
1400 declaredPositionBuiltin = true;
1401 } else if (f.fModifiers.fLayout.fBuiltin == SK_POINTSIZE_BUILTIN) {
1402 // sk_PointSize is explicitly not supported by `builtin_from_sksl_name` so
1403 // writePipelineIODeclaration will never write it. We mark it here if the
1404 // declaration is needed so we can synthesize it below.
1405 requiresPointSizeBuiltin = true;
1406 }
1407 }
1408 }
1409 }
1410 }
1411
1412 // A vertex program must include the `position` builtin in its entry point return type.
1413 if (ProgramConfig::IsVertex(fProgram.fConfig->fKind) && !declaredPositionBuiltin) {
1414 this->writeLine("@builtin(position) sk_Position: vec4<f32>,");
1415 }
1416
1417 fIndentation--;
1418 this->writeLine("};");
1419
1420 // In WebGPU/WGSL, the vertex stage does not support a point-size output and the size
1421 // of a point primitive is always 1 pixel (see https://github.com/gpuweb/gpuweb/issues/332).
1422 //
1423 // There isn't anything we can do to emulate this correctly at this stage so we
1424 // synthesize a placeholder variable that has no effect. Programs should not rely on
1425 // sk_PointSize when using the Dawn backend.
1426 if (ProgramConfig::IsVertex(fProgram.fConfig->fKind) && requiresPointSizeBuiltin) {
1427 this->writeLine("/* unsupported */ var<private> sk_PointSize: f32;");
1428 }
1429 }
1430
writeNonBlockUniformsForTests()1431 void WGSLCodeGenerator::writeNonBlockUniformsForTests() {
1432 for (const ProgramElement* e : fProgram.elements()) {
1433 if (e->is<GlobalVarDeclaration>()) {
1434 const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
1435 const Variable& var = *decls.varDeclaration().var();
1436 if (is_in_global_uniforms(var)) {
1437 if (!fDeclaredUniformsStruct) {
1438 this->write("struct _GlobalUniforms {\n");
1439 fDeclaredUniformsStruct = true;
1440 }
1441 this->write(" ");
1442 this->writeVariableDecl(var.type(), var.mangledName(), Delimiter::kComma);
1443 }
1444 }
1445 }
1446 if (fDeclaredUniformsStruct) {
1447 int binding = fProgram.fConfig->fSettings.fDefaultUniformBinding;
1448 int set = fProgram.fConfig->fSettings.fDefaultUniformSet;
1449 this->write("};\n");
1450 this->write("@binding(" + std::to_string(binding) + ") ");
1451 this->write("@group(" + std::to_string(set) + ") ");
1452 this->writeLine("var<uniform> _globalUniforms: _GlobalUniforms;");
1453 }
1454 }
1455
writeFunctionDependencyArgs(const FunctionDeclaration & f)1456 bool WGSLCodeGenerator::writeFunctionDependencyArgs(const FunctionDeclaration& f) {
1457 FunctionDependencies* deps = fRequirements.dependencies.find(&f);
1458 if (!deps || *deps == FunctionDependencies::kNone) {
1459 return false;
1460 }
1461
1462 const char* separator = "";
1463 if ((*deps & FunctionDependencies::kPipelineInputs) != FunctionDependencies::kNone) {
1464 this->write("_stageIn");
1465 separator = ", ";
1466 }
1467 if ((*deps & FunctionDependencies::kPipelineOutputs) != FunctionDependencies::kNone) {
1468 this->write(separator);
1469 this->write("_stageOut");
1470 }
1471 return true;
1472 }
1473
writeFunctionDependencyParams(const FunctionDeclaration & f)1474 bool WGSLCodeGenerator::writeFunctionDependencyParams(const FunctionDeclaration& f) {
1475 FunctionDependencies* deps = fRequirements.dependencies.find(&f);
1476 if (!deps || *deps == FunctionDependencies::kNone) {
1477 return false;
1478 }
1479
1480 std::string_view structNamePrefix = pipeline_struct_prefix(fProgram.fConfig->fKind);
1481 if (structNamePrefix.empty()) {
1482 return false;
1483 }
1484 const char* separator = "";
1485 if ((*deps & FunctionDependencies::kPipelineInputs) != FunctionDependencies::kNone) {
1486 this->write("_stageIn: ");
1487 separator = ", ";
1488 this->write(structNamePrefix);
1489 this->write("In");
1490 }
1491 if ((*deps & FunctionDependencies::kPipelineOutputs) != FunctionDependencies::kNone) {
1492 this->write(separator);
1493 this->write("_stageOut: ptr<function, ");
1494 this->write(structNamePrefix);
1495 this->write("Out>");
1496 }
1497 return true;
1498 }
1499
writeOutParamHelper(const FunctionCall & c,const ExpressionArray & args,const SkTArray<VariableReference * > & outVars)1500 std::string WGSLCodeGenerator::writeOutParamHelper(const FunctionCall& c,
1501 const ExpressionArray& args,
1502 const SkTArray<VariableReference*>& outVars) {
1503 // It's possible for out-param function arguments to contain an out-param function call
1504 // expression. Emit the function into a temporary stream to prevent the nested helper from
1505 // clobbering the current helper as we recursively evaluate argument expressions.
1506 StringStream tmpStream;
1507 AutoOutputStream outputToExtraFunctions(this, &tmpStream, &fIndentation);
1508
1509 // Reset the line start state while the AutoOutputStream is active. We restore it later before
1510 // the function returns.
1511 bool atLineStart = fAtLineStart;
1512 fAtLineStart = false;
1513 const FunctionDeclaration& func = c.function();
1514
1515 // Synthesize a helper function that takes the same inputs as `function`, except in places where
1516 // `outVars` is non-null; in those places, we take the type of the VariableReference.
1517 //
1518 // float _outParamHelper_0_originalFuncName(float _var0, float _var1, float& outParam) {
1519 std::string name =
1520 "_outParamHelper_" + std::to_string(fSwizzleHelperCount++) + "_" + func.mangledName();
1521 auto separator = SkSL::String::Separator();
1522 this->write("fn ");
1523 this->write(name);
1524 this->write("(");
1525 if (this->writeFunctionDependencyParams(func)) {
1526 separator();
1527 }
1528
1529 SkASSERT(outVars.size() == args.size());
1530 SkASSERT(SkToSizeT(outVars.size()) == func.parameters().size());
1531
1532 // We need to detect cases where the caller passes the same variable as an out-param more than
1533 // once and avoid redeclaring the variable name. This is also a situation that is not permitted
1534 // by WGSL aliasing rules (see https://www.w3.org/TR/WGSL/#aliasing). Because the parameter is
1535 // redundant and we don't actually ever reference it, we give it a placeholder name.
1536 auto parentOutParamArgVars = std::move(fOutParamArgVars);
1537 SkASSERT(fOutParamArgVars.empty());
1538
1539 for (int i = 0; i < args.size(); ++i) {
1540 this->write(separator());
1541
1542 if (outVars[i]) {
1543 const Variable* var = outVars[i]->variable();
1544 if (!fOutParamArgVars.contains(var)) {
1545 fOutParamArgVars.add(var);
1546 this->writeName(var->mangledName());
1547 } else {
1548 this->write("_unused");
1549 this->write(std::to_string(i));
1550 }
1551 } else {
1552 this->write("_var");
1553 this->write(std::to_string(i));
1554 }
1555
1556 this->write(": ");
1557
1558 // Declare the parameter using the type of argument variable. If the complete argument is an
1559 // access or swizzle expression, the target assignment will be resolved below when we copy
1560 // the value to the out-parameter.
1561 const Type& type = outVars[i] ? outVars[i]->type() : args[i]->type();
1562
1563 // Declare an out-parameter as a pointer.
1564 if (func.parameters()[i]->modifiers().fFlags & Modifiers::kOut_Flag) {
1565 this->write(to_ptr_type(type));
1566 } else {
1567 this->write(to_wgsl_type(type));
1568 }
1569 }
1570
1571 this->write(")");
1572 if (!func.returnType().isVoid()) {
1573 this->write(" -> ");
1574 this->write(to_wgsl_type(func.returnType()));
1575 }
1576 this->writeLine(" {");
1577 ++fIndentation;
1578
1579 // Declare a temporary variable for each out-parameter.
1580 for (int i = 0; i < outVars.size(); ++i) {
1581 if (!outVars[i]) {
1582 continue;
1583 }
1584 this->write("var ");
1585 this->write("_var");
1586 this->write(std::to_string(i));
1587 this->write(": ");
1588 this->write(to_wgsl_type(args[i]->type()));
1589
1590 // If this is an inout parameter then we need to copy the input argument into the parameter
1591 // per https://www.khronos.org/opengl/wiki/Core_Language_(GLSL)#Parameters.
1592 if (func.parameters()[i]->modifiers().fFlags & Modifiers::kIn_Flag) {
1593 this->write(" = ");
1594 this->writeExpression(*args[i], Precedence::kAssignment);
1595 }
1596
1597 this->writeLine(";");
1598 }
1599
1600 // Call the function we're wrapping. If it has a return type, then store it so it can be
1601 // returned later.
1602 bool hasReturn = !c.type().isVoid();
1603 if (hasReturn) {
1604 this->write("var _return: ");
1605 this->write(to_wgsl_type(c.type()));
1606 this->write(" = ");
1607 }
1608
1609 // Write the function call.
1610 this->writeName(func.mangledName());
1611 this->write("(");
1612 auto newSeparator = SkSL::String::Separator();
1613 if (this->writeFunctionDependencyArgs(func)) {
1614 newSeparator();
1615 }
1616 for (int i = 0; i < args.size(); ++i) {
1617 this->write(newSeparator());
1618 // All forwarded arguments now have a name that looks like "_var[i]" (e.g. _var0, var1,
1619 // etc.). All such variables should be of value type and those that have been passed in as
1620 // inout should have been dereferenced when they were stored in a local temporary. We need
1621 // to take their address again when forwarding to a pointer.
1622 if (outVars[i]) {
1623 this->write("&");
1624 }
1625 this->write("_var");
1626 this->write(std::to_string(i));
1627 }
1628 this->writeLine(");");
1629
1630 // Copy the temporary variables back into the original out-parameters.
1631 for (int i = 0; i < outVars.size(); ++i) {
1632 if (!outVars[i]) {
1633 continue;
1634 }
1635 // TODO(skia:13092): WGSL does not support assigning to a swizzle
1636 // (see https://github.com/gpuweb/gpuweb/issues/737). These will require special treatment
1637 // when they appear on the lhs of an assignment.
1638 this->writeExpression(*args[i], Precedence::kAssignment);
1639 this->write(" = _var");
1640 this->write(std::to_string(i));
1641 this->writeLine(";");
1642 }
1643
1644 // Return
1645 if (hasReturn) {
1646 this->writeLine("return _return;");
1647 }
1648
1649 --fIndentation;
1650 this->writeLine("}");
1651
1652 // Write the function out to `fExtraFunctions`.
1653 write_stringstream(tmpStream, fExtraFunctions);
1654
1655 // Restore any global state
1656 fOutParamArgVars = std::move(parentOutParamArgVars);
1657 fAtLineStart = atLineStart;
1658 return name;
1659 }
1660
1661 } // namespace SkSL
1662