• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2024 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include "compiler/translator/wgsl/TranslatorWGSL.h"
8 
9 #include <iostream>
10 #include <variant>
11 
12 #include "GLSLANG/ShaderLang.h"
13 #include "common/log_utils.h"
14 #include "common/span.h"
15 #include "compiler/translator/BaseTypes.h"
16 #include "compiler/translator/Common.h"
17 #include "compiler/translator/Diagnostics.h"
18 #include "compiler/translator/ImmutableString.h"
19 #include "compiler/translator/ImmutableStringBuilder.h"
20 #include "compiler/translator/InfoSink.h"
21 #include "compiler/translator/IntermNode.h"
22 #include "compiler/translator/Operator_autogen.h"
23 #include "compiler/translator/OutputTree.h"
24 #include "compiler/translator/StaticType.h"
25 #include "compiler/translator/SymbolUniqueId.h"
26 #include "compiler/translator/Types.h"
27 #include "compiler/translator/tree_ops/MonomorphizeUnsupportedFunctions.h"
28 #include "compiler/translator/tree_ops/RewriteArrayOfArrayOfOpaqueUniforms.h"
29 #include "compiler/translator/tree_ops/RewriteStructSamplers.h"
30 #include "compiler/translator/tree_ops/SeparateStructFromUniformDeclarations.h"
31 #include "compiler/translator/tree_util/BuiltIn_autogen.h"
32 #include "compiler/translator/tree_util/FindMain.h"
33 #include "compiler/translator/tree_util/IntermNode_util.h"
34 #include "compiler/translator/tree_util/IntermTraverse.h"
35 #include "compiler/translator/tree_util/RunAtTheEndOfShader.h"
36 #include "compiler/translator/wgsl/OutputUniformBlocks.h"
37 #include "compiler/translator/wgsl/RewritePipelineVariables.h"
38 #include "compiler/translator/wgsl/Utils.h"
39 
40 namespace sh
41 {
42 namespace
43 {
44 
45 constexpr bool kOutputTreeBeforeTranslation = false;
46 constexpr bool kOutputTranslatedShader      = false;
47 
48 struct VarDecl
49 {
50     const SymbolType symbolType = SymbolType::Empty;
51     const ImmutableString &symbolName;
52     const TType &type;
53 };
54 
IsDefaultUniform(const TType & type)55 bool IsDefaultUniform(const TType &type)
56 {
57     return type.getQualifier() == EvqUniform && type.getInterfaceBlock() == nullptr &&
58            !IsOpaqueType(type.getBasicType());
59 }
60 
61 // When emitting a list of statements, this determines whether a semicolon follows the statement.
RequiresSemicolonTerminator(TIntermNode & node)62 bool RequiresSemicolonTerminator(TIntermNode &node)
63 {
64     if (node.getAsBlock())
65     {
66         return false;
67     }
68     if (node.getAsLoopNode())
69     {
70         return false;
71     }
72     if (node.getAsSwitchNode())
73     {
74         return false;
75     }
76     if (node.getAsIfElseNode())
77     {
78         return false;
79     }
80     if (node.getAsFunctionDefinition())
81     {
82         return false;
83     }
84     if (node.getAsCaseNode())
85     {
86         return false;
87     }
88 
89     return true;
90 }
91 
92 // For pretty formatting of the resulting WGSL text.
NewlinePad(TIntermNode & node)93 bool NewlinePad(TIntermNode &node)
94 {
95     if (node.getAsFunctionDefinition())
96     {
97         return true;
98     }
99     if (TIntermDeclaration *declNode = node.getAsDeclarationNode())
100     {
101         ASSERT(declNode->getChildCount() == 1);
102         TIntermNode &childNode = *declNode->getChildNode(0);
103         if (TIntermSymbol *symbolNode = childNode.getAsSymbolNode())
104         {
105             const TVariable &var = symbolNode->variable();
106             return var.getType().isStructSpecifier();
107         }
108         return false;
109     }
110     return false;
111 }
112 
113 // A traverser that generates WGSL as it walks the AST.
114 class OutputWGSLTraverser : public TIntermTraverser
115 {
116   public:
117     OutputWGSLTraverser(TInfoSinkBase *sink,
118                         RewritePipelineVarOutput *rewritePipelineVarOutput,
119                         UniformBlockMetadata *uniformBlockMetadata,
120                         WGSLGenerationMetadataForUniforms *arrayElementTypesInUniforms);
121     ~OutputWGSLTraverser() override;
122 
123   protected:
124     void visitSymbol(TIntermSymbol *node) override;
125     void visitConstantUnion(TIntermConstantUnion *node) override;
126     bool visitSwizzle(Visit visit, TIntermSwizzle *node) override;
127     bool visitBinary(Visit visit, TIntermBinary *node) override;
128     bool visitUnary(Visit visit, TIntermUnary *node) override;
129     bool visitTernary(Visit visit, TIntermTernary *node) override;
130     bool visitIfElse(Visit visit, TIntermIfElse *node) override;
131     bool visitSwitch(Visit visit, TIntermSwitch *node) override;
132     bool visitCase(Visit visit, TIntermCase *node) override;
133     void visitFunctionPrototype(TIntermFunctionPrototype *node) override;
134     bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
135     bool visitAggregate(Visit visit, TIntermAggregate *node) override;
136     bool visitBlock(Visit visit, TIntermBlock *node) override;
137     bool visitGlobalQualifierDeclaration(Visit visit,
138                                          TIntermGlobalQualifierDeclaration *node) override;
139     bool visitDeclaration(Visit visit, TIntermDeclaration *node) override;
140     bool visitLoop(Visit visit, TIntermLoop *node) override;
141     bool visitBranch(Visit visit, TIntermBranch *node) override;
142     void visitPreprocessorDirective(TIntermPreprocessorDirective *node) override;
143 
144   private:
145     struct EmitVariableDeclarationConfig
146     {
147         EmitTypeConfig typeConfig;
148         bool isParameter            = false;
149         bool disableStructSpecifier = false;
150         bool needsVar               = false;
151         bool isGlobalScope          = false;
152     };
153 
154     void groupedTraverse(TIntermNode &node);
155     void emitNameOf(const VarDecl &decl);
156     void emitBareTypeName(const TType &type);
157     void emitType(const TType &type);
158     void emitSingleConstant(const TConstantUnion *const constUnion);
159     const TConstantUnion *emitConstantUnionArray(const TConstantUnion *const constUnion,
160                                                  const size_t size);
161     const TConstantUnion *emitConstantUnion(const TType &type,
162                                             const TConstantUnion *constUnionBegin);
163     const TField &getDirectField(const TIntermTyped &fieldsNode, TIntermTyped &indexNode);
164     void emitIndentation();
165     void emitOpenBrace();
166     void emitCloseBrace();
167     bool emitBlock(angle::Span<TIntermNode *> nodes);
168     void emitFunctionSignature(const TFunction &func);
169     void emitFunctionReturn(const TFunction &func);
170     void emitFunctionParameter(const TFunction &func, const TVariable &param);
171     void emitStructDeclaration(const TType &type);
172     void emitVariableDeclaration(const VarDecl &decl,
173                                  const EmitVariableDeclarationConfig &evdConfig);
174     void emitArrayIndex(TIntermTyped &leftNode, TIntermTyped &rightNode);
175     void emitStructIndex(TIntermBinary *binaryNode);
176     void emitStructIndexNoUnwrapping(TIntermBinary *binaryNode);
177     void emitTextureBuiltin(const TOperator op, const TIntermSequence &args);
178 
179     bool emitForLoop(TIntermLoop *);
180     bool emitWhileLoop(TIntermLoop *);
181     bool emulateDoWhileLoop(TIntermLoop *);
182 
183     TInfoSinkBase &mSink;
184     const RewritePipelineVarOutput *mRewritePipelineVarOutput;
185     const UniformBlockMetadata *mUniformBlockMetadata;
186     WGSLGenerationMetadataForUniforms *mWGSLGenerationMetadataForUniforms;
187 
188     int mIndentLevel        = -1;
189     int mLastIndentationPos = -1;
190 };
191 
OutputWGSLTraverser(TInfoSinkBase * sink,RewritePipelineVarOutput * rewritePipelineVarOutput,UniformBlockMetadata * uniformBlockMetadata,WGSLGenerationMetadataForUniforms * wgslGenerationMetadataForUniforms)192 OutputWGSLTraverser::OutputWGSLTraverser(
193     TInfoSinkBase *sink,
194     RewritePipelineVarOutput *rewritePipelineVarOutput,
195     UniformBlockMetadata *uniformBlockMetadata,
196     WGSLGenerationMetadataForUniforms *wgslGenerationMetadataForUniforms)
197     : TIntermTraverser(true, false, false),
198       mSink(*sink),
199       mRewritePipelineVarOutput(rewritePipelineVarOutput),
200       mUniformBlockMetadata(uniformBlockMetadata),
201       mWGSLGenerationMetadataForUniforms(wgslGenerationMetadataForUniforms)
202 {}
203 
204 OutputWGSLTraverser::~OutputWGSLTraverser() = default;
205 
groupedTraverse(TIntermNode & node)206 void OutputWGSLTraverser::groupedTraverse(TIntermNode &node)
207 {
208     // TODO(anglebug.com/42267100): to make generated code more readable, do not always
209     // emit parentheses like WGSL is some Lisp dialect.
210     const bool emitParens = true;
211 
212     if (emitParens)
213     {
214         mSink << "(";
215     }
216 
217     node.traverse(this);
218 
219     if (emitParens)
220     {
221         mSink << ")";
222     }
223 }
224 
emitNameOf(const VarDecl & decl)225 void OutputWGSLTraverser::emitNameOf(const VarDecl &decl)
226 {
227     WriteNameOf(mSink, decl.symbolType, decl.symbolName);
228 }
229 
emitIndentation()230 void OutputWGSLTraverser::emitIndentation()
231 {
232     ASSERT(mIndentLevel >= 0);
233 
234     if (mLastIndentationPos == mSink.size())
235     {
236         return;  // Line is already indented.
237     }
238 
239     for (int i = 0; i < mIndentLevel; ++i)
240     {
241         mSink << "  ";
242     }
243 
244     mLastIndentationPos = mSink.size();
245 }
246 
emitOpenBrace()247 void OutputWGSLTraverser::emitOpenBrace()
248 {
249     ASSERT(mIndentLevel >= 0);
250 
251     emitIndentation();
252     mSink << "{\n";
253     ++mIndentLevel;
254 }
255 
emitCloseBrace()256 void OutputWGSLTraverser::emitCloseBrace()
257 {
258     ASSERT(mIndentLevel >= 1);
259 
260     --mIndentLevel;
261     emitIndentation();
262     mSink << "}";
263 }
264 
visitSymbol(TIntermSymbol * symbolNode)265 void OutputWGSLTraverser::visitSymbol(TIntermSymbol *symbolNode)
266 {
267 
268     const TVariable &var = symbolNode->variable();
269     const TType &type    = var.getType();
270     ASSERT(var.symbolType() != SymbolType::Empty);
271 
272     if (type.getBasicType() == TBasicType::EbtVoid)
273     {
274         UNREACHABLE();
275     }
276     else
277     {
278         // Accesses of pipeline variables should be rewritten as struct accesses.
279         if (mRewritePipelineVarOutput->IsInputVar(var.uniqueId()))
280         {
281             mSink << kBuiltinInputStructName << "." << var.name();
282         }
283         else if (mRewritePipelineVarOutput->IsOutputVar(var.uniqueId()))
284         {
285             mSink << kBuiltinOutputStructName << "." << var.name();
286         }
287         // Accesses of basic uniforms need to be converted to struct accesses.
288         else if (IsDefaultUniform(type))
289         {
290             mSink << kDefaultUniformBlockVarName << "." << var.name();
291         }
292         else
293         {
294             WriteNameOf(mSink, var);
295         }
296 
297         if (var.symbolType() == SymbolType::BuiltIn)
298         {
299             ASSERT(mRewritePipelineVarOutput->IsInputVar(var.uniqueId()) ||
300                    mRewritePipelineVarOutput->IsOutputVar(var.uniqueId()) ||
301                    var.uniqueId() == BuiltInId::gl_DepthRange);
302             // TODO(anglebug.com/42267100): support gl_DepthRange.
303             // Match the name of the struct field in `mRewritePipelineVarOutput`.
304             mSink << "_";
305         }
306     }
307 }
308 
emitSingleConstant(const TConstantUnion * const constUnion)309 void OutputWGSLTraverser::emitSingleConstant(const TConstantUnion *const constUnion)
310 {
311     switch (constUnion->getType())
312     {
313         case TBasicType::EbtBool:
314         {
315             mSink << (constUnion->getBConst() ? "true" : "false");
316         }
317         break;
318 
319         case TBasicType::EbtFloat:
320         {
321             float value = constUnion->getFConst();
322             if (std::isnan(value))
323             {
324                 UNIMPLEMENTED();
325                 // TODO(anglebug.com/42267100): this is not a valid constant in WGPU.
326                 // You can't even do something like bitcast<f32>(0xffffffffu).
327                 // The WGSL compiler still complains. I think this is because
328                 // WGSL supports implementations compiling with -ffastmath and
329                 // therefore nans and infinities are assumed to not exist.
330                 // See also https://github.com/gpuweb/gpuweb/issues/3749.
331                 mSink << "NAN_INVALID";
332             }
333             else if (std::isinf(value))
334             {
335                 UNIMPLEMENTED();
336                 // see above.
337                 mSink << "INFINITY_INVALID";
338             }
339             else
340             {
341                 mSink << value << "f";
342             }
343         }
344         break;
345 
346         case TBasicType::EbtInt:
347         {
348             mSink << constUnion->getIConst() << "i";
349         }
350         break;
351 
352         case TBasicType::EbtUInt:
353         {
354             mSink << constUnion->getUConst() << "u";
355         }
356         break;
357 
358         default:
359         {
360             UNIMPLEMENTED();
361         }
362     }
363 }
364 
emitConstantUnionArray(const TConstantUnion * const constUnion,const size_t size)365 const TConstantUnion *OutputWGSLTraverser::emitConstantUnionArray(
366     const TConstantUnion *const constUnion,
367     const size_t size)
368 {
369     const TConstantUnion *constUnionIterated = constUnion;
370     for (size_t i = 0; i < size; i++, constUnionIterated++)
371     {
372         emitSingleConstant(constUnionIterated);
373 
374         if (i != size - 1)
375         {
376             mSink << ", ";
377         }
378     }
379     return constUnionIterated;
380 }
381 
emitConstantUnion(const TType & type,const TConstantUnion * constUnionBegin)382 const TConstantUnion *OutputWGSLTraverser::emitConstantUnion(const TType &type,
383                                                              const TConstantUnion *constUnionBegin)
384 {
385     const TConstantUnion *constUnionCurr = constUnionBegin;
386     const TStructure *structure          = type.getStruct();
387     if (structure)
388     {
389         emitType(type);
390         // Structs are constructed with parentheses in WGSL.
391         mSink << "(";
392         // Emit the constructor parameters. Both GLSL and WGSL require there to be the same number
393         // of parameters as struct fields.
394         const TFieldList &fields = structure->fields();
395         for (size_t i = 0; i < fields.size(); ++i)
396         {
397             const TType *fieldType = fields[i]->type();
398             constUnionCurr         = emitConstantUnion(*fieldType, constUnionCurr);
399             if (i != fields.size() - 1)
400             {
401                 mSink << ", ";
402             }
403         }
404         mSink << ")";
405     }
406     else
407     {
408         size_t size = type.getObjectSize();
409         // If the type's size is more than 1, the type needs to be written with parantheses. This
410         // applies for vectors, matrices, and arrays.
411         bool writeType = size > 1;
412         if (writeType)
413         {
414             emitType(type);
415             mSink << "(";
416         }
417         constUnionCurr = emitConstantUnionArray(constUnionCurr, size);
418         if (writeType)
419         {
420             mSink << ")";
421         }
422     }
423     return constUnionCurr;
424 }
425 
visitConstantUnion(TIntermConstantUnion * constValueNode)426 void OutputWGSLTraverser::visitConstantUnion(TIntermConstantUnion *constValueNode)
427 {
428     emitConstantUnion(constValueNode->getType(), constValueNode->getConstantValue());
429 }
430 
visitSwizzle(Visit,TIntermSwizzle * swizzleNode)431 bool OutputWGSLTraverser::visitSwizzle(Visit, TIntermSwizzle *swizzleNode)
432 {
433     groupedTraverse(*swizzleNode->getOperand());
434     mSink << "." << swizzleNode->getOffsetsAsXYZW();
435 
436     return false;
437 }
438 
GetOperatorString(TOperator op,const TType & resultType,const TType * argType0,const TType * argType1,const TType * argType2)439 const char *GetOperatorString(TOperator op,
440                               const TType &resultType,
441                               const TType *argType0,
442                               const TType *argType1,
443                               const TType *argType2)
444 {
445     switch (op)
446     {
447         case TOperator::EOpComma:
448             // WGSL does not have a comma operator or any other way to implement "statement list as
449             // an expression", so nested expressions will have to be pulled out into statements.
450             UNIMPLEMENTED();
451             return "TODO_operator";
452         case TOperator::EOpAssign:
453             return "=";
454         case TOperator::EOpInitialize:
455             return "=";
456         // Compound assignments now exist: https://www.w3.org/TR/WGSL/#compound-assignment-sec
457         case TOperator::EOpAddAssign:
458             return "+=";
459         case TOperator::EOpSubAssign:
460             return "-=";
461         case TOperator::EOpMulAssign:
462             return "*=";
463         case TOperator::EOpDivAssign:
464             return "/=";
465         case TOperator::EOpIModAssign:
466             return "%=";
467         case TOperator::EOpBitShiftLeftAssign:
468             return "<<=";
469         case TOperator::EOpBitShiftRightAssign:
470             return ">>=";
471         case TOperator::EOpBitwiseAndAssign:
472             return "&=";
473         case TOperator::EOpBitwiseXorAssign:
474             return "^=";
475         case TOperator::EOpBitwiseOrAssign:
476             return "|=";
477         case TOperator::EOpAdd:
478             return "+";
479         case TOperator::EOpSub:
480             return "-";
481         case TOperator::EOpMul:
482             return "*";
483         case TOperator::EOpDiv:
484             return "/";
485         // TODO(anglebug.com/42267100): Works different from GLSL for negative numbers.
486         // https://github.com/gpuweb/gpuweb/discussions/2204#:~:text=not%20WGSL%3B%20etc.-,Inconsistent%20mod/%25%20operator,-At%20first%20glance
487         // GLSL does `x - y * floor(x/y)`, WGSL does x - y * trunc(x/y).
488         case TOperator::EOpIMod:
489         case TOperator::EOpMod:
490             return "%";
491         case TOperator::EOpBitShiftLeft:
492             return "<<";
493         case TOperator::EOpBitShiftRight:
494             return ">>";
495         case TOperator::EOpBitwiseAnd:
496             return "&";
497         case TOperator::EOpBitwiseXor:
498             return "^";
499         case TOperator::EOpBitwiseOr:
500             return "|";
501         case TOperator::EOpLessThan:
502             return "<";
503         case TOperator::EOpGreaterThan:
504             return ">";
505         case TOperator::EOpLessThanEqual:
506             return "<=";
507         case TOperator::EOpGreaterThanEqual:
508             return ">=";
509         // Component-wise comparisons are done with regular infix operators in WGSL:
510         // https://www.w3.org/TR/WGSL/#comparison-expr
511         case TOperator::EOpLessThanComponentWise:
512             return "<";
513         case TOperator::EOpLessThanEqualComponentWise:
514             return "<=";
515         case TOperator::EOpGreaterThanEqualComponentWise:
516             return ">=";
517         case TOperator::EOpGreaterThanComponentWise:
518             return ">";
519         case TOperator::EOpLogicalOr:
520             return "||";
521         // Logical XOR is only applied to boolean expressions so it's the same as "not equals".
522         // Neither short-circuits.
523         case TOperator::EOpLogicalXor:
524             return "!=";
525         case TOperator::EOpLogicalAnd:
526             return "&&";
527         case TOperator::EOpNegative:
528             return "-";
529         case TOperator::EOpPositive:
530             if (argType0->isMatrix())
531             {
532                 return "";
533             }
534             return "+";
535         case TOperator::EOpLogicalNot:
536             return "!";
537         // Component-wise not done with normal prefix unary operator in WGSL:
538         // https://www.w3.org/TR/WGSL/#logical-expr
539         case TOperator::EOpNotComponentWise:
540             return "!";
541         case TOperator::EOpBitwiseNot:
542             return "~";
543         // TODO(anglebug.com/42267100): increment operations cannot be used as expressions in WGSL.
544         case TOperator::EOpPostIncrement:
545             return "++";
546         case TOperator::EOpPostDecrement:
547             return "--";
548         case TOperator::EOpPreIncrement:
549         case TOperator::EOpPreDecrement:
550             // TODO(anglebug.com/42267100): pre increments and decrements do not exist in WGSL.
551             UNIMPLEMENTED();
552             return "TODO_operator";
553         case TOperator::EOpVectorTimesScalarAssign:
554             return "*=";
555         case TOperator::EOpVectorTimesMatrixAssign:
556             return "*=";
557         case TOperator::EOpMatrixTimesScalarAssign:
558             return "*=";
559         case TOperator::EOpMatrixTimesMatrixAssign:
560             return "*=";
561         case TOperator::EOpVectorTimesScalar:
562             return "*";
563         case TOperator::EOpVectorTimesMatrix:
564             return "*";
565         case TOperator::EOpMatrixTimesVector:
566             return "*";
567         case TOperator::EOpMatrixTimesScalar:
568             return "*";
569         case TOperator::EOpMatrixTimesMatrix:
570             return "*";
571         case TOperator::EOpEqualComponentWise:
572             return "==";
573         case TOperator::EOpNotEqualComponentWise:
574             return "!=";
575 
576         // TODO(anglebug.com/42267100): structs, matrices, and arrays are not comparable with WGSL's
577         // == or !=. Comparing vectors results in a component-wise comparison returning a boolean
578         // vector, which is different from GLSL (which use equal(vec, vec) for component-wise
579         // comparison)
580         case TOperator::EOpEqual:
581             if ((argType0->isVector() && argType1->isVector()) ||
582                 (argType0->getStruct() && argType1->getStruct()) ||
583                 (argType0->isArray() && argType1->isArray()) ||
584                 (argType0->isMatrix() && argType1->isMatrix()))
585 
586             {
587                 UNIMPLEMENTED();
588                 return "TODO_operator";
589             }
590 
591             return "==";
592 
593         case TOperator::EOpNotEqual:
594             if ((argType0->isVector() && argType1->isVector()) ||
595                 (argType0->getStruct() && argType1->getStruct()) ||
596                 (argType0->isArray() && argType1->isArray()) ||
597                 (argType0->isMatrix() && argType1->isMatrix()))
598             {
599                 UNIMPLEMENTED();
600                 return "TODO_operator";
601             }
602             return "!=";
603 
604         case TOperator::EOpKill:
605         case TOperator::EOpReturn:
606         case TOperator::EOpBreak:
607         case TOperator::EOpContinue:
608             // These should all be emitted in visitBranch().
609             UNREACHABLE();
610             return "UNREACHABLE_operator";
611         case TOperator::EOpRadians:
612             return "radians";
613         case TOperator::EOpDegrees:
614             return "degrees";
615         case TOperator::EOpAtan:
616             return argType1 == nullptr ? "atan" : "atan2";
617         case TOperator::EOpRefract:
618             return argType0->isVector() ? "refract" : "TODO_operator";
619         case TOperator::EOpDistance:
620             return "distance";
621         case TOperator::EOpLength:
622             return "length";
623         case TOperator::EOpDot:
624             return argType0->isVector() ? "dot" : "*";
625         case TOperator::EOpNormalize:
626             return argType0->isVector() ? "normalize" : "sign";
627         case TOperator::EOpFaceforward:
628             return argType0->isVector() ? "faceForward" : "TODO_Operator";
629         case TOperator::EOpReflect:
630             return argType0->isVector() ? "reflect" : "TODO_Operator";
631         case TOperator::EOpMatrixCompMult:
632             return "TODO_Operator";
633         case TOperator::EOpOuterProduct:
634             return "TODO_Operator";
635         case TOperator::EOpSign:
636             return "sign";
637 
638         case TOperator::EOpAbs:
639             return "abs";
640         case TOperator::EOpAll:
641             return "all";
642         case TOperator::EOpAny:
643             return "any";
644         case TOperator::EOpSin:
645             return "sin";
646         case TOperator::EOpCos:
647             return "cos";
648         case TOperator::EOpTan:
649             return "tan";
650         case TOperator::EOpAsin:
651             return "asin";
652         case TOperator::EOpAcos:
653             return "acos";
654         case TOperator::EOpSinh:
655             return "sinh";
656         case TOperator::EOpCosh:
657             return "cosh";
658         case TOperator::EOpTanh:
659             return "tanh";
660         case TOperator::EOpAsinh:
661             return "asinh";
662         case TOperator::EOpAcosh:
663             return "acosh";
664         case TOperator::EOpAtanh:
665             return "atanh";
666         case TOperator::EOpFma:
667             return "fma";
668         // TODO(anglebug.com/42267100): Won't accept pow(vec<f32>, f32).
669         // https://github.com/gpuweb/gpuweb/discussions/2204#:~:text=Similarly%20pow(vec3%3Cf32%3E%2C%20f32)%20works%20in%20GLSL%20but%20not%20WGSL
670         case TOperator::EOpPow:
671             return "pow";  // GLSL's pow excludes negative x
672         case TOperator::EOpExp:
673             return "exp";
674         case TOperator::EOpExp2:
675             return "exp2";
676         case TOperator::EOpLog:
677             return "log";
678         case TOperator::EOpLog2:
679             return "log2";
680         case TOperator::EOpSqrt:
681             return "sqrt";
682         case TOperator::EOpFloor:
683             return "floor";
684         case TOperator::EOpTrunc:
685             return "trunc";
686         case TOperator::EOpCeil:
687             return "ceil";
688         case TOperator::EOpFract:
689             return "fract";
690         case TOperator::EOpMin:
691             return "min";
692         case TOperator::EOpMax:
693             return "max";
694         case TOperator::EOpRound:
695             return "round";  // TODO(anglebug.com/42267100): this is wrong and must round away from
696                              // zero if there is a tie. This always rounds to the even number.
697         case TOperator::EOpRoundEven:
698             return "round";
699         // TODO(anglebug.com/42267100):
700         // https://github.com/gpuweb/gpuweb/discussions/2204#:~:text=clamp(vec2%3Cf32%3E%2C%20f32%2C%20f32)%20works%20in%20GLSL%20but%20not%20WGSL%3B%20etc.
701         // Need to expand clamp(vec<f32>, low : f32, high : f32) ->
702         // clamp(vec<f32>, vec<f32>(low), vec<f32>(high))
703         case TOperator::EOpClamp:
704             return "clamp";
705         case TOperator::EOpSaturate:
706             return "saturate";
707         case TOperator::EOpMix:
708             if (!argType1->isScalar() && argType2 && argType2->getBasicType() == EbtBool)
709             {
710                 return "TODO_Operator";
711             }
712             return "mix";
713         case TOperator::EOpStep:
714             return "step";
715         case TOperator::EOpSmoothstep:
716             return "smoothstep";
717         case TOperator::EOpModf:
718             UNIMPLEMENTED();  // TODO(anglebug.com/42267100): in WGSL this returns a struct, GLSL it
719                               // uses a return value and an outparam
720             return "modf";
721         case TOperator::EOpIsnan:
722         case TOperator::EOpIsinf:
723             UNIMPLEMENTED();  // TODO(anglebug.com/42267100): WGSL does not allow NaNs or infinity.
724                               // What to do about shaders that require this?
725             // Implementations are allowed to assume overflow, infinities, and NaNs are not present
726             // at runtime, however. https://www.w3.org/TR/WGSL/#floating-point-evaluation
727             return "TODO_Operator";
728         case TOperator::EOpLdexp:
729             // TODO(anglebug.com/42267100): won't accept first arg vector, second arg scalar
730             return "ldexp";
731         case TOperator::EOpFrexp:
732             return "frexp";  // TODO(anglebug.com/42267100): returns a struct
733         case TOperator::EOpInversesqrt:
734             return "inverseSqrt";
735         case TOperator::EOpCross:
736             return "cross";
737             // TODO(anglebug.com/42267100): are these the same? dpdxCoarse() vs dpdxFine()?
738         case TOperator::EOpDFdx:
739             return "dpdx";
740         case TOperator::EOpDFdy:
741             return "dpdy";
742         case TOperator::EOpFwidth:
743             return "fwidth";
744         case TOperator::EOpTranspose:
745             return "transpose";
746         case TOperator::EOpDeterminant:
747             return "determinant";
748 
749         case TOperator::EOpInverse:
750             return "TODO_Operator";  // No builtin invert().
751                                      // https://github.com/gpuweb/gpuweb/issues/4115
752 
753         // TODO(anglebug.com/42267100): these interpolateAt*() are not builtin
754         case TOperator::EOpInterpolateAtCentroid:
755             return "TODO_Operator";
756         case TOperator::EOpInterpolateAtSample:
757             return "TODO_Operator";
758         case TOperator::EOpInterpolateAtOffset:
759             return "TODO_Operator";
760         case TOperator::EOpInterpolateAtCenter:
761             return "TODO_Operator";
762 
763         case TOperator::EOpFloatBitsToInt:
764         case TOperator::EOpFloatBitsToUint:
765         case TOperator::EOpIntBitsToFloat:
766         case TOperator::EOpUintBitsToFloat:
767         {
768 #define BITCAST_SCALAR()                   \
769     do                                     \
770         switch (resultType.getBasicType()) \
771         {                                  \
772             case TBasicType::EbtInt:       \
773                 return "bitcast<i32>";     \
774             case TBasicType::EbtUInt:      \
775                 return "bitcast<u32>";     \
776             case TBasicType::EbtFloat:     \
777                 return "bitcast<f32>";     \
778             default:                       \
779                 UNIMPLEMENTED();           \
780                 return "TOperator_TODO";   \
781         }                                  \
782     while (false)
783 
784 #define BITCAST_VECTOR(vecSize)                        \
785     do                                                 \
786         switch (resultType.getBasicType())             \
787         {                                              \
788             case TBasicType::EbtInt:                   \
789                 return "bitcast<vec" vecSize "<i32>>"; \
790             case TBasicType::EbtUInt:                  \
791                 return "bitcast<vec" vecSize "<u32>>"; \
792             case TBasicType::EbtFloat:                 \
793                 return "bitcast<vec" vecSize "<f32>>"; \
794             default:                                   \
795                 UNIMPLEMENTED();                       \
796                 return "TOperator_TODO";               \
797         }                                              \
798     while (false)
799 
800             if (resultType.isScalar())
801             {
802                 BITCAST_SCALAR();
803             }
804             else if (resultType.isVector())
805             {
806                 switch (resultType.getNominalSize())
807                 {
808                     case 2:
809                         BITCAST_VECTOR("2");
810                     case 3:
811                         BITCAST_VECTOR("3");
812                     case 4:
813                         BITCAST_VECTOR("4");
814                     default:
815                         UNREACHABLE();
816                         return nullptr;
817                 }
818             }
819             else
820             {
821                 UNIMPLEMENTED();
822                 return "TOperator_TODO";
823             }
824 
825 #undef BITCAST_SCALAR
826 #undef BITCAST_VECTOR
827         }
828 
829         case TOperator::EOpPackUnorm2x16:
830             return "pack2x16unorm";
831         case TOperator::EOpPackSnorm2x16:
832             return "pack2x16snorm";
833 
834         case TOperator::EOpPackUnorm4x8:
835             return "pack4x8unorm";
836         case TOperator::EOpPackSnorm4x8:
837             return "pack4x8snorm";
838 
839         case TOperator::EOpUnpackUnorm2x16:
840             return "unpack2x16unorm";
841         case TOperator::EOpUnpackSnorm2x16:
842             return "unpack2x16snorm";
843 
844         case TOperator::EOpUnpackUnorm4x8:
845             return "unpack4x8unorm";
846         case TOperator::EOpUnpackSnorm4x8:
847             return "unpack4x8snorm";
848 
849         case TOperator::EOpPackHalf2x16:
850             return "pack2x16float";
851         case TOperator::EOpUnpackHalf2x16:
852             return "unpack2x16float";
853 
854         case TOperator::EOpBarrier:
855             UNREACHABLE();
856             return "TOperator_TODO";
857         case TOperator::EOpMemoryBarrier:
858             // TODO(anglebug.com/42267100): does this exist in WGPU? Device-scoped memory barrier?
859             // Maybe storageBarrier()?
860             UNREACHABLE();
861             return "TOperator_TODO";
862         case TOperator::EOpGroupMemoryBarrier:
863             return "workgroupBarrier";
864         case TOperator::EOpMemoryBarrierAtomicCounter:
865         case TOperator::EOpMemoryBarrierBuffer:
866         case TOperator::EOpMemoryBarrierShared:
867             UNREACHABLE();
868             return "TOperator_TODO";
869         case TOperator::EOpAtomicAdd:
870             return "atomicAdd";
871         case TOperator::EOpAtomicMin:
872             return "atomicMin";
873         case TOperator::EOpAtomicMax:
874             return "atomicMax";
875         case TOperator::EOpAtomicAnd:
876             return "atomicAnd";
877         case TOperator::EOpAtomicOr:
878             return "atomicOr";
879         case TOperator::EOpAtomicXor:
880             return "atomicXor";
881         case TOperator::EOpAtomicExchange:
882             return "atomicExchange";
883         case TOperator::EOpAtomicCompSwap:
884             return "atomicCompareExchangeWeak";  // TODO(anglebug.com/42267100): returns a struct.
885         case TOperator::EOpBitfieldExtract:
886         case TOperator::EOpBitfieldInsert:
887         case TOperator::EOpBitfieldReverse:
888         case TOperator::EOpBitCount:
889         case TOperator::EOpFindLSB:
890         case TOperator::EOpFindMSB:
891         case TOperator::EOpUaddCarry:
892         case TOperator::EOpUsubBorrow:
893         case TOperator::EOpUmulExtended:
894         case TOperator::EOpImulExtended:
895         case TOperator::EOpEmitVertex:
896         case TOperator::EOpEndPrimitive:
897         case TOperator::EOpArrayLength:
898             UNIMPLEMENTED();
899             return "TOperator_TODO";
900 
901         case TOperator::EOpNull:
902         case TOperator::EOpConstruct:
903         case TOperator::EOpCallFunctionInAST:
904         case TOperator::EOpCallInternalRawFunction:
905         case TOperator::EOpIndexDirect:
906         case TOperator::EOpIndexIndirect:
907         case TOperator::EOpIndexDirectStruct:
908         case TOperator::EOpIndexDirectInterfaceBlock:
909             UNREACHABLE();
910             return nullptr;
911         default:
912             // Any other built-in function.
913             return nullptr;
914     }
915 }
916 
IsSymbolicOperator(TOperator op,const TType & resultType,const TType * argType0,const TType * argType1)917 bool IsSymbolicOperator(TOperator op,
918                         const TType &resultType,
919                         const TType *argType0,
920                         const TType *argType1)
921 {
922     const char *operatorString = GetOperatorString(op, resultType, argType0, argType1, nullptr);
923     if (operatorString == nullptr)
924     {
925         return false;
926     }
927     return !std::isalnum(operatorString[0]);
928 }
929 
getDirectField(const TIntermTyped & fieldsNode,TIntermTyped & indexNode)930 const TField &OutputWGSLTraverser::getDirectField(const TIntermTyped &fieldsNode,
931                                                   TIntermTyped &indexNode)
932 {
933     const TType &fieldsType = fieldsNode.getType();
934 
935     const TFieldListCollection *fieldListCollection = fieldsType.getStruct();
936     if (fieldListCollection == nullptr)
937     {
938         fieldListCollection = fieldsType.getInterfaceBlock();
939     }
940     ASSERT(fieldListCollection);
941 
942     const TIntermConstantUnion *indexNodeAsConstantUnion = indexNode.getAsConstantUnion();
943     ASSERT(indexNodeAsConstantUnion);
944     const TConstantUnion &index = *indexNodeAsConstantUnion->getConstantValue();
945 
946     ASSERT(index.getType() == TBasicType::EbtInt);
947 
948     const TFieldList &fieldList = fieldListCollection->fields();
949     const int indexVal          = index.getIConst();
950     const TField &field         = *fieldList[indexVal];
951 
952     return field;
953 }
954 
emitArrayIndex(TIntermTyped & leftNode,TIntermTyped & rightNode)955 void OutputWGSLTraverser::emitArrayIndex(TIntermTyped &leftNode, TIntermTyped &rightNode)
956 {
957     TType leftType = leftNode.getType();
958 
959     // Some arrays within the uniform address space have their element types wrapped in a struct
960     // when generating WGSL, so this unwraps the element (as an optimization of converting the
961     // entire array back to the unwrapped type).
962     bool needsUnwrapping                  = false;
963     bool isUniformMatrixNeedingConversion = false;
964     TIntermBinary *leftNodeBinary         = leftNode.getAsBinaryNode();
965     if (leftNodeBinary && leftNodeBinary->getOp() == TOperator::EOpIndexDirectStruct)
966     {
967         const TStructure *structure = leftNodeBinary->getLeft()->getType().getStruct();
968 
969         bool isInUniformAddressSpace =
970             mUniformBlockMetadata->structsInUniformAddressSpace.count(structure->uniqueId().get());
971 
972         needsUnwrapping =
973             structure && ElementTypeNeedsUniformWrapperStruct(isInUniformAddressSpace, &leftType);
974 
975         isUniformMatrixNeedingConversion = isInUniformAddressSpace && IsMatCx2(&leftType);
976 
977         ASSERT(!needsUnwrapping || !isUniformMatrixNeedingConversion);
978     }
979 
980     // Emit the left side, which should be of type array.
981     if (needsUnwrapping || isUniformMatrixNeedingConversion)
982     {
983         if (isUniformMatrixNeedingConversion)
984         {
985             // If this array index expression is yielding an std140 matCx2 (i.e.
986             // array<ANGLE_wrapped_vec2, C>), just convert the entire expression to a WGSL matCx2,
987             // instead of converting the entire array of std140 matCx2s into an array of WGSL
988             // matCx2s and then indexing into it.
989             TType baseType = leftType;
990             baseType.toArrayBaseType();
991             mSink << MakeMatCx2ConversionFunctionName(&baseType) << "(";
992             // Make sure the conversion function referenced here is actually generated in the
993             // resulting WGSL.
994             mWGSLGenerationMetadataForUniforms->outputMatCx2Conversion.insert(baseType);
995         }
996         emitStructIndexNoUnwrapping(leftNodeBinary);
997     }
998     else
999     {
1000         groupedTraverse(leftNode);
1001     }
1002 
1003     mSink << "[";
1004     const TConstantUnion *constIndex = rightNode.getConstantValue();
1005     // If the array index is a constant that we can statically verify is within array
1006     // bounds, just emit that constant.
1007     if (!leftType.isUnsizedArray() && constIndex != nullptr && constIndex->getType() == EbtInt &&
1008         constIndex->getIConst() >= 0 &&
1009         constIndex->getIConst() < static_cast<int>(leftType.isArray()
1010                                                        ? leftType.getOutermostArraySize()
1011                                                        : leftType.getNominalSize()))
1012     {
1013         emitSingleConstant(constIndex);
1014     }
1015     else
1016     {
1017         // If the array index is not a constant within the bounds of the array, clamp the
1018         // index.
1019         mSink << "clamp(";
1020         groupedTraverse(rightNode);
1021         mSink << ", 0, ";
1022         // Now find the array size and clamp it.
1023         if (leftType.isUnsizedArray())
1024         {
1025             // TODO(anglebug.com/42267100): This is a bug to traverse the `leftNode` a
1026             // second time if `leftNode` has side effects (and could also have performance
1027             // implications). This should be stored in a temporary variable. This might also
1028             // be a bug in the MSL shader compiler.
1029             mSink << "arrayLength(&";
1030             groupedTraverse(leftNode);
1031             mSink << ")";
1032         }
1033         else
1034         {
1035             uint32_t maxSize;
1036             if (leftType.isArray())
1037             {
1038                 maxSize = leftType.getOutermostArraySize() - 1;
1039             }
1040             else
1041             {
1042                 maxSize = leftType.getNominalSize() - 1;
1043             }
1044             mSink << maxSize;
1045         }
1046         // End the clamp() function.
1047         mSink << ")";
1048     }
1049     // End the array index operation.
1050     mSink << "]";
1051 
1052     if (needsUnwrapping)
1053     {
1054         mSink << "." << kWrappedStructFieldName;
1055     }
1056     else if (isUniformMatrixNeedingConversion)
1057     {
1058         // Close conversion function call
1059         mSink << ")";
1060     }
1061 }
1062 
emitStructIndex(TIntermBinary * binaryNode)1063 void OutputWGSLTraverser::emitStructIndex(TIntermBinary *binaryNode)
1064 {
1065     ASSERT(binaryNode->getOp() == TOperator::EOpIndexDirectStruct);
1066     TIntermTyped &leftNode  = *binaryNode->getLeft();
1067     const TType *binaryNodeType = &binaryNode->getType();
1068 
1069     const TStructure *structure = leftNode.getType().getStruct();
1070     ASSERT(structure);
1071 
1072     bool isInUniformAddressSpace =
1073         mUniformBlockMetadata->structsInUniformAddressSpace.count(structure->uniqueId().get());
1074 
1075     bool isUniformMatrixNeedingConversion = isInUniformAddressSpace && IsMatCx2(binaryNodeType);
1076 
1077     bool needsUnwrapping =
1078         ElementTypeNeedsUniformWrapperStruct(isInUniformAddressSpace, binaryNodeType);
1079     if (needsUnwrapping)
1080     {
1081         ASSERT(!isUniformMatrixNeedingConversion);
1082 
1083         mSink << MakeUnwrappingArrayConversionFunctionName(&binaryNode->getType()) << "(";
1084         // Make sure the conversion function referenced here is actually generated in the resulting
1085         // WGSL.
1086         mWGSLGenerationMetadataForUniforms->arrayElementTypesThatNeedUnwrappingConversions.insert(
1087             *binaryNodeType);
1088     }
1089     else if (isUniformMatrixNeedingConversion)
1090     {
1091         mSink << MakeMatCx2ConversionFunctionName(binaryNodeType) << "(";
1092         // Make sure the conversion function referenced here is actually generated in the resulting
1093         // WGSL.
1094         mWGSLGenerationMetadataForUniforms->outputMatCx2Conversion.insert(*binaryNodeType);
1095     }
1096     emitStructIndexNoUnwrapping(binaryNode);
1097     if (needsUnwrapping || isUniformMatrixNeedingConversion)
1098     {
1099         mSink << ")";
1100     }
1101 }
1102 
emitStructIndexNoUnwrapping(TIntermBinary * binaryNode)1103 void OutputWGSLTraverser::emitStructIndexNoUnwrapping(TIntermBinary *binaryNode)
1104 {
1105     ASSERT(binaryNode->getOp() == TOperator::EOpIndexDirectStruct);
1106     TIntermTyped &leftNode  = *binaryNode->getLeft();
1107     TIntermTyped &rightNode = *binaryNode->getRight();
1108 
1109     groupedTraverse(leftNode);
1110     mSink << ".";
1111     WriteNameOf(mSink, getDirectField(leftNode, rightNode));
1112 }
1113 
visitBinary(Visit,TIntermBinary * binaryNode)1114 bool OutputWGSLTraverser::visitBinary(Visit, TIntermBinary *binaryNode)
1115 {
1116     const TOperator op      = binaryNode->getOp();
1117     TIntermTyped &leftNode  = *binaryNode->getLeft();
1118     TIntermTyped &rightNode = *binaryNode->getRight();
1119 
1120     switch (op)
1121     {
1122         case TOperator::EOpIndexDirectStruct:
1123         case TOperator::EOpIndexDirectInterfaceBlock:
1124             emitStructIndex(binaryNode);
1125             break;
1126 
1127         case TOperator::EOpIndexDirect:
1128         case TOperator::EOpIndexIndirect:
1129             emitArrayIndex(leftNode, rightNode);
1130             break;
1131 
1132         default:
1133         {
1134             const TType &resultType = binaryNode->getType();
1135             const TType &leftType   = leftNode.getType();
1136             const TType &rightType  = rightNode.getType();
1137 
1138             // x * y, x ^ y, etc.
1139             if (IsSymbolicOperator(op, resultType, &leftType, &rightType))
1140             {
1141                 groupedTraverse(leftNode);
1142                 if (op != TOperator::EOpComma)
1143                 {
1144                     mSink << " ";
1145                 }
1146                 mSink << GetOperatorString(op, resultType, &leftType, &rightType, nullptr) << " ";
1147                 groupedTraverse(rightNode);
1148             }
1149             // E.g. builtin function calls
1150             else
1151             {
1152                 mSink << GetOperatorString(op, resultType, &leftType, &rightType, nullptr) << "(";
1153                 leftNode.traverse(this);
1154                 mSink << ", ";
1155                 rightNode.traverse(this);
1156                 mSink << ")";
1157             }
1158         }
1159     }
1160 
1161     return false;
1162 }
1163 
IsPostfix(TOperator op)1164 bool IsPostfix(TOperator op)
1165 {
1166     switch (op)
1167     {
1168         case TOperator::EOpPostIncrement:
1169         case TOperator::EOpPostDecrement:
1170             return true;
1171 
1172         default:
1173             return false;
1174     }
1175 }
1176 
visitUnary(Visit,TIntermUnary * unaryNode)1177 bool OutputWGSLTraverser::visitUnary(Visit, TIntermUnary *unaryNode)
1178 {
1179     const TOperator op      = unaryNode->getOp();
1180     const TType &resultType = unaryNode->getType();
1181 
1182     TIntermTyped &arg    = *unaryNode->getOperand();
1183     const TType &argType = arg.getType();
1184 
1185     const char *name = GetOperatorString(op, resultType, &argType, nullptr, nullptr);
1186 
1187     // Examples: -x, ~x, ~x
1188     if (IsSymbolicOperator(op, resultType, &argType, nullptr))
1189     {
1190         const bool postfix = IsPostfix(op);
1191         if (!postfix)
1192         {
1193             mSink << name;
1194         }
1195         groupedTraverse(arg);
1196         if (postfix)
1197         {
1198             mSink << name;
1199         }
1200     }
1201     else
1202     {
1203         mSink << name << "(";
1204         arg.traverse(this);
1205         mSink << ")";
1206     }
1207 
1208     return false;
1209 }
1210 
visitTernary(Visit,TIntermTernary * conditionalNode)1211 bool OutputWGSLTraverser::visitTernary(Visit, TIntermTernary *conditionalNode)
1212 {
1213     // WGSL does not have a ternary. https://github.com/gpuweb/gpuweb/issues/3747
1214     // The select() builtin is not short circuiting. Maybe we can get if () {} else {} as an
1215     // expression, which would also solve the comma operator problem.
1216     // TODO(anglebug.com/42267100): as mentioned above this is not correct if the operands have side
1217     // effects. Even if they don't have side effects it could have performance implications.
1218     // It also doesn't work with all types that ternaries do, e.g. arrays or structs.
1219     mSink << "select(";
1220     groupedTraverse(*conditionalNode->getFalseExpression());
1221     mSink << ", ";
1222     groupedTraverse(*conditionalNode->getTrueExpression());
1223     mSink << ", ";
1224     groupedTraverse(*conditionalNode->getCondition());
1225     mSink << ")";
1226 
1227     return false;
1228 }
1229 
visitIfElse(Visit,TIntermIfElse * ifThenElseNode)1230 bool OutputWGSLTraverser::visitIfElse(Visit, TIntermIfElse *ifThenElseNode)
1231 {
1232     TIntermTyped &condNode = *ifThenElseNode->getCondition();
1233     TIntermBlock *thenNode = ifThenElseNode->getTrueBlock();
1234     TIntermBlock *elseNode = ifThenElseNode->getFalseBlock();
1235 
1236     mSink << "if (";
1237     condNode.traverse(this);
1238     mSink << ")";
1239 
1240     if (thenNode)
1241     {
1242         mSink << "\n";
1243         thenNode->traverse(this);
1244     }
1245     else
1246     {
1247         mSink << " {}";
1248     }
1249 
1250     if (elseNode)
1251     {
1252         mSink << "\n";
1253         emitIndentation();
1254         mSink << "else\n";
1255         elseNode->traverse(this);
1256     }
1257 
1258     return false;
1259 }
1260 
visitSwitch(Visit,TIntermSwitch * switchNode)1261 bool OutputWGSLTraverser::visitSwitch(Visit, TIntermSwitch *switchNode)
1262 {
1263     TIntermBlock &stmtList = *switchNode->getStatementList();
1264 
1265     emitIndentation();
1266     mSink << "switch ";
1267     switchNode->getInit()->traverse(this);
1268     mSink << "\n";
1269 
1270     emitOpenBrace();
1271 
1272     // TODO(anglebug.com/42267100): Case statements that fall through need to combined into a single
1273     // case statement with multiple labels.
1274 
1275     const size_t stmtCount = stmtList.getChildCount();
1276     bool inCaseList        = false;
1277     size_t currStmt        = 0;
1278     while (currStmt < stmtCount)
1279     {
1280         TIntermNode &stmtNode = *stmtList.getChildNode(currStmt);
1281         TIntermCase *caseNode = stmtNode.getAsCaseNode();
1282         if (caseNode)
1283         {
1284             if (inCaseList)
1285             {
1286                 mSink << ", ";
1287             }
1288             else
1289             {
1290                 emitIndentation();
1291                 mSink << "case ";
1292                 inCaseList = true;
1293             }
1294             caseNode->traverse(this);
1295 
1296             // Process the next statement.
1297             currStmt++;
1298         }
1299         else
1300         {
1301             // The current statement is not a case statement, end the current case list and emit all
1302             // the code until the next case statement. WGSL requires braces around the case
1303             // statement's code.
1304             ASSERT(inCaseList);
1305             inCaseList = false;
1306             mSink << ":\n";
1307 
1308             // Count the statements until the next case (or the end of the switch) and emit them as
1309             // a block. This assumes that the current statement list will never fallthrough to the
1310             // next case statement.
1311             size_t nextCaseStmt = currStmt + 1;
1312             for (;
1313                  nextCaseStmt < stmtCount && !stmtList.getChildNode(nextCaseStmt)->getAsCaseNode();
1314                  nextCaseStmt++)
1315             {
1316             }
1317             angle::Span<TIntermNode *> stmtListView(&stmtList.getSequence()->at(currStmt),
1318                                                     nextCaseStmt - currStmt);
1319             emitBlock(stmtListView);
1320             mSink << "\n";
1321 
1322             // Skip to the next case statement.
1323             currStmt = nextCaseStmt;
1324         }
1325     }
1326 
1327     emitCloseBrace();
1328 
1329     return false;
1330 }
1331 
visitCase(Visit,TIntermCase * caseNode)1332 bool OutputWGSLTraverser::visitCase(Visit, TIntermCase *caseNode)
1333 {
1334     // "case" will have been emitted in the visitSwitch() override.
1335 
1336     if (caseNode->hasCondition())
1337     {
1338         TIntermTyped *condExpr = caseNode->getCondition();
1339         condExpr->traverse(this);
1340     }
1341     else
1342     {
1343         mSink << "default";
1344     }
1345 
1346     return false;
1347 }
1348 
emitFunctionReturn(const TFunction & func)1349 void OutputWGSLTraverser::emitFunctionReturn(const TFunction &func)
1350 {
1351     const TType &returnType = func.getReturnType();
1352     if (returnType.getBasicType() == EbtVoid)
1353     {
1354         return;
1355     }
1356     mSink << " -> ";
1357     emitType(returnType);
1358 }
1359 
1360 // TODO(anglebug.com/42267100): Function overloads are not supported in WGSL, so function names
1361 // should either be emitted mangled or overloaded functions should be renamed in the AST as a
1362 // pre-pass. As of Apr 2024, WGSL function overloads are "not coming soon"
1363 // (https://github.com/gpuweb/gpuweb/issues/876).
emitFunctionSignature(const TFunction & func)1364 void OutputWGSLTraverser::emitFunctionSignature(const TFunction &func)
1365 {
1366     mSink << "fn ";
1367 
1368     WriteNameOf(mSink, func);
1369     mSink << "(";
1370 
1371     bool emitComma          = false;
1372     const size_t paramCount = func.getParamCount();
1373     for (size_t i = 0; i < paramCount; ++i)
1374     {
1375         if (emitComma)
1376         {
1377             mSink << ", ";
1378         }
1379         emitComma = true;
1380 
1381         const TVariable &param = *func.getParam(i);
1382         emitFunctionParameter(func, param);
1383     }
1384 
1385     mSink << ")";
1386 
1387     emitFunctionReturn(func);
1388 }
1389 
emitFunctionParameter(const TFunction & func,const TVariable & param)1390 void OutputWGSLTraverser::emitFunctionParameter(const TFunction &func, const TVariable &param)
1391 {
1392     // TODO(anglebug.com/42267100): function parameters are immutable and will need to be renamed if
1393     // they are mutated.
1394     EmitVariableDeclarationConfig evdConfig;
1395     evdConfig.isParameter = true;
1396     emitVariableDeclaration({param.symbolType(), param.name(), param.getType()}, evdConfig);
1397 }
1398 
visitFunctionPrototype(TIntermFunctionPrototype * funcProtoNode)1399 void OutputWGSLTraverser::visitFunctionPrototype(TIntermFunctionPrototype *funcProtoNode)
1400 {
1401     const TFunction &func = *funcProtoNode->getFunction();
1402 
1403     emitIndentation();
1404     // TODO(anglebug.com/42267100): output correct signature for main() if main() is declared as a
1405     // function prototype, or perhaps just emit nothing.
1406     emitFunctionSignature(func);
1407 }
1408 
visitFunctionDefinition(Visit,TIntermFunctionDefinition * funcDefNode)1409 bool OutputWGSLTraverser::visitFunctionDefinition(Visit, TIntermFunctionDefinition *funcDefNode)
1410 {
1411     const TFunction &func = *funcDefNode->getFunction();
1412     TIntermBlock &body    = *funcDefNode->getBody();
1413 
1414     emitIndentation();
1415     emitFunctionSignature(func);
1416     mSink << "\n";
1417     body.traverse(this);
1418 
1419     return false;
1420 }
1421 
emitTextureBuiltin(const TOperator op,const TIntermSequence & args)1422 void OutputWGSLTraverser::emitTextureBuiltin(const TOperator op, const TIntermSequence &args)
1423 {
1424 
1425     ASSERT(BuiltInGroup::IsTexture(op));
1426 
1427     // The index in the GLSL function's argument list of each particular argument, e.g. bias.
1428     // `bias`, `lod`, `offset`, and `P` (the coordinates) are the common arguments to most texture
1429     // functions.
1430     size_t biasIndex   = 0;
1431     size_t lodIndex    = 0;
1432     size_t offsetIndex = 0;
1433     size_t pIndex      = 0;
1434 
1435     size_t dpdxIndex = 0;
1436     size_t dpdyIndex = 0;
1437 
1438     // TODO(anglebug.com/389145696): These are probably incorrect translations when sampling from
1439     // integer or unsigned integer samplers. Using texture() with a usampler
1440     // is similar to using texelFetch(), except wrap modes are respected. Possibly, the correct mip
1441     // levels are also selected.
1442 
1443     // The name of the equivalent texture function in WGSL.
1444     ImmutableString wgslFunctionName("");
1445     // GLSL stuffs 1, 2, or 3 arguments into a single vector. These represent the swizzles necessary
1446     // for extracting each argument from P to pass to the appropriate WGSL function.
1447     ImmutableString coordsSwizzle("");
1448     ImmutableString arrayIndexSwizzle("");
1449     ImmutableString depthRefSwizzle("");
1450     // For the projection forms of the texture builtins, the last coordinate will divide the other
1451     // three. This is just a swizzle for the last coordinate if the builtin call includes
1452     // projection.
1453     ImmutableString projectionDivisionSwizzle("");
1454 
1455     ImmutableString wgslTextureVarName("");
1456     ImmutableString wgslSamplerVarName("");
1457 
1458     constexpr char k2DCoordsSwizzle[] = ".xy";
1459     constexpr char k3DCoordsSwizzle[] = ".xyz";
1460 
1461     constexpr char kPossibleElems[] = "xyzw";
1462 
1463     // MonomorphizeUnsupportedFunctions() and RewriteStructSamplers() ensure that this is a
1464     // reference to the global sampler.
1465     TIntermSymbol *samplerNode = args[0]->getAsSymbolNode();
1466     // TODO(anglebug.com/389145696): this will fail if it's an array of samplers, which isn't yet
1467     // handled.
1468     if (!samplerNode)
1469     {
1470         UNIMPLEMENTED();
1471         mSink << "TODO_UNHANDLED_TEXTURE_FUNCTION()";
1472         return;
1473     }
1474     TBasicType samplerType = samplerNode->getType().getBasicType();
1475     ASSERT(IsSampler(samplerType));
1476 
1477     bool isProj = false;
1478 
1479     auto setWgslTextureVarName = [&]() {
1480         wgslTextureVarName =
1481             BuildConcatenatedImmutableString(kAngleTexturePrefix, samplerNode->getName());
1482     };
1483 
1484     auto setWgslSamplerVarName = [&]() {
1485         wgslSamplerVarName =
1486             BuildConcatenatedImmutableString(kAngleSamplerPrefix, samplerNode->getName());
1487     };
1488 
1489     auto setTextureSampleFunctionNameFromBias = [&]() {
1490         // TODO(anglebug.com/389145696): these are incorrect translations in vertex shaders, where
1491         // they should probably use textureLoad() (and textureDimensions()).
1492         if (IsShadowSampler(samplerType))
1493         {
1494             if (biasIndex != 0)
1495             {
1496                 // TODO(anglebug.com/389145696): WGSL doesn't support using bias with shadow
1497                 // samplers.
1498                 UNIMPLEMENTED();
1499                 wgslFunctionName = ImmutableString("TODO_CANNOT_USE_BIAS_WITH_SHADOW_SAMPLER");
1500             }
1501             else
1502             {
1503                 wgslFunctionName = ImmutableString("textureSampleCompare");
1504             }
1505         }
1506         else
1507         {
1508             if (biasIndex == 0)
1509             {
1510                 wgslFunctionName = ImmutableString("textureSample");
1511             }
1512             else
1513             {
1514 
1515                 wgslFunctionName = ImmutableString("textureSampleBias");
1516             }
1517         }
1518     };
1519 
1520     switch (op)
1521     {
1522         case EOpTextureSize:
1523         {
1524             lodIndex = 1;
1525             ASSERT(args.size() == 2);
1526 
1527             wgslFunctionName = ImmutableString("textureDimensions");
1528             setWgslTextureVarName();
1529         }
1530         break;
1531 
1532         case EOpTexelFetchOffset:
1533         case EOpTexelFetch:
1534         {
1535             pIndex   = 1;
1536             lodIndex = 2;
1537             if (args.size() == 4)
1538             {
1539                 offsetIndex = 3;
1540             }
1541             ASSERT(args.size() == 3 || args.size() == 4);
1542 
1543             wgslFunctionName = ImmutableString("textureLoad");
1544             setWgslTextureVarName();
1545         }
1546         break;
1547 
1548         // texture() use to be split into texture2D() and textureCube(). WGSL matches GLSL 3.0 and
1549         // combines them.
1550         case EOpTextureProj:
1551         case EOpTexture2DProj:
1552         case EOpTextureProjBias:
1553         case EOpTexture2DProjBias:
1554             isProj = true;
1555             [[fallthrough]];
1556         case EOpTexture:
1557         case EOpTexture2D:
1558         case EOpTextureCube:
1559         case EOpTextureBias:
1560         case EOpTexture2DBias:
1561         case EOpTextureCubeBias:
1562         {
1563             pIndex = 1;
1564             if (args.size() == 3)
1565             {
1566                 biasIndex = 2;
1567             }
1568             ASSERT(args.size() == 2 || args.size() == 3);
1569 
1570             setTextureSampleFunctionNameFromBias();
1571             setWgslTextureVarName();
1572             setWgslSamplerVarName();
1573         }
1574         break;
1575 
1576         case EOpTextureProjLod:
1577         case EOpTextureProjLodOffset:
1578         case EOpTexture2DProjLodVS:
1579         case EOpTexture2DProjLodEXTFS:
1580             isProj = true;
1581             [[fallthrough]];
1582         case EOpTextureLod:
1583         case EOpTexture2DLodVS:
1584         case EOpTextureCubeLodVS:
1585         case EOpTexture2DLodEXTFS:
1586         case EOpTextureCubeLodEXTFS:
1587         case EOpTextureLodOffset:
1588         {
1589             pIndex   = 1;
1590             lodIndex = 2;
1591             if (args.size() == 4)
1592             {
1593                 offsetIndex = 3;
1594             }
1595             ASSERT(args.size() == 3 || args.size() == 4);
1596 
1597             if (IsShadowSampler(samplerType))
1598             {
1599                 // TODO(anglebug.com/389145696): WGSL may not support explicit LOD with shadow
1600                 // samplers. textureSampleCompareLevel() only uses mip level 0.
1601                 UNIMPLEMENTED();
1602                 wgslFunctionName =
1603                     ImmutableString("TODO_CANNOT_USE_EXPLICIT_LOD_WITH_SHADOW_SAMPLER");
1604             }
1605             else
1606             {
1607                 wgslFunctionName = ImmutableString("textureSampleLevel");
1608             }
1609             setWgslTextureVarName();
1610             setWgslSamplerVarName();
1611         }
1612         break;
1613 
1614         case EOpTextureProjOffset:
1615         case EOpTextureProjOffsetBias:
1616             isProj = true;
1617             [[fallthrough]];
1618         case EOpTextureOffset:
1619         case EOpTextureOffsetBias:
1620         {
1621             pIndex      = 1;
1622             offsetIndex = 2;
1623             if (args.size() == 4)
1624             {
1625                 biasIndex = 3;
1626             }
1627             ASSERT(args.size() == 3 || args.size() == 4);
1628 
1629             setTextureSampleFunctionNameFromBias();
1630             setWgslTextureVarName();
1631             setWgslSamplerVarName();
1632         }
1633         break;
1634 
1635         case EOpTextureProjGrad:
1636         case EOpTextureProjGradOffset:
1637             isProj = true;
1638             [[fallthrough]];
1639         case EOpTextureGrad:
1640         case EOpTextureGradOffset:
1641         {
1642             pIndex    = 1;
1643             dpdxIndex = 2;
1644             dpdyIndex = 3;
1645             if (args.size() == 5)
1646             {
1647                 offsetIndex = 4;
1648             }
1649             ASSERT(args.size() == 4 || args.size() == 5);
1650 
1651             if (IsShadowSampler(samplerType))
1652             {
1653                 // TODO(anglebug.com/389145696): WGSL may not support explicit gradients with shadow
1654                 // samplers.
1655                 UNIMPLEMENTED();
1656                 wgslFunctionName =
1657                     ImmutableString("TODO_CANNOT_USE_EXPLICIT_GRAD_WITH_SHADOW_SAMPLER");
1658             }
1659             else
1660             {
1661                 wgslFunctionName = ImmutableString("textureSampleGrad");
1662             }
1663             setWgslTextureVarName();
1664             setWgslSamplerVarName();
1665         }
1666         break;
1667 
1668         default:
1669             UNIMPLEMENTED();
1670             mSink << "TODO_UNHANDLED_TEXTURE_FUNCTION()";
1671             return;
1672     }
1673 
1674     mSink << wgslFunctionName << "(";
1675 
1676     ASSERT(!wgslTextureVarName.empty());
1677     mSink << wgslTextureVarName;
1678 
1679     if (!wgslSamplerVarName.empty())
1680     {
1681         mSink << ", " << wgslSamplerVarName;
1682 
1683         // If using a projection division, set the swizzle that extracts the last argument from the
1684         // p vector.
1685         if (isProj)
1686         {
1687             ASSERT(pIndex == 1);
1688             const uint8_t vecSize = args[pIndex]->getAsTyped()->getNominalSize();
1689             ASSERT(vecSize == 3 || vecSize == 4);
1690             projectionDivisionSwizzle =
1691                 BuildConcatenatedImmutableString('.', kPossibleElems[vecSize - 1]);
1692         }
1693 
1694         // If sampling from an array, set the swizzle that extracts the array layer number from the
1695         // p vector.
1696         if (IsSampler2DArray(samplerType))
1697         {
1698             arrayIndexSwizzle = ImmutableString(".z");
1699         }
1700 
1701         // If sampling from a shadow samplers, set the swizzle that extracts the D_ref argument from
1702         // the p vector.
1703         if (IsShadowSampler(samplerType))
1704         {
1705             size_t elemIndex = 0;
1706             if (IsSampler2D(samplerType))
1707             {
1708                 elemIndex = 2;
1709             }
1710             else if (IsSampler2DArray(samplerType) || IsSampler3D(samplerType) ||
1711                      IsSamplerCube(samplerType))
1712             {
1713                 elemIndex = 3;
1714             }
1715 
1716             depthRefSwizzle = BuildConcatenatedImmutableString('.', kPossibleElems[elemIndex]);
1717         }
1718 
1719         // Finally, set the swizzle for extracting coordinates from the p vector.
1720         if (IsSampler2D(samplerType) || IsSampler2DArray(samplerType))
1721         {
1722             coordsSwizzle = ImmutableString(k2DCoordsSwizzle);
1723         }
1724         else if (IsSampler3D(samplerType) || IsSamplerCube(samplerType))
1725         {
1726             coordsSwizzle = ImmutableString(k3DCoordsSwizzle);
1727         }
1728     }
1729 
1730     // TODO(anglebug.com/389145696): traversing the pArg multiple times is an error if it ever
1731     // contains side effects (e.g. a function call). There is also a problem if this traverses
1732     // function arguments in a different order, arguments with side effects that effect arguments
1733     // that come later may be reordered incorrectly. ESSL specs defined function argument evaluation
1734     // as left-to-right.
1735     auto traversePArg = [&]() {
1736         mSink << "(";
1737         ASSERT(pIndex != 0);
1738         args[pIndex]->traverse(this);
1739         mSink << ")";
1740     };
1741 
1742     auto outputProjectionDivisionIfNecessary = [&]() {
1743         if (projectionDivisionSwizzle.empty())
1744         {
1745             return;
1746         }
1747         mSink << " / ";
1748         traversePArg();
1749         mSink << projectionDivisionSwizzle;
1750     };
1751 
1752     // The arguments to the WGSL function always appear in a certain (partial) order, so output them
1753     // in that order.
1754     //
1755     // The order is always
1756     // - texture
1757     // - sampler
1758     // - coordinates
1759     // - array layer index
1760     // - depth_ref, bias, explicit level of detail (never appear together)
1761     // - dfdx
1762     // - dfdy
1763     // - offset
1764     //
1765     // See the texture builtin functions in the WGSL spec:
1766     // https://www.w3.org/TR/WGSL/#texture-builtin-functions
1767     //
1768     // For example
1769     // @must_use fn textureSampleLevel(t: texture_2d_array<f32>,
1770     //                             s: sampler,
1771     //                             coords: vec2<f32>,
1772     //                             array_index: A,
1773     //                             level: f32,
1774     //                             offset: vec2<i32>) -> vec4<f32>
1775 
1776     if (pIndex != 0)
1777     {
1778         mSink << ", ";
1779         traversePArg();
1780         mSink << coordsSwizzle;
1781         outputProjectionDivisionIfNecessary();
1782     }
1783 
1784     if (!arrayIndexSwizzle.empty())
1785     {
1786         mSink << ", ";
1787         traversePArg();
1788         mSink << arrayIndexSwizzle;
1789     }
1790 
1791     if (!depthRefSwizzle.empty())
1792     {
1793         mSink << ", ";
1794         traversePArg();
1795         mSink << depthRefSwizzle;
1796         outputProjectionDivisionIfNecessary();
1797     }
1798 
1799     if (biasIndex != 0)
1800     {
1801         mSink << ", ";
1802         args[biasIndex]->traverse(this);
1803     }
1804 
1805     if (lodIndex != 0)
1806     {
1807         mSink << ", ";
1808         args[lodIndex]->traverse(this);
1809     }
1810 
1811     if (dpdxIndex != 0)
1812     {
1813         mSink << ", ";
1814         args[dpdxIndex]->traverse(this);
1815     }
1816 
1817     if (dpdyIndex != 0)
1818     {
1819         mSink << ", ";
1820         args[dpdyIndex]->traverse(this);
1821     }
1822 
1823     if (offsetIndex != 0)
1824     {
1825         mSink << ", ";
1826         // Both GLSL and WGSL require this to be a const expression.
1827         args[offsetIndex]->traverse(this);
1828     }
1829 
1830     mSink << ")";
1831 }
1832 
visitAggregate(Visit,TIntermAggregate * aggregateNode)1833 bool OutputWGSLTraverser::visitAggregate(Visit, TIntermAggregate *aggregateNode)
1834 {
1835     const TIntermSequence &args = *aggregateNode->getSequence();
1836 
1837     auto emitArgList = [&]() {
1838         mSink << "(";
1839 
1840         bool emitComma = false;
1841         for (TIntermNode *arg : args)
1842         {
1843             if (emitComma)
1844             {
1845                 mSink << ", ";
1846             }
1847             emitComma = true;
1848             arg->traverse(this);
1849         }
1850 
1851         mSink << ")";
1852     };
1853 
1854     const TType &retType = aggregateNode->getType();
1855 
1856     if (aggregateNode->isConstructor())
1857     {
1858 
1859         emitType(retType);
1860         emitArgList();
1861 
1862         return false;
1863     }
1864     else
1865     {
1866         const TOperator op = aggregateNode->getOp();
1867         switch (op)
1868         {
1869             case TOperator::EOpCallFunctionInAST:
1870                 WriteNameOf(mSink, *aggregateNode->getFunction());
1871                 emitArgList();
1872                 return false;
1873 
1874             default:
1875                 // Do not allow raw function calls, i.e. calls to functions
1876                 // not present in the AST.
1877                 ASSERT(op != TOperator::EOpCallInternalRawFunction);
1878                 auto getArgType = [&](size_t index) -> const TType * {
1879                     if (index < args.size())
1880                     {
1881                         TIntermTyped *arg = args[index]->getAsTyped();
1882                         ASSERT(arg);
1883                         return &arg->getType();
1884                     }
1885                     return nullptr;
1886                 };
1887 
1888                 const TType *argType0 = getArgType(0);
1889                 const TType *argType1 = getArgType(1);
1890                 const TType *argType2 = getArgType(2);
1891 
1892                 const char *opName = GetOperatorString(op, retType, argType0, argType1, argType2);
1893 
1894                 if (IsSymbolicOperator(op, retType, argType0, argType1))
1895                 {
1896                     switch (args.size())
1897                     {
1898                         case 1:
1899                         {
1900                             TIntermNode &operandNode = *aggregateNode->getChildNode(0);
1901                             if (IsPostfix(op))
1902                             {
1903                                 mSink << opName;
1904                                 groupedTraverse(operandNode);
1905                             }
1906                             else
1907                             {
1908                                 groupedTraverse(operandNode);
1909                                 mSink << opName;
1910                             }
1911                             return false;
1912                         }
1913 
1914                         case 2:
1915                         {
1916                             // symbolic operators with 2 args are emitted with infix notation.
1917                             TIntermNode &leftNode  = *aggregateNode->getChildNode(0);
1918                             TIntermNode &rightNode = *aggregateNode->getChildNode(1);
1919                             groupedTraverse(leftNode);
1920                             mSink << " " << opName << " ";
1921                             groupedTraverse(rightNode);
1922                             return false;
1923                         }
1924 
1925                         default:
1926                             UNREACHABLE();
1927                             return false;
1928                     }
1929                 }
1930                 else
1931                 {
1932                     // Rewrite the calls to sampler functions.
1933                     if (BuiltInGroup::IsTexture(op))
1934                     {
1935                         emitTextureBuiltin(op, args);
1936                         return false;
1937                     }
1938                     // If the operator is not symbolic then it is a builtin that uses function call
1939                     // syntax: builtin(arg1, arg2, ..);
1940                     mSink << (opName == nullptr ? "TODO_Operator" : opName);
1941                     emitArgList();
1942                     return false;
1943                 }
1944         }
1945     }
1946 }
1947 
emitBlock(angle::Span<TIntermNode * > nodes)1948 bool OutputWGSLTraverser::emitBlock(angle::Span<TIntermNode *> nodes)
1949 {
1950     ASSERT(mIndentLevel >= -1);
1951     const bool isGlobalScope = mIndentLevel == -1;
1952 
1953     if (isGlobalScope)
1954     {
1955         ++mIndentLevel;
1956     }
1957     else
1958     {
1959         emitOpenBrace();
1960     }
1961 
1962     TIntermNode *prevStmtNode = nullptr;
1963 
1964     const size_t stmtCount = nodes.size();
1965     for (size_t i = 0; i < stmtCount; ++i)
1966     {
1967         TIntermNode &stmtNode = *nodes[i];
1968 
1969         if (isGlobalScope && prevStmtNode && (NewlinePad(*prevStmtNode) || NewlinePad(stmtNode)))
1970         {
1971             mSink << "\n";
1972         }
1973         const bool isCase = stmtNode.getAsCaseNode();
1974         mIndentLevel -= isCase;
1975         emitIndentation();
1976         mIndentLevel += isCase;
1977         stmtNode.traverse(this);
1978         if (RequiresSemicolonTerminator(stmtNode))
1979         {
1980             mSink << ";";
1981         }
1982         mSink << "\n";
1983 
1984         prevStmtNode = &stmtNode;
1985     }
1986 
1987     if (isGlobalScope)
1988     {
1989         ASSERT(mIndentLevel == 0);
1990         --mIndentLevel;
1991     }
1992     else
1993     {
1994         emitCloseBrace();
1995     }
1996 
1997     return false;
1998 }
1999 
visitBlock(Visit,TIntermBlock * blockNode)2000 bool OutputWGSLTraverser::visitBlock(Visit, TIntermBlock *blockNode)
2001 {
2002     return emitBlock(
2003         angle::Span(blockNode->getSequence()->data(), blockNode->getSequence()->size()));
2004 }
2005 
visitGlobalQualifierDeclaration(Visit,TIntermGlobalQualifierDeclaration *)2006 bool OutputWGSLTraverser::visitGlobalQualifierDeclaration(Visit,
2007                                                           TIntermGlobalQualifierDeclaration *)
2008 {
2009     return false;
2010 }
2011 
emitStructDeclaration(const TType & type)2012 void OutputWGSLTraverser::emitStructDeclaration(const TType &type)
2013 {
2014     ASSERT(type.getBasicType() == TBasicType::EbtStruct);
2015     ASSERT(type.isStructSpecifier());
2016 
2017     mSink << "struct ";
2018     emitBareTypeName(type);
2019 
2020     mSink << "\n";
2021     emitOpenBrace();
2022 
2023     const TStructure &structure = *type.getStruct();
2024     bool isInUniformAddressSpace =
2025         mUniformBlockMetadata->structsInUniformAddressSpace.count(structure.uniqueId().get()) != 0;
2026 
2027     bool alignTo16InUniformAddressSpace = true;
2028     for (const TField *field : structure.fields())
2029     {
2030         const TType *fieldType = field->type();
2031 
2032         emitIndentation();
2033         // If this struct is used in the uniform address space, it must obey the uniform address
2034         // space's layout constaints (https://www.w3.org/TR/WGSL/#address-space-layout-constraints).
2035         // WGSL's address space layout constraints nearly match std140, and the places they don't
2036         // are handled elsewhere.
2037         if (isInUniformAddressSpace)
2038         {
2039             // Here, the field must be aligned to 16 if:
2040             // 1. The field is a struct or array (note that matCx2 is represented as an array of
2041             // vec2)
2042             // 2. The previous field is a struct
2043             // 3. The field is the first in the struct (for convenience).
2044             if (field->type()->getStruct() || fieldType->isArray() || IsMatCx2(fieldType))
2045             {
2046                 alignTo16InUniformAddressSpace = true;
2047             }
2048             if (alignTo16InUniformAddressSpace)
2049             {
2050                 mSink << "@align(16) ";
2051             }
2052 
2053             // If this field is a struct, the next member should be aligned to 16.
2054             alignTo16InUniformAddressSpace = fieldType->getStruct();
2055 
2056             // If the field is an array whose stride is not aligned to 16, the element type must be
2057             // emitted with a wrapper struct. Record that the wrapper struct needs to be emitted.
2058             // Note that if the array element type is already of struct type, it doesn't need
2059             // another wrapper struct, it will automatically be aligned to 16 because its first
2060             // member is aligned to 16 (implemented above).
2061             if (ElementTypeNeedsUniformWrapperStruct(/*inUniformAddressSpace=*/true, fieldType))
2062             {
2063                 TType innerType = *fieldType;
2064                 innerType.toArrayElementType();
2065                 // Multidimensional arrays not currently supported in uniforms in the WebGPU backend
2066                 ASSERT(!innerType.isArray());
2067                 mWGSLGenerationMetadataForUniforms->arrayElementTypesInUniforms.insert(innerType);
2068             }
2069         }
2070 
2071         // TODO(anglebug.com/42267100): emit qualifiers.
2072         EmitVariableDeclarationConfig evdConfig;
2073         evdConfig.typeConfig.addressSpace =
2074             isInUniformAddressSpace ? WgslAddressSpace::Uniform : WgslAddressSpace::NonUniform;
2075         evdConfig.disableStructSpecifier = true;
2076         emitVariableDeclaration({field->symbolType(), field->name(), *fieldType}, evdConfig);
2077         mSink << ",\n";
2078     }
2079 
2080     emitCloseBrace();
2081 }
2082 
emitVariableDeclaration(const VarDecl & decl,const EmitVariableDeclarationConfig & evdConfig)2083 void OutputWGSLTraverser::emitVariableDeclaration(const VarDecl &decl,
2084                                                   const EmitVariableDeclarationConfig &evdConfig)
2085 {
2086     const TBasicType basicType = decl.type.getBasicType();
2087 
2088     if (decl.type.getQualifier() == EvqUniform)
2089     {
2090         // Uniforms are declared in a pre-pass, and don't need to be outputted here.
2091         return;
2092     }
2093 
2094     if (basicType == TBasicType::EbtStruct && decl.type.isStructSpecifier() &&
2095         !evdConfig.disableStructSpecifier)
2096     {
2097         // TODO(anglebug.com/42267100): in WGSL structs probably can't be declared in
2098         // function parameters or in uniform declarations or in variable declarations, or
2099         // anonymously either within other structs or within a variable declaration. Handle
2100         // these with the same AST pre-passes as other shader translators.
2101         ASSERT(!evdConfig.isParameter);
2102         emitStructDeclaration(decl.type);
2103         if (decl.symbolType != SymbolType::Empty)
2104         {
2105             mSink << " ";
2106             emitNameOf(decl);
2107         }
2108         return;
2109     }
2110 
2111     ASSERT(basicType == TBasicType::EbtStruct || decl.symbolType != SymbolType::Empty ||
2112            evdConfig.isParameter);
2113 
2114     if (evdConfig.needsVar)
2115     {
2116         // "const" and "let" probably don't need to be ever emitted because they are more for
2117         // readability, and the GLSL compiler constant folds most (all?) the consts anyway.
2118         mSink << "var";
2119         // TODO(anglebug.com/42267100): <workgroup> or <storage>?
2120         if (evdConfig.isGlobalScope)
2121         {
2122             if (decl.type.getQualifier() == EvqUniform)
2123             {
2124                 ASSERT(IsOpaqueType(decl.type.getBasicType()));
2125                 mSink << "<uniform>";
2126             }
2127             else
2128             {
2129                 mSink << "<private>";
2130             }
2131         }
2132         mSink << " ";
2133     }
2134     else
2135     {
2136         ASSERT(!evdConfig.isGlobalScope);
2137     }
2138 
2139     if (decl.symbolType != SymbolType::Empty)
2140     {
2141         emitNameOf(decl);
2142     }
2143     mSink << " : ";
2144     WriteWgslType(mSink, decl.type, evdConfig.typeConfig);
2145 }
2146 
visitDeclaration(Visit,TIntermDeclaration * declNode)2147 bool OutputWGSLTraverser::visitDeclaration(Visit, TIntermDeclaration *declNode)
2148 {
2149     ASSERT(declNode->getChildCount() == 1);
2150     TIntermNode &node = *declNode->getChildNode(0);
2151 
2152     EmitVariableDeclarationConfig evdConfig;
2153     evdConfig.needsVar      = true;
2154     evdConfig.isGlobalScope = mIndentLevel == 0;
2155 
2156     if (TIntermSymbol *symbolNode = node.getAsSymbolNode())
2157     {
2158         const TVariable &var = symbolNode->variable();
2159         if (mRewritePipelineVarOutput->IsInputVar(var.uniqueId()) ||
2160             mRewritePipelineVarOutput->IsOutputVar(var.uniqueId()))
2161         {
2162             // Some variables, like shader inputs/outputs/builtins, are declared in the WGSL source
2163             // outside of the traverser.
2164             return false;
2165         }
2166         emitVariableDeclaration({var.symbolType(), var.name(), var.getType()}, evdConfig);
2167     }
2168     else if (TIntermBinary *initNode = node.getAsBinaryNode())
2169     {
2170         ASSERT(initNode->getOp() == TOperator::EOpInitialize);
2171         TIntermSymbol *leftSymbolNode = initNode->getLeft()->getAsSymbolNode();
2172         TIntermTyped *valueNode       = initNode->getRight()->getAsTyped();
2173         ASSERT(leftSymbolNode && valueNode);
2174 
2175         const TVariable &var = leftSymbolNode->variable();
2176         if (mRewritePipelineVarOutput->IsInputVar(var.uniqueId()) ||
2177             mRewritePipelineVarOutput->IsOutputVar(var.uniqueId()))
2178         {
2179             // Some variables, like shader inputs/outputs/builtins, are declared in the WGSL source
2180             // outside of the traverser.
2181             return false;
2182         }
2183 
2184         emitVariableDeclaration({var.symbolType(), var.name(), var.getType()}, evdConfig);
2185         mSink << " = ";
2186         groupedTraverse(*valueNode);
2187     }
2188     else
2189     {
2190         UNREACHABLE();
2191     }
2192 
2193     return false;
2194 }
2195 
visitLoop(Visit,TIntermLoop * loopNode)2196 bool OutputWGSLTraverser::visitLoop(Visit, TIntermLoop *loopNode)
2197 {
2198     const TLoopType loopType = loopNode->getType();
2199 
2200     switch (loopType)
2201     {
2202         case TLoopType::ELoopFor:
2203             return emitForLoop(loopNode);
2204         case TLoopType::ELoopWhile:
2205             return emitWhileLoop(loopNode);
2206         case TLoopType::ELoopDoWhile:
2207             return emulateDoWhileLoop(loopNode);
2208     }
2209 }
2210 
emitForLoop(TIntermLoop * loopNode)2211 bool OutputWGSLTraverser::emitForLoop(TIntermLoop *loopNode)
2212 {
2213     ASSERT(loopNode->getType() == TLoopType::ELoopFor);
2214 
2215     TIntermNode *initNode  = loopNode->getInit();
2216     TIntermTyped *condNode = loopNode->getCondition();
2217     TIntermTyped *exprNode = loopNode->getExpression();
2218 
2219     mSink << "for (";
2220 
2221     if (initNode)
2222     {
2223         initNode->traverse(this);
2224     }
2225     else
2226     {
2227         mSink << " ";
2228     }
2229 
2230     mSink << "; ";
2231 
2232     if (condNode)
2233     {
2234         condNode->traverse(this);
2235     }
2236 
2237     mSink << "; ";
2238 
2239     if (exprNode)
2240     {
2241         exprNode->traverse(this);
2242     }
2243 
2244     mSink << ")\n";
2245 
2246     loopNode->getBody()->traverse(this);
2247 
2248     return false;
2249 }
2250 
emitWhileLoop(TIntermLoop * loopNode)2251 bool OutputWGSLTraverser::emitWhileLoop(TIntermLoop *loopNode)
2252 {
2253     ASSERT(loopNode->getType() == TLoopType::ELoopWhile);
2254 
2255     TIntermNode *initNode  = loopNode->getInit();
2256     TIntermTyped *condNode = loopNode->getCondition();
2257     TIntermTyped *exprNode = loopNode->getExpression();
2258     ASSERT(condNode);
2259     ASSERT(!initNode && !exprNode);
2260 
2261     emitIndentation();
2262     mSink << "while (";
2263     condNode->traverse(this);
2264     mSink << ")\n";
2265     loopNode->getBody()->traverse(this);
2266 
2267     return false;
2268 }
2269 
emulateDoWhileLoop(TIntermLoop * loopNode)2270 bool OutputWGSLTraverser::emulateDoWhileLoop(TIntermLoop *loopNode)
2271 {
2272     ASSERT(loopNode->getType() == TLoopType::ELoopDoWhile);
2273 
2274     TIntermNode *initNode  = loopNode->getInit();
2275     TIntermTyped *condNode = loopNode->getCondition();
2276     TIntermTyped *exprNode = loopNode->getExpression();
2277     ASSERT(condNode);
2278     ASSERT(!initNode && !exprNode);
2279 
2280     emitIndentation();
2281     // Write an infinite loop.
2282     mSink << "loop {\n";
2283     mIndentLevel++;
2284     loopNode->getBody()->traverse(this);
2285     mSink << "\n";
2286     emitIndentation();
2287     // At the end of the loop, break if the loop condition dos not still hold.
2288     mSink << "if (!(";
2289     condNode->traverse(this);
2290     mSink << ") { break; }\n";
2291     mIndentLevel--;
2292     emitIndentation();
2293     mSink << "}";
2294 
2295     return false;
2296 }
2297 
visitBranch(Visit,TIntermBranch * branchNode)2298 bool OutputWGSLTraverser::visitBranch(Visit, TIntermBranch *branchNode)
2299 {
2300     const TOperator flowOp = branchNode->getFlowOp();
2301     TIntermTyped *exprNode = branchNode->getExpression();
2302 
2303     emitIndentation();
2304 
2305     switch (flowOp)
2306     {
2307         case TOperator::EOpKill:
2308         {
2309             ASSERT(exprNode == nullptr);
2310             mSink << "discard";
2311         }
2312         break;
2313 
2314         case TOperator::EOpReturn:
2315         {
2316             mSink << "return";
2317             if (exprNode)
2318             {
2319                 mSink << " ";
2320                 exprNode->traverse(this);
2321             }
2322         }
2323         break;
2324 
2325         case TOperator::EOpBreak:
2326         {
2327             ASSERT(exprNode == nullptr);
2328             mSink << "break";
2329         }
2330         break;
2331 
2332         case TOperator::EOpContinue:
2333         {
2334             ASSERT(exprNode == nullptr);
2335             mSink << "continue";
2336         }
2337         break;
2338 
2339         default:
2340         {
2341             UNREACHABLE();
2342         }
2343     }
2344 
2345     return false;
2346 }
2347 
visitPreprocessorDirective(TIntermPreprocessorDirective * node)2348 void OutputWGSLTraverser::visitPreprocessorDirective(TIntermPreprocessorDirective *node)
2349 {
2350     // No preprocessor directives expected at this point.
2351     UNREACHABLE();
2352 }
2353 
emitBareTypeName(const TType & type)2354 void OutputWGSLTraverser::emitBareTypeName(const TType &type)
2355 {
2356     WriteWgslBareTypeName(mSink, type, {});
2357 }
2358 
emitType(const TType & type)2359 void OutputWGSLTraverser::emitType(const TType &type)
2360 {
2361     WriteWgslType(mSink, type, {});
2362 }
2363 
2364 }  // namespace
2365 
TranslatorWGSL(sh::GLenum type,ShShaderSpec spec,ShShaderOutput output)2366 TranslatorWGSL::TranslatorWGSL(sh::GLenum type, ShShaderSpec spec, ShShaderOutput output)
2367     : TCompiler(type, spec, output)
2368 {}
2369 
preTranslateTreeModifications(TIntermBlock * root)2370 bool TranslatorWGSL::preTranslateTreeModifications(TIntermBlock *root)
2371 {
2372 
2373     int aggregateTypesUsedForUniforms = 0;
2374     for (const auto &uniform : getUniforms())
2375     {
2376         if (uniform.isStruct() || uniform.isArrayOfArrays())
2377         {
2378             ++aggregateTypesUsedForUniforms;
2379         }
2380     }
2381 
2382     // Samplers are legal as function parameters, but samplers within structs or arrays are not
2383     // allowed in WGSL
2384     // (https://www.w3.org/TR/WGSL/#function-call-expr:~:text=A%20function%20parameter,a%20sampler%20type).
2385     // TODO(anglebug.com/389145696): handle arrays of samplers here.
2386 
2387     // If there are any function calls that take array-of-array of opaque uniform parameters, or
2388     // other opaque uniforms that need special handling in WebGPU, monomorphize the functions by
2389     // removing said parameters and replacing them in the function body with the call arguments.
2390     //
2391     // This  dramatically simplifies future transformations w.r.t to samplers in structs, array of
2392     //   arrays of opaque types, atomic counters etc.
2393     UnsupportedFunctionArgsBitSet args{UnsupportedFunctionArgs::StructContainingSamplers,
2394                                        UnsupportedFunctionArgs::ArrayOfArrayOfSamplerOrImage,
2395                                        UnsupportedFunctionArgs::AtomicCounter,
2396                                        UnsupportedFunctionArgs::Image};
2397     if (!MonomorphizeUnsupportedFunctions(this, root, &getSymbolTable(), args))
2398     {
2399         return false;
2400     }
2401 
2402     if (aggregateTypesUsedForUniforms > 0)
2403     {
2404         if (!SeparateStructFromUniformDeclarations(this, root, &getSymbolTable()))
2405         {
2406             return false;
2407         }
2408 
2409         int removedUniformsCount;
2410 
2411         // Requires MonomorphizeUnsupportedFunctions() to have been run already.
2412         if (!RewriteStructSamplers(this, root, &getSymbolTable(), &removedUniformsCount))
2413         {
2414             return false;
2415         }
2416     }
2417 
2418     // Replace array of array of opaque uniforms with a flattened array.  This is run after
2419     // MonomorphizeUnsupportedFunctions and RewriteStructSamplers so that it's not possible for an
2420     // array of array of opaque type to be partially subscripted and passed to a function.
2421     // TODO(anglebug.com/389145696): Even single-level arrays of samplers are not allowed in WGSL.
2422     if (!RewriteArrayOfArrayOfOpaqueUniforms(this, root, &getSymbolTable()))
2423     {
2424         return false;
2425     }
2426 
2427     return true;
2428 }
2429 
translate(TIntermBlock * root,const ShCompileOptions & compileOptions,PerformanceDiagnostics * perfDiagnostics)2430 bool TranslatorWGSL::translate(TIntermBlock *root,
2431                                const ShCompileOptions &compileOptions,
2432                                PerformanceDiagnostics *perfDiagnostics)
2433 {
2434     if (kOutputTreeBeforeTranslation)
2435     {
2436         OutputTree(root, getInfoSink().info);
2437         std::cout << getInfoSink().info.c_str();
2438     }
2439 
2440     if (!preTranslateTreeModifications(root))
2441     {
2442         return false;
2443     }
2444     enableValidateNoMoreTransformations();
2445 
2446     RewritePipelineVarOutput rewritePipelineVarOutput(getShaderType());
2447     WGSLGenerationMetadataForUniforms wgslGenerationMetadataForUniforms;
2448 
2449     // WGSL's main() will need to take parameters or return values if any glsl (input/output)
2450     // builtin variables are used.
2451     if (!GenerateMainFunctionAndIOStructs(*this, *root, rewritePipelineVarOutput))
2452     {
2453         return false;
2454     }
2455 
2456     TInfoSinkBase &sink = getInfoSink().obj;
2457     // Start writing the output structs that will be referred to by the `traverser`'s output.'
2458     if (!rewritePipelineVarOutput.OutputStructs(sink))
2459     {
2460         return false;
2461     }
2462 
2463     if (!OutputUniformBlocksAndSamplers(this, root))
2464     {
2465         return false;
2466     }
2467 
2468     UniformBlockMetadata uniformBlockMetadata;
2469     if (!RecordUniformBlockMetadata(root, uniformBlockMetadata))
2470     {
2471         return false;
2472     }
2473 
2474     // Generate the body of the WGSL including the GLSL main() function.
2475     TInfoSinkBase traverserOutput;
2476     OutputWGSLTraverser traverser(&traverserOutput, &rewritePipelineVarOutput,
2477                                   &uniformBlockMetadata, &wgslGenerationMetadataForUniforms);
2478     root->traverse(&traverser);
2479 
2480     sink << "\n";
2481     OutputUniformWrapperStructsAndConversions(sink, wgslGenerationMetadataForUniforms);
2482 
2483     // The traverser output needs to be in the code after uniform wrapper structs are emitted above,
2484     // since the traverser code references the wrapper struct types.
2485     sink << traverserOutput.str();
2486 
2487     // Write the actual WGSL main function, wgslMain(), which calls the GLSL main function.
2488     if (!rewritePipelineVarOutput.OutputMainFunction(sink))
2489     {
2490         return false;
2491     }
2492 
2493     if (kOutputTranslatedShader)
2494     {
2495         std::cout << sink.str();
2496     }
2497 
2498     return true;
2499 }
2500 
shouldFlattenPragmaStdglInvariantAll()2501 bool TranslatorWGSL::shouldFlattenPragmaStdglInvariantAll()
2502 {
2503     // Not neccesary for WGSL transformation.
2504     return false;
2505 }
2506 }  // namespace sh
2507