1 //
2 // Copyright 2024 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include "compiler/translator/wgsl/TranslatorWGSL.h"
8
9 #include <iostream>
10 #include <variant>
11
12 #include "GLSLANG/ShaderLang.h"
13 #include "common/log_utils.h"
14 #include "common/span.h"
15 #include "compiler/translator/BaseTypes.h"
16 #include "compiler/translator/Common.h"
17 #include "compiler/translator/Diagnostics.h"
18 #include "compiler/translator/ImmutableString.h"
19 #include "compiler/translator/ImmutableStringBuilder.h"
20 #include "compiler/translator/InfoSink.h"
21 #include "compiler/translator/IntermNode.h"
22 #include "compiler/translator/Operator_autogen.h"
23 #include "compiler/translator/OutputTree.h"
24 #include "compiler/translator/StaticType.h"
25 #include "compiler/translator/SymbolUniqueId.h"
26 #include "compiler/translator/Types.h"
27 #include "compiler/translator/tree_ops/MonomorphizeUnsupportedFunctions.h"
28 #include "compiler/translator/tree_ops/RewriteArrayOfArrayOfOpaqueUniforms.h"
29 #include "compiler/translator/tree_ops/RewriteStructSamplers.h"
30 #include "compiler/translator/tree_ops/SeparateStructFromUniformDeclarations.h"
31 #include "compiler/translator/tree_util/BuiltIn_autogen.h"
32 #include "compiler/translator/tree_util/FindMain.h"
33 #include "compiler/translator/tree_util/IntermNode_util.h"
34 #include "compiler/translator/tree_util/IntermTraverse.h"
35 #include "compiler/translator/tree_util/RunAtTheEndOfShader.h"
36 #include "compiler/translator/wgsl/OutputUniformBlocks.h"
37 #include "compiler/translator/wgsl/RewritePipelineVariables.h"
38 #include "compiler/translator/wgsl/Utils.h"
39
40 namespace sh
41 {
42 namespace
43 {
44
45 constexpr bool kOutputTreeBeforeTranslation = false;
46 constexpr bool kOutputTranslatedShader = false;
47
48 struct VarDecl
49 {
50 const SymbolType symbolType = SymbolType::Empty;
51 const ImmutableString &symbolName;
52 const TType &type;
53 };
54
IsDefaultUniform(const TType & type)55 bool IsDefaultUniform(const TType &type)
56 {
57 return type.getQualifier() == EvqUniform && type.getInterfaceBlock() == nullptr &&
58 !IsOpaqueType(type.getBasicType());
59 }
60
61 // When emitting a list of statements, this determines whether a semicolon follows the statement.
RequiresSemicolonTerminator(TIntermNode & node)62 bool RequiresSemicolonTerminator(TIntermNode &node)
63 {
64 if (node.getAsBlock())
65 {
66 return false;
67 }
68 if (node.getAsLoopNode())
69 {
70 return false;
71 }
72 if (node.getAsSwitchNode())
73 {
74 return false;
75 }
76 if (node.getAsIfElseNode())
77 {
78 return false;
79 }
80 if (node.getAsFunctionDefinition())
81 {
82 return false;
83 }
84 if (node.getAsCaseNode())
85 {
86 return false;
87 }
88
89 return true;
90 }
91
92 // For pretty formatting of the resulting WGSL text.
NewlinePad(TIntermNode & node)93 bool NewlinePad(TIntermNode &node)
94 {
95 if (node.getAsFunctionDefinition())
96 {
97 return true;
98 }
99 if (TIntermDeclaration *declNode = node.getAsDeclarationNode())
100 {
101 ASSERT(declNode->getChildCount() == 1);
102 TIntermNode &childNode = *declNode->getChildNode(0);
103 if (TIntermSymbol *symbolNode = childNode.getAsSymbolNode())
104 {
105 const TVariable &var = symbolNode->variable();
106 return var.getType().isStructSpecifier();
107 }
108 return false;
109 }
110 return false;
111 }
112
113 // A traverser that generates WGSL as it walks the AST.
114 class OutputWGSLTraverser : public TIntermTraverser
115 {
116 public:
117 OutputWGSLTraverser(TInfoSinkBase *sink,
118 RewritePipelineVarOutput *rewritePipelineVarOutput,
119 UniformBlockMetadata *uniformBlockMetadata,
120 WGSLGenerationMetadataForUniforms *arrayElementTypesInUniforms);
121 ~OutputWGSLTraverser() override;
122
123 protected:
124 void visitSymbol(TIntermSymbol *node) override;
125 void visitConstantUnion(TIntermConstantUnion *node) override;
126 bool visitSwizzle(Visit visit, TIntermSwizzle *node) override;
127 bool visitBinary(Visit visit, TIntermBinary *node) override;
128 bool visitUnary(Visit visit, TIntermUnary *node) override;
129 bool visitTernary(Visit visit, TIntermTernary *node) override;
130 bool visitIfElse(Visit visit, TIntermIfElse *node) override;
131 bool visitSwitch(Visit visit, TIntermSwitch *node) override;
132 bool visitCase(Visit visit, TIntermCase *node) override;
133 void visitFunctionPrototype(TIntermFunctionPrototype *node) override;
134 bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
135 bool visitAggregate(Visit visit, TIntermAggregate *node) override;
136 bool visitBlock(Visit visit, TIntermBlock *node) override;
137 bool visitGlobalQualifierDeclaration(Visit visit,
138 TIntermGlobalQualifierDeclaration *node) override;
139 bool visitDeclaration(Visit visit, TIntermDeclaration *node) override;
140 bool visitLoop(Visit visit, TIntermLoop *node) override;
141 bool visitBranch(Visit visit, TIntermBranch *node) override;
142 void visitPreprocessorDirective(TIntermPreprocessorDirective *node) override;
143
144 private:
145 struct EmitVariableDeclarationConfig
146 {
147 EmitTypeConfig typeConfig;
148 bool isParameter = false;
149 bool disableStructSpecifier = false;
150 bool needsVar = false;
151 bool isGlobalScope = false;
152 };
153
154 void groupedTraverse(TIntermNode &node);
155 void emitNameOf(const VarDecl &decl);
156 void emitBareTypeName(const TType &type);
157 void emitType(const TType &type);
158 void emitSingleConstant(const TConstantUnion *const constUnion);
159 const TConstantUnion *emitConstantUnionArray(const TConstantUnion *const constUnion,
160 const size_t size);
161 const TConstantUnion *emitConstantUnion(const TType &type,
162 const TConstantUnion *constUnionBegin);
163 const TField &getDirectField(const TIntermTyped &fieldsNode, TIntermTyped &indexNode);
164 void emitIndentation();
165 void emitOpenBrace();
166 void emitCloseBrace();
167 bool emitBlock(angle::Span<TIntermNode *> nodes);
168 void emitFunctionSignature(const TFunction &func);
169 void emitFunctionReturn(const TFunction &func);
170 void emitFunctionParameter(const TFunction &func, const TVariable ¶m);
171 void emitStructDeclaration(const TType &type);
172 void emitVariableDeclaration(const VarDecl &decl,
173 const EmitVariableDeclarationConfig &evdConfig);
174 void emitArrayIndex(TIntermTyped &leftNode, TIntermTyped &rightNode);
175 void emitStructIndex(TIntermBinary *binaryNode);
176 void emitStructIndexNoUnwrapping(TIntermBinary *binaryNode);
177 void emitTextureBuiltin(const TOperator op, const TIntermSequence &args);
178
179 bool emitForLoop(TIntermLoop *);
180 bool emitWhileLoop(TIntermLoop *);
181 bool emulateDoWhileLoop(TIntermLoop *);
182
183 TInfoSinkBase &mSink;
184 const RewritePipelineVarOutput *mRewritePipelineVarOutput;
185 const UniformBlockMetadata *mUniformBlockMetadata;
186 WGSLGenerationMetadataForUniforms *mWGSLGenerationMetadataForUniforms;
187
188 int mIndentLevel = -1;
189 int mLastIndentationPos = -1;
190 };
191
OutputWGSLTraverser(TInfoSinkBase * sink,RewritePipelineVarOutput * rewritePipelineVarOutput,UniformBlockMetadata * uniformBlockMetadata,WGSLGenerationMetadataForUniforms * wgslGenerationMetadataForUniforms)192 OutputWGSLTraverser::OutputWGSLTraverser(
193 TInfoSinkBase *sink,
194 RewritePipelineVarOutput *rewritePipelineVarOutput,
195 UniformBlockMetadata *uniformBlockMetadata,
196 WGSLGenerationMetadataForUniforms *wgslGenerationMetadataForUniforms)
197 : TIntermTraverser(true, false, false),
198 mSink(*sink),
199 mRewritePipelineVarOutput(rewritePipelineVarOutput),
200 mUniformBlockMetadata(uniformBlockMetadata),
201 mWGSLGenerationMetadataForUniforms(wgslGenerationMetadataForUniforms)
202 {}
203
204 OutputWGSLTraverser::~OutputWGSLTraverser() = default;
205
groupedTraverse(TIntermNode & node)206 void OutputWGSLTraverser::groupedTraverse(TIntermNode &node)
207 {
208 // TODO(anglebug.com/42267100): to make generated code more readable, do not always
209 // emit parentheses like WGSL is some Lisp dialect.
210 const bool emitParens = true;
211
212 if (emitParens)
213 {
214 mSink << "(";
215 }
216
217 node.traverse(this);
218
219 if (emitParens)
220 {
221 mSink << ")";
222 }
223 }
224
emitNameOf(const VarDecl & decl)225 void OutputWGSLTraverser::emitNameOf(const VarDecl &decl)
226 {
227 WriteNameOf(mSink, decl.symbolType, decl.symbolName);
228 }
229
emitIndentation()230 void OutputWGSLTraverser::emitIndentation()
231 {
232 ASSERT(mIndentLevel >= 0);
233
234 if (mLastIndentationPos == mSink.size())
235 {
236 return; // Line is already indented.
237 }
238
239 for (int i = 0; i < mIndentLevel; ++i)
240 {
241 mSink << " ";
242 }
243
244 mLastIndentationPos = mSink.size();
245 }
246
emitOpenBrace()247 void OutputWGSLTraverser::emitOpenBrace()
248 {
249 ASSERT(mIndentLevel >= 0);
250
251 emitIndentation();
252 mSink << "{\n";
253 ++mIndentLevel;
254 }
255
emitCloseBrace()256 void OutputWGSLTraverser::emitCloseBrace()
257 {
258 ASSERT(mIndentLevel >= 1);
259
260 --mIndentLevel;
261 emitIndentation();
262 mSink << "}";
263 }
264
visitSymbol(TIntermSymbol * symbolNode)265 void OutputWGSLTraverser::visitSymbol(TIntermSymbol *symbolNode)
266 {
267
268 const TVariable &var = symbolNode->variable();
269 const TType &type = var.getType();
270 ASSERT(var.symbolType() != SymbolType::Empty);
271
272 if (type.getBasicType() == TBasicType::EbtVoid)
273 {
274 UNREACHABLE();
275 }
276 else
277 {
278 // Accesses of pipeline variables should be rewritten as struct accesses.
279 if (mRewritePipelineVarOutput->IsInputVar(var.uniqueId()))
280 {
281 mSink << kBuiltinInputStructName << "." << var.name();
282 }
283 else if (mRewritePipelineVarOutput->IsOutputVar(var.uniqueId()))
284 {
285 mSink << kBuiltinOutputStructName << "." << var.name();
286 }
287 // Accesses of basic uniforms need to be converted to struct accesses.
288 else if (IsDefaultUniform(type))
289 {
290 mSink << kDefaultUniformBlockVarName << "." << var.name();
291 }
292 else
293 {
294 WriteNameOf(mSink, var);
295 }
296
297 if (var.symbolType() == SymbolType::BuiltIn)
298 {
299 ASSERT(mRewritePipelineVarOutput->IsInputVar(var.uniqueId()) ||
300 mRewritePipelineVarOutput->IsOutputVar(var.uniqueId()) ||
301 var.uniqueId() == BuiltInId::gl_DepthRange);
302 // TODO(anglebug.com/42267100): support gl_DepthRange.
303 // Match the name of the struct field in `mRewritePipelineVarOutput`.
304 mSink << "_";
305 }
306 }
307 }
308
emitSingleConstant(const TConstantUnion * const constUnion)309 void OutputWGSLTraverser::emitSingleConstant(const TConstantUnion *const constUnion)
310 {
311 switch (constUnion->getType())
312 {
313 case TBasicType::EbtBool:
314 {
315 mSink << (constUnion->getBConst() ? "true" : "false");
316 }
317 break;
318
319 case TBasicType::EbtFloat:
320 {
321 float value = constUnion->getFConst();
322 if (std::isnan(value))
323 {
324 UNIMPLEMENTED();
325 // TODO(anglebug.com/42267100): this is not a valid constant in WGPU.
326 // You can't even do something like bitcast<f32>(0xffffffffu).
327 // The WGSL compiler still complains. I think this is because
328 // WGSL supports implementations compiling with -ffastmath and
329 // therefore nans and infinities are assumed to not exist.
330 // See also https://github.com/gpuweb/gpuweb/issues/3749.
331 mSink << "NAN_INVALID";
332 }
333 else if (std::isinf(value))
334 {
335 UNIMPLEMENTED();
336 // see above.
337 mSink << "INFINITY_INVALID";
338 }
339 else
340 {
341 mSink << value << "f";
342 }
343 }
344 break;
345
346 case TBasicType::EbtInt:
347 {
348 mSink << constUnion->getIConst() << "i";
349 }
350 break;
351
352 case TBasicType::EbtUInt:
353 {
354 mSink << constUnion->getUConst() << "u";
355 }
356 break;
357
358 default:
359 {
360 UNIMPLEMENTED();
361 }
362 }
363 }
364
emitConstantUnionArray(const TConstantUnion * const constUnion,const size_t size)365 const TConstantUnion *OutputWGSLTraverser::emitConstantUnionArray(
366 const TConstantUnion *const constUnion,
367 const size_t size)
368 {
369 const TConstantUnion *constUnionIterated = constUnion;
370 for (size_t i = 0; i < size; i++, constUnionIterated++)
371 {
372 emitSingleConstant(constUnionIterated);
373
374 if (i != size - 1)
375 {
376 mSink << ", ";
377 }
378 }
379 return constUnionIterated;
380 }
381
emitConstantUnion(const TType & type,const TConstantUnion * constUnionBegin)382 const TConstantUnion *OutputWGSLTraverser::emitConstantUnion(const TType &type,
383 const TConstantUnion *constUnionBegin)
384 {
385 const TConstantUnion *constUnionCurr = constUnionBegin;
386 const TStructure *structure = type.getStruct();
387 if (structure)
388 {
389 emitType(type);
390 // Structs are constructed with parentheses in WGSL.
391 mSink << "(";
392 // Emit the constructor parameters. Both GLSL and WGSL require there to be the same number
393 // of parameters as struct fields.
394 const TFieldList &fields = structure->fields();
395 for (size_t i = 0; i < fields.size(); ++i)
396 {
397 const TType *fieldType = fields[i]->type();
398 constUnionCurr = emitConstantUnion(*fieldType, constUnionCurr);
399 if (i != fields.size() - 1)
400 {
401 mSink << ", ";
402 }
403 }
404 mSink << ")";
405 }
406 else
407 {
408 size_t size = type.getObjectSize();
409 // If the type's size is more than 1, the type needs to be written with parantheses. This
410 // applies for vectors, matrices, and arrays.
411 bool writeType = size > 1;
412 if (writeType)
413 {
414 emitType(type);
415 mSink << "(";
416 }
417 constUnionCurr = emitConstantUnionArray(constUnionCurr, size);
418 if (writeType)
419 {
420 mSink << ")";
421 }
422 }
423 return constUnionCurr;
424 }
425
visitConstantUnion(TIntermConstantUnion * constValueNode)426 void OutputWGSLTraverser::visitConstantUnion(TIntermConstantUnion *constValueNode)
427 {
428 emitConstantUnion(constValueNode->getType(), constValueNode->getConstantValue());
429 }
430
visitSwizzle(Visit,TIntermSwizzle * swizzleNode)431 bool OutputWGSLTraverser::visitSwizzle(Visit, TIntermSwizzle *swizzleNode)
432 {
433 groupedTraverse(*swizzleNode->getOperand());
434 mSink << "." << swizzleNode->getOffsetsAsXYZW();
435
436 return false;
437 }
438
GetOperatorString(TOperator op,const TType & resultType,const TType * argType0,const TType * argType1,const TType * argType2)439 const char *GetOperatorString(TOperator op,
440 const TType &resultType,
441 const TType *argType0,
442 const TType *argType1,
443 const TType *argType2)
444 {
445 switch (op)
446 {
447 case TOperator::EOpComma:
448 // WGSL does not have a comma operator or any other way to implement "statement list as
449 // an expression", so nested expressions will have to be pulled out into statements.
450 UNIMPLEMENTED();
451 return "TODO_operator";
452 case TOperator::EOpAssign:
453 return "=";
454 case TOperator::EOpInitialize:
455 return "=";
456 // Compound assignments now exist: https://www.w3.org/TR/WGSL/#compound-assignment-sec
457 case TOperator::EOpAddAssign:
458 return "+=";
459 case TOperator::EOpSubAssign:
460 return "-=";
461 case TOperator::EOpMulAssign:
462 return "*=";
463 case TOperator::EOpDivAssign:
464 return "/=";
465 case TOperator::EOpIModAssign:
466 return "%=";
467 case TOperator::EOpBitShiftLeftAssign:
468 return "<<=";
469 case TOperator::EOpBitShiftRightAssign:
470 return ">>=";
471 case TOperator::EOpBitwiseAndAssign:
472 return "&=";
473 case TOperator::EOpBitwiseXorAssign:
474 return "^=";
475 case TOperator::EOpBitwiseOrAssign:
476 return "|=";
477 case TOperator::EOpAdd:
478 return "+";
479 case TOperator::EOpSub:
480 return "-";
481 case TOperator::EOpMul:
482 return "*";
483 case TOperator::EOpDiv:
484 return "/";
485 // TODO(anglebug.com/42267100): Works different from GLSL for negative numbers.
486 // https://github.com/gpuweb/gpuweb/discussions/2204#:~:text=not%20WGSL%3B%20etc.-,Inconsistent%20mod/%25%20operator,-At%20first%20glance
487 // GLSL does `x - y * floor(x/y)`, WGSL does x - y * trunc(x/y).
488 case TOperator::EOpIMod:
489 case TOperator::EOpMod:
490 return "%";
491 case TOperator::EOpBitShiftLeft:
492 return "<<";
493 case TOperator::EOpBitShiftRight:
494 return ">>";
495 case TOperator::EOpBitwiseAnd:
496 return "&";
497 case TOperator::EOpBitwiseXor:
498 return "^";
499 case TOperator::EOpBitwiseOr:
500 return "|";
501 case TOperator::EOpLessThan:
502 return "<";
503 case TOperator::EOpGreaterThan:
504 return ">";
505 case TOperator::EOpLessThanEqual:
506 return "<=";
507 case TOperator::EOpGreaterThanEqual:
508 return ">=";
509 // Component-wise comparisons are done with regular infix operators in WGSL:
510 // https://www.w3.org/TR/WGSL/#comparison-expr
511 case TOperator::EOpLessThanComponentWise:
512 return "<";
513 case TOperator::EOpLessThanEqualComponentWise:
514 return "<=";
515 case TOperator::EOpGreaterThanEqualComponentWise:
516 return ">=";
517 case TOperator::EOpGreaterThanComponentWise:
518 return ">";
519 case TOperator::EOpLogicalOr:
520 return "||";
521 // Logical XOR is only applied to boolean expressions so it's the same as "not equals".
522 // Neither short-circuits.
523 case TOperator::EOpLogicalXor:
524 return "!=";
525 case TOperator::EOpLogicalAnd:
526 return "&&";
527 case TOperator::EOpNegative:
528 return "-";
529 case TOperator::EOpPositive:
530 if (argType0->isMatrix())
531 {
532 return "";
533 }
534 return "+";
535 case TOperator::EOpLogicalNot:
536 return "!";
537 // Component-wise not done with normal prefix unary operator in WGSL:
538 // https://www.w3.org/TR/WGSL/#logical-expr
539 case TOperator::EOpNotComponentWise:
540 return "!";
541 case TOperator::EOpBitwiseNot:
542 return "~";
543 // TODO(anglebug.com/42267100): increment operations cannot be used as expressions in WGSL.
544 case TOperator::EOpPostIncrement:
545 return "++";
546 case TOperator::EOpPostDecrement:
547 return "--";
548 case TOperator::EOpPreIncrement:
549 case TOperator::EOpPreDecrement:
550 // TODO(anglebug.com/42267100): pre increments and decrements do not exist in WGSL.
551 UNIMPLEMENTED();
552 return "TODO_operator";
553 case TOperator::EOpVectorTimesScalarAssign:
554 return "*=";
555 case TOperator::EOpVectorTimesMatrixAssign:
556 return "*=";
557 case TOperator::EOpMatrixTimesScalarAssign:
558 return "*=";
559 case TOperator::EOpMatrixTimesMatrixAssign:
560 return "*=";
561 case TOperator::EOpVectorTimesScalar:
562 return "*";
563 case TOperator::EOpVectorTimesMatrix:
564 return "*";
565 case TOperator::EOpMatrixTimesVector:
566 return "*";
567 case TOperator::EOpMatrixTimesScalar:
568 return "*";
569 case TOperator::EOpMatrixTimesMatrix:
570 return "*";
571 case TOperator::EOpEqualComponentWise:
572 return "==";
573 case TOperator::EOpNotEqualComponentWise:
574 return "!=";
575
576 // TODO(anglebug.com/42267100): structs, matrices, and arrays are not comparable with WGSL's
577 // == or !=. Comparing vectors results in a component-wise comparison returning a boolean
578 // vector, which is different from GLSL (which use equal(vec, vec) for component-wise
579 // comparison)
580 case TOperator::EOpEqual:
581 if ((argType0->isVector() && argType1->isVector()) ||
582 (argType0->getStruct() && argType1->getStruct()) ||
583 (argType0->isArray() && argType1->isArray()) ||
584 (argType0->isMatrix() && argType1->isMatrix()))
585
586 {
587 UNIMPLEMENTED();
588 return "TODO_operator";
589 }
590
591 return "==";
592
593 case TOperator::EOpNotEqual:
594 if ((argType0->isVector() && argType1->isVector()) ||
595 (argType0->getStruct() && argType1->getStruct()) ||
596 (argType0->isArray() && argType1->isArray()) ||
597 (argType0->isMatrix() && argType1->isMatrix()))
598 {
599 UNIMPLEMENTED();
600 return "TODO_operator";
601 }
602 return "!=";
603
604 case TOperator::EOpKill:
605 case TOperator::EOpReturn:
606 case TOperator::EOpBreak:
607 case TOperator::EOpContinue:
608 // These should all be emitted in visitBranch().
609 UNREACHABLE();
610 return "UNREACHABLE_operator";
611 case TOperator::EOpRadians:
612 return "radians";
613 case TOperator::EOpDegrees:
614 return "degrees";
615 case TOperator::EOpAtan:
616 return argType1 == nullptr ? "atan" : "atan2";
617 case TOperator::EOpRefract:
618 return argType0->isVector() ? "refract" : "TODO_operator";
619 case TOperator::EOpDistance:
620 return "distance";
621 case TOperator::EOpLength:
622 return "length";
623 case TOperator::EOpDot:
624 return argType0->isVector() ? "dot" : "*";
625 case TOperator::EOpNormalize:
626 return argType0->isVector() ? "normalize" : "sign";
627 case TOperator::EOpFaceforward:
628 return argType0->isVector() ? "faceForward" : "TODO_Operator";
629 case TOperator::EOpReflect:
630 return argType0->isVector() ? "reflect" : "TODO_Operator";
631 case TOperator::EOpMatrixCompMult:
632 return "TODO_Operator";
633 case TOperator::EOpOuterProduct:
634 return "TODO_Operator";
635 case TOperator::EOpSign:
636 return "sign";
637
638 case TOperator::EOpAbs:
639 return "abs";
640 case TOperator::EOpAll:
641 return "all";
642 case TOperator::EOpAny:
643 return "any";
644 case TOperator::EOpSin:
645 return "sin";
646 case TOperator::EOpCos:
647 return "cos";
648 case TOperator::EOpTan:
649 return "tan";
650 case TOperator::EOpAsin:
651 return "asin";
652 case TOperator::EOpAcos:
653 return "acos";
654 case TOperator::EOpSinh:
655 return "sinh";
656 case TOperator::EOpCosh:
657 return "cosh";
658 case TOperator::EOpTanh:
659 return "tanh";
660 case TOperator::EOpAsinh:
661 return "asinh";
662 case TOperator::EOpAcosh:
663 return "acosh";
664 case TOperator::EOpAtanh:
665 return "atanh";
666 case TOperator::EOpFma:
667 return "fma";
668 // TODO(anglebug.com/42267100): Won't accept pow(vec<f32>, f32).
669 // https://github.com/gpuweb/gpuweb/discussions/2204#:~:text=Similarly%20pow(vec3%3Cf32%3E%2C%20f32)%20works%20in%20GLSL%20but%20not%20WGSL
670 case TOperator::EOpPow:
671 return "pow"; // GLSL's pow excludes negative x
672 case TOperator::EOpExp:
673 return "exp";
674 case TOperator::EOpExp2:
675 return "exp2";
676 case TOperator::EOpLog:
677 return "log";
678 case TOperator::EOpLog2:
679 return "log2";
680 case TOperator::EOpSqrt:
681 return "sqrt";
682 case TOperator::EOpFloor:
683 return "floor";
684 case TOperator::EOpTrunc:
685 return "trunc";
686 case TOperator::EOpCeil:
687 return "ceil";
688 case TOperator::EOpFract:
689 return "fract";
690 case TOperator::EOpMin:
691 return "min";
692 case TOperator::EOpMax:
693 return "max";
694 case TOperator::EOpRound:
695 return "round"; // TODO(anglebug.com/42267100): this is wrong and must round away from
696 // zero if there is a tie. This always rounds to the even number.
697 case TOperator::EOpRoundEven:
698 return "round";
699 // TODO(anglebug.com/42267100):
700 // https://github.com/gpuweb/gpuweb/discussions/2204#:~:text=clamp(vec2%3Cf32%3E%2C%20f32%2C%20f32)%20works%20in%20GLSL%20but%20not%20WGSL%3B%20etc.
701 // Need to expand clamp(vec<f32>, low : f32, high : f32) ->
702 // clamp(vec<f32>, vec<f32>(low), vec<f32>(high))
703 case TOperator::EOpClamp:
704 return "clamp";
705 case TOperator::EOpSaturate:
706 return "saturate";
707 case TOperator::EOpMix:
708 if (!argType1->isScalar() && argType2 && argType2->getBasicType() == EbtBool)
709 {
710 return "TODO_Operator";
711 }
712 return "mix";
713 case TOperator::EOpStep:
714 return "step";
715 case TOperator::EOpSmoothstep:
716 return "smoothstep";
717 case TOperator::EOpModf:
718 UNIMPLEMENTED(); // TODO(anglebug.com/42267100): in WGSL this returns a struct, GLSL it
719 // uses a return value and an outparam
720 return "modf";
721 case TOperator::EOpIsnan:
722 case TOperator::EOpIsinf:
723 UNIMPLEMENTED(); // TODO(anglebug.com/42267100): WGSL does not allow NaNs or infinity.
724 // What to do about shaders that require this?
725 // Implementations are allowed to assume overflow, infinities, and NaNs are not present
726 // at runtime, however. https://www.w3.org/TR/WGSL/#floating-point-evaluation
727 return "TODO_Operator";
728 case TOperator::EOpLdexp:
729 // TODO(anglebug.com/42267100): won't accept first arg vector, second arg scalar
730 return "ldexp";
731 case TOperator::EOpFrexp:
732 return "frexp"; // TODO(anglebug.com/42267100): returns a struct
733 case TOperator::EOpInversesqrt:
734 return "inverseSqrt";
735 case TOperator::EOpCross:
736 return "cross";
737 // TODO(anglebug.com/42267100): are these the same? dpdxCoarse() vs dpdxFine()?
738 case TOperator::EOpDFdx:
739 return "dpdx";
740 case TOperator::EOpDFdy:
741 return "dpdy";
742 case TOperator::EOpFwidth:
743 return "fwidth";
744 case TOperator::EOpTranspose:
745 return "transpose";
746 case TOperator::EOpDeterminant:
747 return "determinant";
748
749 case TOperator::EOpInverse:
750 return "TODO_Operator"; // No builtin invert().
751 // https://github.com/gpuweb/gpuweb/issues/4115
752
753 // TODO(anglebug.com/42267100): these interpolateAt*() are not builtin
754 case TOperator::EOpInterpolateAtCentroid:
755 return "TODO_Operator";
756 case TOperator::EOpInterpolateAtSample:
757 return "TODO_Operator";
758 case TOperator::EOpInterpolateAtOffset:
759 return "TODO_Operator";
760 case TOperator::EOpInterpolateAtCenter:
761 return "TODO_Operator";
762
763 case TOperator::EOpFloatBitsToInt:
764 case TOperator::EOpFloatBitsToUint:
765 case TOperator::EOpIntBitsToFloat:
766 case TOperator::EOpUintBitsToFloat:
767 {
768 #define BITCAST_SCALAR() \
769 do \
770 switch (resultType.getBasicType()) \
771 { \
772 case TBasicType::EbtInt: \
773 return "bitcast<i32>"; \
774 case TBasicType::EbtUInt: \
775 return "bitcast<u32>"; \
776 case TBasicType::EbtFloat: \
777 return "bitcast<f32>"; \
778 default: \
779 UNIMPLEMENTED(); \
780 return "TOperator_TODO"; \
781 } \
782 while (false)
783
784 #define BITCAST_VECTOR(vecSize) \
785 do \
786 switch (resultType.getBasicType()) \
787 { \
788 case TBasicType::EbtInt: \
789 return "bitcast<vec" vecSize "<i32>>"; \
790 case TBasicType::EbtUInt: \
791 return "bitcast<vec" vecSize "<u32>>"; \
792 case TBasicType::EbtFloat: \
793 return "bitcast<vec" vecSize "<f32>>"; \
794 default: \
795 UNIMPLEMENTED(); \
796 return "TOperator_TODO"; \
797 } \
798 while (false)
799
800 if (resultType.isScalar())
801 {
802 BITCAST_SCALAR();
803 }
804 else if (resultType.isVector())
805 {
806 switch (resultType.getNominalSize())
807 {
808 case 2:
809 BITCAST_VECTOR("2");
810 case 3:
811 BITCAST_VECTOR("3");
812 case 4:
813 BITCAST_VECTOR("4");
814 default:
815 UNREACHABLE();
816 return nullptr;
817 }
818 }
819 else
820 {
821 UNIMPLEMENTED();
822 return "TOperator_TODO";
823 }
824
825 #undef BITCAST_SCALAR
826 #undef BITCAST_VECTOR
827 }
828
829 case TOperator::EOpPackUnorm2x16:
830 return "pack2x16unorm";
831 case TOperator::EOpPackSnorm2x16:
832 return "pack2x16snorm";
833
834 case TOperator::EOpPackUnorm4x8:
835 return "pack4x8unorm";
836 case TOperator::EOpPackSnorm4x8:
837 return "pack4x8snorm";
838
839 case TOperator::EOpUnpackUnorm2x16:
840 return "unpack2x16unorm";
841 case TOperator::EOpUnpackSnorm2x16:
842 return "unpack2x16snorm";
843
844 case TOperator::EOpUnpackUnorm4x8:
845 return "unpack4x8unorm";
846 case TOperator::EOpUnpackSnorm4x8:
847 return "unpack4x8snorm";
848
849 case TOperator::EOpPackHalf2x16:
850 return "pack2x16float";
851 case TOperator::EOpUnpackHalf2x16:
852 return "unpack2x16float";
853
854 case TOperator::EOpBarrier:
855 UNREACHABLE();
856 return "TOperator_TODO";
857 case TOperator::EOpMemoryBarrier:
858 // TODO(anglebug.com/42267100): does this exist in WGPU? Device-scoped memory barrier?
859 // Maybe storageBarrier()?
860 UNREACHABLE();
861 return "TOperator_TODO";
862 case TOperator::EOpGroupMemoryBarrier:
863 return "workgroupBarrier";
864 case TOperator::EOpMemoryBarrierAtomicCounter:
865 case TOperator::EOpMemoryBarrierBuffer:
866 case TOperator::EOpMemoryBarrierShared:
867 UNREACHABLE();
868 return "TOperator_TODO";
869 case TOperator::EOpAtomicAdd:
870 return "atomicAdd";
871 case TOperator::EOpAtomicMin:
872 return "atomicMin";
873 case TOperator::EOpAtomicMax:
874 return "atomicMax";
875 case TOperator::EOpAtomicAnd:
876 return "atomicAnd";
877 case TOperator::EOpAtomicOr:
878 return "atomicOr";
879 case TOperator::EOpAtomicXor:
880 return "atomicXor";
881 case TOperator::EOpAtomicExchange:
882 return "atomicExchange";
883 case TOperator::EOpAtomicCompSwap:
884 return "atomicCompareExchangeWeak"; // TODO(anglebug.com/42267100): returns a struct.
885 case TOperator::EOpBitfieldExtract:
886 case TOperator::EOpBitfieldInsert:
887 case TOperator::EOpBitfieldReverse:
888 case TOperator::EOpBitCount:
889 case TOperator::EOpFindLSB:
890 case TOperator::EOpFindMSB:
891 case TOperator::EOpUaddCarry:
892 case TOperator::EOpUsubBorrow:
893 case TOperator::EOpUmulExtended:
894 case TOperator::EOpImulExtended:
895 case TOperator::EOpEmitVertex:
896 case TOperator::EOpEndPrimitive:
897 case TOperator::EOpArrayLength:
898 UNIMPLEMENTED();
899 return "TOperator_TODO";
900
901 case TOperator::EOpNull:
902 case TOperator::EOpConstruct:
903 case TOperator::EOpCallFunctionInAST:
904 case TOperator::EOpCallInternalRawFunction:
905 case TOperator::EOpIndexDirect:
906 case TOperator::EOpIndexIndirect:
907 case TOperator::EOpIndexDirectStruct:
908 case TOperator::EOpIndexDirectInterfaceBlock:
909 UNREACHABLE();
910 return nullptr;
911 default:
912 // Any other built-in function.
913 return nullptr;
914 }
915 }
916
IsSymbolicOperator(TOperator op,const TType & resultType,const TType * argType0,const TType * argType1)917 bool IsSymbolicOperator(TOperator op,
918 const TType &resultType,
919 const TType *argType0,
920 const TType *argType1)
921 {
922 const char *operatorString = GetOperatorString(op, resultType, argType0, argType1, nullptr);
923 if (operatorString == nullptr)
924 {
925 return false;
926 }
927 return !std::isalnum(operatorString[0]);
928 }
929
getDirectField(const TIntermTyped & fieldsNode,TIntermTyped & indexNode)930 const TField &OutputWGSLTraverser::getDirectField(const TIntermTyped &fieldsNode,
931 TIntermTyped &indexNode)
932 {
933 const TType &fieldsType = fieldsNode.getType();
934
935 const TFieldListCollection *fieldListCollection = fieldsType.getStruct();
936 if (fieldListCollection == nullptr)
937 {
938 fieldListCollection = fieldsType.getInterfaceBlock();
939 }
940 ASSERT(fieldListCollection);
941
942 const TIntermConstantUnion *indexNodeAsConstantUnion = indexNode.getAsConstantUnion();
943 ASSERT(indexNodeAsConstantUnion);
944 const TConstantUnion &index = *indexNodeAsConstantUnion->getConstantValue();
945
946 ASSERT(index.getType() == TBasicType::EbtInt);
947
948 const TFieldList &fieldList = fieldListCollection->fields();
949 const int indexVal = index.getIConst();
950 const TField &field = *fieldList[indexVal];
951
952 return field;
953 }
954
emitArrayIndex(TIntermTyped & leftNode,TIntermTyped & rightNode)955 void OutputWGSLTraverser::emitArrayIndex(TIntermTyped &leftNode, TIntermTyped &rightNode)
956 {
957 TType leftType = leftNode.getType();
958
959 // Some arrays within the uniform address space have their element types wrapped in a struct
960 // when generating WGSL, so this unwraps the element (as an optimization of converting the
961 // entire array back to the unwrapped type).
962 bool needsUnwrapping = false;
963 bool isUniformMatrixNeedingConversion = false;
964 TIntermBinary *leftNodeBinary = leftNode.getAsBinaryNode();
965 if (leftNodeBinary && leftNodeBinary->getOp() == TOperator::EOpIndexDirectStruct)
966 {
967 const TStructure *structure = leftNodeBinary->getLeft()->getType().getStruct();
968
969 bool isInUniformAddressSpace =
970 mUniformBlockMetadata->structsInUniformAddressSpace.count(structure->uniqueId().get());
971
972 needsUnwrapping =
973 structure && ElementTypeNeedsUniformWrapperStruct(isInUniformAddressSpace, &leftType);
974
975 isUniformMatrixNeedingConversion = isInUniformAddressSpace && IsMatCx2(&leftType);
976
977 ASSERT(!needsUnwrapping || !isUniformMatrixNeedingConversion);
978 }
979
980 // Emit the left side, which should be of type array.
981 if (needsUnwrapping || isUniformMatrixNeedingConversion)
982 {
983 if (isUniformMatrixNeedingConversion)
984 {
985 // If this array index expression is yielding an std140 matCx2 (i.e.
986 // array<ANGLE_wrapped_vec2, C>), just convert the entire expression to a WGSL matCx2,
987 // instead of converting the entire array of std140 matCx2s into an array of WGSL
988 // matCx2s and then indexing into it.
989 TType baseType = leftType;
990 baseType.toArrayBaseType();
991 mSink << MakeMatCx2ConversionFunctionName(&baseType) << "(";
992 // Make sure the conversion function referenced here is actually generated in the
993 // resulting WGSL.
994 mWGSLGenerationMetadataForUniforms->outputMatCx2Conversion.insert(baseType);
995 }
996 emitStructIndexNoUnwrapping(leftNodeBinary);
997 }
998 else
999 {
1000 groupedTraverse(leftNode);
1001 }
1002
1003 mSink << "[";
1004 const TConstantUnion *constIndex = rightNode.getConstantValue();
1005 // If the array index is a constant that we can statically verify is within array
1006 // bounds, just emit that constant.
1007 if (!leftType.isUnsizedArray() && constIndex != nullptr && constIndex->getType() == EbtInt &&
1008 constIndex->getIConst() >= 0 &&
1009 constIndex->getIConst() < static_cast<int>(leftType.isArray()
1010 ? leftType.getOutermostArraySize()
1011 : leftType.getNominalSize()))
1012 {
1013 emitSingleConstant(constIndex);
1014 }
1015 else
1016 {
1017 // If the array index is not a constant within the bounds of the array, clamp the
1018 // index.
1019 mSink << "clamp(";
1020 groupedTraverse(rightNode);
1021 mSink << ", 0, ";
1022 // Now find the array size and clamp it.
1023 if (leftType.isUnsizedArray())
1024 {
1025 // TODO(anglebug.com/42267100): This is a bug to traverse the `leftNode` a
1026 // second time if `leftNode` has side effects (and could also have performance
1027 // implications). This should be stored in a temporary variable. This might also
1028 // be a bug in the MSL shader compiler.
1029 mSink << "arrayLength(&";
1030 groupedTraverse(leftNode);
1031 mSink << ")";
1032 }
1033 else
1034 {
1035 uint32_t maxSize;
1036 if (leftType.isArray())
1037 {
1038 maxSize = leftType.getOutermostArraySize() - 1;
1039 }
1040 else
1041 {
1042 maxSize = leftType.getNominalSize() - 1;
1043 }
1044 mSink << maxSize;
1045 }
1046 // End the clamp() function.
1047 mSink << ")";
1048 }
1049 // End the array index operation.
1050 mSink << "]";
1051
1052 if (needsUnwrapping)
1053 {
1054 mSink << "." << kWrappedStructFieldName;
1055 }
1056 else if (isUniformMatrixNeedingConversion)
1057 {
1058 // Close conversion function call
1059 mSink << ")";
1060 }
1061 }
1062
emitStructIndex(TIntermBinary * binaryNode)1063 void OutputWGSLTraverser::emitStructIndex(TIntermBinary *binaryNode)
1064 {
1065 ASSERT(binaryNode->getOp() == TOperator::EOpIndexDirectStruct);
1066 TIntermTyped &leftNode = *binaryNode->getLeft();
1067 const TType *binaryNodeType = &binaryNode->getType();
1068
1069 const TStructure *structure = leftNode.getType().getStruct();
1070 ASSERT(structure);
1071
1072 bool isInUniformAddressSpace =
1073 mUniformBlockMetadata->structsInUniformAddressSpace.count(structure->uniqueId().get());
1074
1075 bool isUniformMatrixNeedingConversion = isInUniformAddressSpace && IsMatCx2(binaryNodeType);
1076
1077 bool needsUnwrapping =
1078 ElementTypeNeedsUniformWrapperStruct(isInUniformAddressSpace, binaryNodeType);
1079 if (needsUnwrapping)
1080 {
1081 ASSERT(!isUniformMatrixNeedingConversion);
1082
1083 mSink << MakeUnwrappingArrayConversionFunctionName(&binaryNode->getType()) << "(";
1084 // Make sure the conversion function referenced here is actually generated in the resulting
1085 // WGSL.
1086 mWGSLGenerationMetadataForUniforms->arrayElementTypesThatNeedUnwrappingConversions.insert(
1087 *binaryNodeType);
1088 }
1089 else if (isUniformMatrixNeedingConversion)
1090 {
1091 mSink << MakeMatCx2ConversionFunctionName(binaryNodeType) << "(";
1092 // Make sure the conversion function referenced here is actually generated in the resulting
1093 // WGSL.
1094 mWGSLGenerationMetadataForUniforms->outputMatCx2Conversion.insert(*binaryNodeType);
1095 }
1096 emitStructIndexNoUnwrapping(binaryNode);
1097 if (needsUnwrapping || isUniformMatrixNeedingConversion)
1098 {
1099 mSink << ")";
1100 }
1101 }
1102
emitStructIndexNoUnwrapping(TIntermBinary * binaryNode)1103 void OutputWGSLTraverser::emitStructIndexNoUnwrapping(TIntermBinary *binaryNode)
1104 {
1105 ASSERT(binaryNode->getOp() == TOperator::EOpIndexDirectStruct);
1106 TIntermTyped &leftNode = *binaryNode->getLeft();
1107 TIntermTyped &rightNode = *binaryNode->getRight();
1108
1109 groupedTraverse(leftNode);
1110 mSink << ".";
1111 WriteNameOf(mSink, getDirectField(leftNode, rightNode));
1112 }
1113
visitBinary(Visit,TIntermBinary * binaryNode)1114 bool OutputWGSLTraverser::visitBinary(Visit, TIntermBinary *binaryNode)
1115 {
1116 const TOperator op = binaryNode->getOp();
1117 TIntermTyped &leftNode = *binaryNode->getLeft();
1118 TIntermTyped &rightNode = *binaryNode->getRight();
1119
1120 switch (op)
1121 {
1122 case TOperator::EOpIndexDirectStruct:
1123 case TOperator::EOpIndexDirectInterfaceBlock:
1124 emitStructIndex(binaryNode);
1125 break;
1126
1127 case TOperator::EOpIndexDirect:
1128 case TOperator::EOpIndexIndirect:
1129 emitArrayIndex(leftNode, rightNode);
1130 break;
1131
1132 default:
1133 {
1134 const TType &resultType = binaryNode->getType();
1135 const TType &leftType = leftNode.getType();
1136 const TType &rightType = rightNode.getType();
1137
1138 // x * y, x ^ y, etc.
1139 if (IsSymbolicOperator(op, resultType, &leftType, &rightType))
1140 {
1141 groupedTraverse(leftNode);
1142 if (op != TOperator::EOpComma)
1143 {
1144 mSink << " ";
1145 }
1146 mSink << GetOperatorString(op, resultType, &leftType, &rightType, nullptr) << " ";
1147 groupedTraverse(rightNode);
1148 }
1149 // E.g. builtin function calls
1150 else
1151 {
1152 mSink << GetOperatorString(op, resultType, &leftType, &rightType, nullptr) << "(";
1153 leftNode.traverse(this);
1154 mSink << ", ";
1155 rightNode.traverse(this);
1156 mSink << ")";
1157 }
1158 }
1159 }
1160
1161 return false;
1162 }
1163
IsPostfix(TOperator op)1164 bool IsPostfix(TOperator op)
1165 {
1166 switch (op)
1167 {
1168 case TOperator::EOpPostIncrement:
1169 case TOperator::EOpPostDecrement:
1170 return true;
1171
1172 default:
1173 return false;
1174 }
1175 }
1176
visitUnary(Visit,TIntermUnary * unaryNode)1177 bool OutputWGSLTraverser::visitUnary(Visit, TIntermUnary *unaryNode)
1178 {
1179 const TOperator op = unaryNode->getOp();
1180 const TType &resultType = unaryNode->getType();
1181
1182 TIntermTyped &arg = *unaryNode->getOperand();
1183 const TType &argType = arg.getType();
1184
1185 const char *name = GetOperatorString(op, resultType, &argType, nullptr, nullptr);
1186
1187 // Examples: -x, ~x, ~x
1188 if (IsSymbolicOperator(op, resultType, &argType, nullptr))
1189 {
1190 const bool postfix = IsPostfix(op);
1191 if (!postfix)
1192 {
1193 mSink << name;
1194 }
1195 groupedTraverse(arg);
1196 if (postfix)
1197 {
1198 mSink << name;
1199 }
1200 }
1201 else
1202 {
1203 mSink << name << "(";
1204 arg.traverse(this);
1205 mSink << ")";
1206 }
1207
1208 return false;
1209 }
1210
visitTernary(Visit,TIntermTernary * conditionalNode)1211 bool OutputWGSLTraverser::visitTernary(Visit, TIntermTernary *conditionalNode)
1212 {
1213 // WGSL does not have a ternary. https://github.com/gpuweb/gpuweb/issues/3747
1214 // The select() builtin is not short circuiting. Maybe we can get if () {} else {} as an
1215 // expression, which would also solve the comma operator problem.
1216 // TODO(anglebug.com/42267100): as mentioned above this is not correct if the operands have side
1217 // effects. Even if they don't have side effects it could have performance implications.
1218 // It also doesn't work with all types that ternaries do, e.g. arrays or structs.
1219 mSink << "select(";
1220 groupedTraverse(*conditionalNode->getFalseExpression());
1221 mSink << ", ";
1222 groupedTraverse(*conditionalNode->getTrueExpression());
1223 mSink << ", ";
1224 groupedTraverse(*conditionalNode->getCondition());
1225 mSink << ")";
1226
1227 return false;
1228 }
1229
visitIfElse(Visit,TIntermIfElse * ifThenElseNode)1230 bool OutputWGSLTraverser::visitIfElse(Visit, TIntermIfElse *ifThenElseNode)
1231 {
1232 TIntermTyped &condNode = *ifThenElseNode->getCondition();
1233 TIntermBlock *thenNode = ifThenElseNode->getTrueBlock();
1234 TIntermBlock *elseNode = ifThenElseNode->getFalseBlock();
1235
1236 mSink << "if (";
1237 condNode.traverse(this);
1238 mSink << ")";
1239
1240 if (thenNode)
1241 {
1242 mSink << "\n";
1243 thenNode->traverse(this);
1244 }
1245 else
1246 {
1247 mSink << " {}";
1248 }
1249
1250 if (elseNode)
1251 {
1252 mSink << "\n";
1253 emitIndentation();
1254 mSink << "else\n";
1255 elseNode->traverse(this);
1256 }
1257
1258 return false;
1259 }
1260
visitSwitch(Visit,TIntermSwitch * switchNode)1261 bool OutputWGSLTraverser::visitSwitch(Visit, TIntermSwitch *switchNode)
1262 {
1263 TIntermBlock &stmtList = *switchNode->getStatementList();
1264
1265 emitIndentation();
1266 mSink << "switch ";
1267 switchNode->getInit()->traverse(this);
1268 mSink << "\n";
1269
1270 emitOpenBrace();
1271
1272 // TODO(anglebug.com/42267100): Case statements that fall through need to combined into a single
1273 // case statement with multiple labels.
1274
1275 const size_t stmtCount = stmtList.getChildCount();
1276 bool inCaseList = false;
1277 size_t currStmt = 0;
1278 while (currStmt < stmtCount)
1279 {
1280 TIntermNode &stmtNode = *stmtList.getChildNode(currStmt);
1281 TIntermCase *caseNode = stmtNode.getAsCaseNode();
1282 if (caseNode)
1283 {
1284 if (inCaseList)
1285 {
1286 mSink << ", ";
1287 }
1288 else
1289 {
1290 emitIndentation();
1291 mSink << "case ";
1292 inCaseList = true;
1293 }
1294 caseNode->traverse(this);
1295
1296 // Process the next statement.
1297 currStmt++;
1298 }
1299 else
1300 {
1301 // The current statement is not a case statement, end the current case list and emit all
1302 // the code until the next case statement. WGSL requires braces around the case
1303 // statement's code.
1304 ASSERT(inCaseList);
1305 inCaseList = false;
1306 mSink << ":\n";
1307
1308 // Count the statements until the next case (or the end of the switch) and emit them as
1309 // a block. This assumes that the current statement list will never fallthrough to the
1310 // next case statement.
1311 size_t nextCaseStmt = currStmt + 1;
1312 for (;
1313 nextCaseStmt < stmtCount && !stmtList.getChildNode(nextCaseStmt)->getAsCaseNode();
1314 nextCaseStmt++)
1315 {
1316 }
1317 angle::Span<TIntermNode *> stmtListView(&stmtList.getSequence()->at(currStmt),
1318 nextCaseStmt - currStmt);
1319 emitBlock(stmtListView);
1320 mSink << "\n";
1321
1322 // Skip to the next case statement.
1323 currStmt = nextCaseStmt;
1324 }
1325 }
1326
1327 emitCloseBrace();
1328
1329 return false;
1330 }
1331
visitCase(Visit,TIntermCase * caseNode)1332 bool OutputWGSLTraverser::visitCase(Visit, TIntermCase *caseNode)
1333 {
1334 // "case" will have been emitted in the visitSwitch() override.
1335
1336 if (caseNode->hasCondition())
1337 {
1338 TIntermTyped *condExpr = caseNode->getCondition();
1339 condExpr->traverse(this);
1340 }
1341 else
1342 {
1343 mSink << "default";
1344 }
1345
1346 return false;
1347 }
1348
emitFunctionReturn(const TFunction & func)1349 void OutputWGSLTraverser::emitFunctionReturn(const TFunction &func)
1350 {
1351 const TType &returnType = func.getReturnType();
1352 if (returnType.getBasicType() == EbtVoid)
1353 {
1354 return;
1355 }
1356 mSink << " -> ";
1357 emitType(returnType);
1358 }
1359
1360 // TODO(anglebug.com/42267100): Function overloads are not supported in WGSL, so function names
1361 // should either be emitted mangled or overloaded functions should be renamed in the AST as a
1362 // pre-pass. As of Apr 2024, WGSL function overloads are "not coming soon"
1363 // (https://github.com/gpuweb/gpuweb/issues/876).
emitFunctionSignature(const TFunction & func)1364 void OutputWGSLTraverser::emitFunctionSignature(const TFunction &func)
1365 {
1366 mSink << "fn ";
1367
1368 WriteNameOf(mSink, func);
1369 mSink << "(";
1370
1371 bool emitComma = false;
1372 const size_t paramCount = func.getParamCount();
1373 for (size_t i = 0; i < paramCount; ++i)
1374 {
1375 if (emitComma)
1376 {
1377 mSink << ", ";
1378 }
1379 emitComma = true;
1380
1381 const TVariable ¶m = *func.getParam(i);
1382 emitFunctionParameter(func, param);
1383 }
1384
1385 mSink << ")";
1386
1387 emitFunctionReturn(func);
1388 }
1389
emitFunctionParameter(const TFunction & func,const TVariable & param)1390 void OutputWGSLTraverser::emitFunctionParameter(const TFunction &func, const TVariable ¶m)
1391 {
1392 // TODO(anglebug.com/42267100): function parameters are immutable and will need to be renamed if
1393 // they are mutated.
1394 EmitVariableDeclarationConfig evdConfig;
1395 evdConfig.isParameter = true;
1396 emitVariableDeclaration({param.symbolType(), param.name(), param.getType()}, evdConfig);
1397 }
1398
visitFunctionPrototype(TIntermFunctionPrototype * funcProtoNode)1399 void OutputWGSLTraverser::visitFunctionPrototype(TIntermFunctionPrototype *funcProtoNode)
1400 {
1401 const TFunction &func = *funcProtoNode->getFunction();
1402
1403 emitIndentation();
1404 // TODO(anglebug.com/42267100): output correct signature for main() if main() is declared as a
1405 // function prototype, or perhaps just emit nothing.
1406 emitFunctionSignature(func);
1407 }
1408
visitFunctionDefinition(Visit,TIntermFunctionDefinition * funcDefNode)1409 bool OutputWGSLTraverser::visitFunctionDefinition(Visit, TIntermFunctionDefinition *funcDefNode)
1410 {
1411 const TFunction &func = *funcDefNode->getFunction();
1412 TIntermBlock &body = *funcDefNode->getBody();
1413
1414 emitIndentation();
1415 emitFunctionSignature(func);
1416 mSink << "\n";
1417 body.traverse(this);
1418
1419 return false;
1420 }
1421
emitTextureBuiltin(const TOperator op,const TIntermSequence & args)1422 void OutputWGSLTraverser::emitTextureBuiltin(const TOperator op, const TIntermSequence &args)
1423 {
1424
1425 ASSERT(BuiltInGroup::IsTexture(op));
1426
1427 // The index in the GLSL function's argument list of each particular argument, e.g. bias.
1428 // `bias`, `lod`, `offset`, and `P` (the coordinates) are the common arguments to most texture
1429 // functions.
1430 size_t biasIndex = 0;
1431 size_t lodIndex = 0;
1432 size_t offsetIndex = 0;
1433 size_t pIndex = 0;
1434
1435 size_t dpdxIndex = 0;
1436 size_t dpdyIndex = 0;
1437
1438 // TODO(anglebug.com/389145696): These are probably incorrect translations when sampling from
1439 // integer or unsigned integer samplers. Using texture() with a usampler
1440 // is similar to using texelFetch(), except wrap modes are respected. Possibly, the correct mip
1441 // levels are also selected.
1442
1443 // The name of the equivalent texture function in WGSL.
1444 ImmutableString wgslFunctionName("");
1445 // GLSL stuffs 1, 2, or 3 arguments into a single vector. These represent the swizzles necessary
1446 // for extracting each argument from P to pass to the appropriate WGSL function.
1447 ImmutableString coordsSwizzle("");
1448 ImmutableString arrayIndexSwizzle("");
1449 ImmutableString depthRefSwizzle("");
1450 // For the projection forms of the texture builtins, the last coordinate will divide the other
1451 // three. This is just a swizzle for the last coordinate if the builtin call includes
1452 // projection.
1453 ImmutableString projectionDivisionSwizzle("");
1454
1455 ImmutableString wgslTextureVarName("");
1456 ImmutableString wgslSamplerVarName("");
1457
1458 constexpr char k2DCoordsSwizzle[] = ".xy";
1459 constexpr char k3DCoordsSwizzle[] = ".xyz";
1460
1461 constexpr char kPossibleElems[] = "xyzw";
1462
1463 // MonomorphizeUnsupportedFunctions() and RewriteStructSamplers() ensure that this is a
1464 // reference to the global sampler.
1465 TIntermSymbol *samplerNode = args[0]->getAsSymbolNode();
1466 // TODO(anglebug.com/389145696): this will fail if it's an array of samplers, which isn't yet
1467 // handled.
1468 if (!samplerNode)
1469 {
1470 UNIMPLEMENTED();
1471 mSink << "TODO_UNHANDLED_TEXTURE_FUNCTION()";
1472 return;
1473 }
1474 TBasicType samplerType = samplerNode->getType().getBasicType();
1475 ASSERT(IsSampler(samplerType));
1476
1477 bool isProj = false;
1478
1479 auto setWgslTextureVarName = [&]() {
1480 wgslTextureVarName =
1481 BuildConcatenatedImmutableString(kAngleTexturePrefix, samplerNode->getName());
1482 };
1483
1484 auto setWgslSamplerVarName = [&]() {
1485 wgslSamplerVarName =
1486 BuildConcatenatedImmutableString(kAngleSamplerPrefix, samplerNode->getName());
1487 };
1488
1489 auto setTextureSampleFunctionNameFromBias = [&]() {
1490 // TODO(anglebug.com/389145696): these are incorrect translations in vertex shaders, where
1491 // they should probably use textureLoad() (and textureDimensions()).
1492 if (IsShadowSampler(samplerType))
1493 {
1494 if (biasIndex != 0)
1495 {
1496 // TODO(anglebug.com/389145696): WGSL doesn't support using bias with shadow
1497 // samplers.
1498 UNIMPLEMENTED();
1499 wgslFunctionName = ImmutableString("TODO_CANNOT_USE_BIAS_WITH_SHADOW_SAMPLER");
1500 }
1501 else
1502 {
1503 wgslFunctionName = ImmutableString("textureSampleCompare");
1504 }
1505 }
1506 else
1507 {
1508 if (biasIndex == 0)
1509 {
1510 wgslFunctionName = ImmutableString("textureSample");
1511 }
1512 else
1513 {
1514
1515 wgslFunctionName = ImmutableString("textureSampleBias");
1516 }
1517 }
1518 };
1519
1520 switch (op)
1521 {
1522 case EOpTextureSize:
1523 {
1524 lodIndex = 1;
1525 ASSERT(args.size() == 2);
1526
1527 wgslFunctionName = ImmutableString("textureDimensions");
1528 setWgslTextureVarName();
1529 }
1530 break;
1531
1532 case EOpTexelFetchOffset:
1533 case EOpTexelFetch:
1534 {
1535 pIndex = 1;
1536 lodIndex = 2;
1537 if (args.size() == 4)
1538 {
1539 offsetIndex = 3;
1540 }
1541 ASSERT(args.size() == 3 || args.size() == 4);
1542
1543 wgslFunctionName = ImmutableString("textureLoad");
1544 setWgslTextureVarName();
1545 }
1546 break;
1547
1548 // texture() use to be split into texture2D() and textureCube(). WGSL matches GLSL 3.0 and
1549 // combines them.
1550 case EOpTextureProj:
1551 case EOpTexture2DProj:
1552 case EOpTextureProjBias:
1553 case EOpTexture2DProjBias:
1554 isProj = true;
1555 [[fallthrough]];
1556 case EOpTexture:
1557 case EOpTexture2D:
1558 case EOpTextureCube:
1559 case EOpTextureBias:
1560 case EOpTexture2DBias:
1561 case EOpTextureCubeBias:
1562 {
1563 pIndex = 1;
1564 if (args.size() == 3)
1565 {
1566 biasIndex = 2;
1567 }
1568 ASSERT(args.size() == 2 || args.size() == 3);
1569
1570 setTextureSampleFunctionNameFromBias();
1571 setWgslTextureVarName();
1572 setWgslSamplerVarName();
1573 }
1574 break;
1575
1576 case EOpTextureProjLod:
1577 case EOpTextureProjLodOffset:
1578 case EOpTexture2DProjLodVS:
1579 case EOpTexture2DProjLodEXTFS:
1580 isProj = true;
1581 [[fallthrough]];
1582 case EOpTextureLod:
1583 case EOpTexture2DLodVS:
1584 case EOpTextureCubeLodVS:
1585 case EOpTexture2DLodEXTFS:
1586 case EOpTextureCubeLodEXTFS:
1587 case EOpTextureLodOffset:
1588 {
1589 pIndex = 1;
1590 lodIndex = 2;
1591 if (args.size() == 4)
1592 {
1593 offsetIndex = 3;
1594 }
1595 ASSERT(args.size() == 3 || args.size() == 4);
1596
1597 if (IsShadowSampler(samplerType))
1598 {
1599 // TODO(anglebug.com/389145696): WGSL may not support explicit LOD with shadow
1600 // samplers. textureSampleCompareLevel() only uses mip level 0.
1601 UNIMPLEMENTED();
1602 wgslFunctionName =
1603 ImmutableString("TODO_CANNOT_USE_EXPLICIT_LOD_WITH_SHADOW_SAMPLER");
1604 }
1605 else
1606 {
1607 wgslFunctionName = ImmutableString("textureSampleLevel");
1608 }
1609 setWgslTextureVarName();
1610 setWgslSamplerVarName();
1611 }
1612 break;
1613
1614 case EOpTextureProjOffset:
1615 case EOpTextureProjOffsetBias:
1616 isProj = true;
1617 [[fallthrough]];
1618 case EOpTextureOffset:
1619 case EOpTextureOffsetBias:
1620 {
1621 pIndex = 1;
1622 offsetIndex = 2;
1623 if (args.size() == 4)
1624 {
1625 biasIndex = 3;
1626 }
1627 ASSERT(args.size() == 3 || args.size() == 4);
1628
1629 setTextureSampleFunctionNameFromBias();
1630 setWgslTextureVarName();
1631 setWgslSamplerVarName();
1632 }
1633 break;
1634
1635 case EOpTextureProjGrad:
1636 case EOpTextureProjGradOffset:
1637 isProj = true;
1638 [[fallthrough]];
1639 case EOpTextureGrad:
1640 case EOpTextureGradOffset:
1641 {
1642 pIndex = 1;
1643 dpdxIndex = 2;
1644 dpdyIndex = 3;
1645 if (args.size() == 5)
1646 {
1647 offsetIndex = 4;
1648 }
1649 ASSERT(args.size() == 4 || args.size() == 5);
1650
1651 if (IsShadowSampler(samplerType))
1652 {
1653 // TODO(anglebug.com/389145696): WGSL may not support explicit gradients with shadow
1654 // samplers.
1655 UNIMPLEMENTED();
1656 wgslFunctionName =
1657 ImmutableString("TODO_CANNOT_USE_EXPLICIT_GRAD_WITH_SHADOW_SAMPLER");
1658 }
1659 else
1660 {
1661 wgslFunctionName = ImmutableString("textureSampleGrad");
1662 }
1663 setWgslTextureVarName();
1664 setWgslSamplerVarName();
1665 }
1666 break;
1667
1668 default:
1669 UNIMPLEMENTED();
1670 mSink << "TODO_UNHANDLED_TEXTURE_FUNCTION()";
1671 return;
1672 }
1673
1674 mSink << wgslFunctionName << "(";
1675
1676 ASSERT(!wgslTextureVarName.empty());
1677 mSink << wgslTextureVarName;
1678
1679 if (!wgslSamplerVarName.empty())
1680 {
1681 mSink << ", " << wgslSamplerVarName;
1682
1683 // If using a projection division, set the swizzle that extracts the last argument from the
1684 // p vector.
1685 if (isProj)
1686 {
1687 ASSERT(pIndex == 1);
1688 const uint8_t vecSize = args[pIndex]->getAsTyped()->getNominalSize();
1689 ASSERT(vecSize == 3 || vecSize == 4);
1690 projectionDivisionSwizzle =
1691 BuildConcatenatedImmutableString('.', kPossibleElems[vecSize - 1]);
1692 }
1693
1694 // If sampling from an array, set the swizzle that extracts the array layer number from the
1695 // p vector.
1696 if (IsSampler2DArray(samplerType))
1697 {
1698 arrayIndexSwizzle = ImmutableString(".z");
1699 }
1700
1701 // If sampling from a shadow samplers, set the swizzle that extracts the D_ref argument from
1702 // the p vector.
1703 if (IsShadowSampler(samplerType))
1704 {
1705 size_t elemIndex = 0;
1706 if (IsSampler2D(samplerType))
1707 {
1708 elemIndex = 2;
1709 }
1710 else if (IsSampler2DArray(samplerType) || IsSampler3D(samplerType) ||
1711 IsSamplerCube(samplerType))
1712 {
1713 elemIndex = 3;
1714 }
1715
1716 depthRefSwizzle = BuildConcatenatedImmutableString('.', kPossibleElems[elemIndex]);
1717 }
1718
1719 // Finally, set the swizzle for extracting coordinates from the p vector.
1720 if (IsSampler2D(samplerType) || IsSampler2DArray(samplerType))
1721 {
1722 coordsSwizzle = ImmutableString(k2DCoordsSwizzle);
1723 }
1724 else if (IsSampler3D(samplerType) || IsSamplerCube(samplerType))
1725 {
1726 coordsSwizzle = ImmutableString(k3DCoordsSwizzle);
1727 }
1728 }
1729
1730 // TODO(anglebug.com/389145696): traversing the pArg multiple times is an error if it ever
1731 // contains side effects (e.g. a function call). There is also a problem if this traverses
1732 // function arguments in a different order, arguments with side effects that effect arguments
1733 // that come later may be reordered incorrectly. ESSL specs defined function argument evaluation
1734 // as left-to-right.
1735 auto traversePArg = [&]() {
1736 mSink << "(";
1737 ASSERT(pIndex != 0);
1738 args[pIndex]->traverse(this);
1739 mSink << ")";
1740 };
1741
1742 auto outputProjectionDivisionIfNecessary = [&]() {
1743 if (projectionDivisionSwizzle.empty())
1744 {
1745 return;
1746 }
1747 mSink << " / ";
1748 traversePArg();
1749 mSink << projectionDivisionSwizzle;
1750 };
1751
1752 // The arguments to the WGSL function always appear in a certain (partial) order, so output them
1753 // in that order.
1754 //
1755 // The order is always
1756 // - texture
1757 // - sampler
1758 // - coordinates
1759 // - array layer index
1760 // - depth_ref, bias, explicit level of detail (never appear together)
1761 // - dfdx
1762 // - dfdy
1763 // - offset
1764 //
1765 // See the texture builtin functions in the WGSL spec:
1766 // https://www.w3.org/TR/WGSL/#texture-builtin-functions
1767 //
1768 // For example
1769 // @must_use fn textureSampleLevel(t: texture_2d_array<f32>,
1770 // s: sampler,
1771 // coords: vec2<f32>,
1772 // array_index: A,
1773 // level: f32,
1774 // offset: vec2<i32>) -> vec4<f32>
1775
1776 if (pIndex != 0)
1777 {
1778 mSink << ", ";
1779 traversePArg();
1780 mSink << coordsSwizzle;
1781 outputProjectionDivisionIfNecessary();
1782 }
1783
1784 if (!arrayIndexSwizzle.empty())
1785 {
1786 mSink << ", ";
1787 traversePArg();
1788 mSink << arrayIndexSwizzle;
1789 }
1790
1791 if (!depthRefSwizzle.empty())
1792 {
1793 mSink << ", ";
1794 traversePArg();
1795 mSink << depthRefSwizzle;
1796 outputProjectionDivisionIfNecessary();
1797 }
1798
1799 if (biasIndex != 0)
1800 {
1801 mSink << ", ";
1802 args[biasIndex]->traverse(this);
1803 }
1804
1805 if (lodIndex != 0)
1806 {
1807 mSink << ", ";
1808 args[lodIndex]->traverse(this);
1809 }
1810
1811 if (dpdxIndex != 0)
1812 {
1813 mSink << ", ";
1814 args[dpdxIndex]->traverse(this);
1815 }
1816
1817 if (dpdyIndex != 0)
1818 {
1819 mSink << ", ";
1820 args[dpdyIndex]->traverse(this);
1821 }
1822
1823 if (offsetIndex != 0)
1824 {
1825 mSink << ", ";
1826 // Both GLSL and WGSL require this to be a const expression.
1827 args[offsetIndex]->traverse(this);
1828 }
1829
1830 mSink << ")";
1831 }
1832
visitAggregate(Visit,TIntermAggregate * aggregateNode)1833 bool OutputWGSLTraverser::visitAggregate(Visit, TIntermAggregate *aggregateNode)
1834 {
1835 const TIntermSequence &args = *aggregateNode->getSequence();
1836
1837 auto emitArgList = [&]() {
1838 mSink << "(";
1839
1840 bool emitComma = false;
1841 for (TIntermNode *arg : args)
1842 {
1843 if (emitComma)
1844 {
1845 mSink << ", ";
1846 }
1847 emitComma = true;
1848 arg->traverse(this);
1849 }
1850
1851 mSink << ")";
1852 };
1853
1854 const TType &retType = aggregateNode->getType();
1855
1856 if (aggregateNode->isConstructor())
1857 {
1858
1859 emitType(retType);
1860 emitArgList();
1861
1862 return false;
1863 }
1864 else
1865 {
1866 const TOperator op = aggregateNode->getOp();
1867 switch (op)
1868 {
1869 case TOperator::EOpCallFunctionInAST:
1870 WriteNameOf(mSink, *aggregateNode->getFunction());
1871 emitArgList();
1872 return false;
1873
1874 default:
1875 // Do not allow raw function calls, i.e. calls to functions
1876 // not present in the AST.
1877 ASSERT(op != TOperator::EOpCallInternalRawFunction);
1878 auto getArgType = [&](size_t index) -> const TType * {
1879 if (index < args.size())
1880 {
1881 TIntermTyped *arg = args[index]->getAsTyped();
1882 ASSERT(arg);
1883 return &arg->getType();
1884 }
1885 return nullptr;
1886 };
1887
1888 const TType *argType0 = getArgType(0);
1889 const TType *argType1 = getArgType(1);
1890 const TType *argType2 = getArgType(2);
1891
1892 const char *opName = GetOperatorString(op, retType, argType0, argType1, argType2);
1893
1894 if (IsSymbolicOperator(op, retType, argType0, argType1))
1895 {
1896 switch (args.size())
1897 {
1898 case 1:
1899 {
1900 TIntermNode &operandNode = *aggregateNode->getChildNode(0);
1901 if (IsPostfix(op))
1902 {
1903 mSink << opName;
1904 groupedTraverse(operandNode);
1905 }
1906 else
1907 {
1908 groupedTraverse(operandNode);
1909 mSink << opName;
1910 }
1911 return false;
1912 }
1913
1914 case 2:
1915 {
1916 // symbolic operators with 2 args are emitted with infix notation.
1917 TIntermNode &leftNode = *aggregateNode->getChildNode(0);
1918 TIntermNode &rightNode = *aggregateNode->getChildNode(1);
1919 groupedTraverse(leftNode);
1920 mSink << " " << opName << " ";
1921 groupedTraverse(rightNode);
1922 return false;
1923 }
1924
1925 default:
1926 UNREACHABLE();
1927 return false;
1928 }
1929 }
1930 else
1931 {
1932 // Rewrite the calls to sampler functions.
1933 if (BuiltInGroup::IsTexture(op))
1934 {
1935 emitTextureBuiltin(op, args);
1936 return false;
1937 }
1938 // If the operator is not symbolic then it is a builtin that uses function call
1939 // syntax: builtin(arg1, arg2, ..);
1940 mSink << (opName == nullptr ? "TODO_Operator" : opName);
1941 emitArgList();
1942 return false;
1943 }
1944 }
1945 }
1946 }
1947
emitBlock(angle::Span<TIntermNode * > nodes)1948 bool OutputWGSLTraverser::emitBlock(angle::Span<TIntermNode *> nodes)
1949 {
1950 ASSERT(mIndentLevel >= -1);
1951 const bool isGlobalScope = mIndentLevel == -1;
1952
1953 if (isGlobalScope)
1954 {
1955 ++mIndentLevel;
1956 }
1957 else
1958 {
1959 emitOpenBrace();
1960 }
1961
1962 TIntermNode *prevStmtNode = nullptr;
1963
1964 const size_t stmtCount = nodes.size();
1965 for (size_t i = 0; i < stmtCount; ++i)
1966 {
1967 TIntermNode &stmtNode = *nodes[i];
1968
1969 if (isGlobalScope && prevStmtNode && (NewlinePad(*prevStmtNode) || NewlinePad(stmtNode)))
1970 {
1971 mSink << "\n";
1972 }
1973 const bool isCase = stmtNode.getAsCaseNode();
1974 mIndentLevel -= isCase;
1975 emitIndentation();
1976 mIndentLevel += isCase;
1977 stmtNode.traverse(this);
1978 if (RequiresSemicolonTerminator(stmtNode))
1979 {
1980 mSink << ";";
1981 }
1982 mSink << "\n";
1983
1984 prevStmtNode = &stmtNode;
1985 }
1986
1987 if (isGlobalScope)
1988 {
1989 ASSERT(mIndentLevel == 0);
1990 --mIndentLevel;
1991 }
1992 else
1993 {
1994 emitCloseBrace();
1995 }
1996
1997 return false;
1998 }
1999
visitBlock(Visit,TIntermBlock * blockNode)2000 bool OutputWGSLTraverser::visitBlock(Visit, TIntermBlock *blockNode)
2001 {
2002 return emitBlock(
2003 angle::Span(blockNode->getSequence()->data(), blockNode->getSequence()->size()));
2004 }
2005
visitGlobalQualifierDeclaration(Visit,TIntermGlobalQualifierDeclaration *)2006 bool OutputWGSLTraverser::visitGlobalQualifierDeclaration(Visit,
2007 TIntermGlobalQualifierDeclaration *)
2008 {
2009 return false;
2010 }
2011
emitStructDeclaration(const TType & type)2012 void OutputWGSLTraverser::emitStructDeclaration(const TType &type)
2013 {
2014 ASSERT(type.getBasicType() == TBasicType::EbtStruct);
2015 ASSERT(type.isStructSpecifier());
2016
2017 mSink << "struct ";
2018 emitBareTypeName(type);
2019
2020 mSink << "\n";
2021 emitOpenBrace();
2022
2023 const TStructure &structure = *type.getStruct();
2024 bool isInUniformAddressSpace =
2025 mUniformBlockMetadata->structsInUniformAddressSpace.count(structure.uniqueId().get()) != 0;
2026
2027 bool alignTo16InUniformAddressSpace = true;
2028 for (const TField *field : structure.fields())
2029 {
2030 const TType *fieldType = field->type();
2031
2032 emitIndentation();
2033 // If this struct is used in the uniform address space, it must obey the uniform address
2034 // space's layout constaints (https://www.w3.org/TR/WGSL/#address-space-layout-constraints).
2035 // WGSL's address space layout constraints nearly match std140, and the places they don't
2036 // are handled elsewhere.
2037 if (isInUniformAddressSpace)
2038 {
2039 // Here, the field must be aligned to 16 if:
2040 // 1. The field is a struct or array (note that matCx2 is represented as an array of
2041 // vec2)
2042 // 2. The previous field is a struct
2043 // 3. The field is the first in the struct (for convenience).
2044 if (field->type()->getStruct() || fieldType->isArray() || IsMatCx2(fieldType))
2045 {
2046 alignTo16InUniformAddressSpace = true;
2047 }
2048 if (alignTo16InUniformAddressSpace)
2049 {
2050 mSink << "@align(16) ";
2051 }
2052
2053 // If this field is a struct, the next member should be aligned to 16.
2054 alignTo16InUniformAddressSpace = fieldType->getStruct();
2055
2056 // If the field is an array whose stride is not aligned to 16, the element type must be
2057 // emitted with a wrapper struct. Record that the wrapper struct needs to be emitted.
2058 // Note that if the array element type is already of struct type, it doesn't need
2059 // another wrapper struct, it will automatically be aligned to 16 because its first
2060 // member is aligned to 16 (implemented above).
2061 if (ElementTypeNeedsUniformWrapperStruct(/*inUniformAddressSpace=*/true, fieldType))
2062 {
2063 TType innerType = *fieldType;
2064 innerType.toArrayElementType();
2065 // Multidimensional arrays not currently supported in uniforms in the WebGPU backend
2066 ASSERT(!innerType.isArray());
2067 mWGSLGenerationMetadataForUniforms->arrayElementTypesInUniforms.insert(innerType);
2068 }
2069 }
2070
2071 // TODO(anglebug.com/42267100): emit qualifiers.
2072 EmitVariableDeclarationConfig evdConfig;
2073 evdConfig.typeConfig.addressSpace =
2074 isInUniformAddressSpace ? WgslAddressSpace::Uniform : WgslAddressSpace::NonUniform;
2075 evdConfig.disableStructSpecifier = true;
2076 emitVariableDeclaration({field->symbolType(), field->name(), *fieldType}, evdConfig);
2077 mSink << ",\n";
2078 }
2079
2080 emitCloseBrace();
2081 }
2082
emitVariableDeclaration(const VarDecl & decl,const EmitVariableDeclarationConfig & evdConfig)2083 void OutputWGSLTraverser::emitVariableDeclaration(const VarDecl &decl,
2084 const EmitVariableDeclarationConfig &evdConfig)
2085 {
2086 const TBasicType basicType = decl.type.getBasicType();
2087
2088 if (decl.type.getQualifier() == EvqUniform)
2089 {
2090 // Uniforms are declared in a pre-pass, and don't need to be outputted here.
2091 return;
2092 }
2093
2094 if (basicType == TBasicType::EbtStruct && decl.type.isStructSpecifier() &&
2095 !evdConfig.disableStructSpecifier)
2096 {
2097 // TODO(anglebug.com/42267100): in WGSL structs probably can't be declared in
2098 // function parameters or in uniform declarations or in variable declarations, or
2099 // anonymously either within other structs or within a variable declaration. Handle
2100 // these with the same AST pre-passes as other shader translators.
2101 ASSERT(!evdConfig.isParameter);
2102 emitStructDeclaration(decl.type);
2103 if (decl.symbolType != SymbolType::Empty)
2104 {
2105 mSink << " ";
2106 emitNameOf(decl);
2107 }
2108 return;
2109 }
2110
2111 ASSERT(basicType == TBasicType::EbtStruct || decl.symbolType != SymbolType::Empty ||
2112 evdConfig.isParameter);
2113
2114 if (evdConfig.needsVar)
2115 {
2116 // "const" and "let" probably don't need to be ever emitted because they are more for
2117 // readability, and the GLSL compiler constant folds most (all?) the consts anyway.
2118 mSink << "var";
2119 // TODO(anglebug.com/42267100): <workgroup> or <storage>?
2120 if (evdConfig.isGlobalScope)
2121 {
2122 if (decl.type.getQualifier() == EvqUniform)
2123 {
2124 ASSERT(IsOpaqueType(decl.type.getBasicType()));
2125 mSink << "<uniform>";
2126 }
2127 else
2128 {
2129 mSink << "<private>";
2130 }
2131 }
2132 mSink << " ";
2133 }
2134 else
2135 {
2136 ASSERT(!evdConfig.isGlobalScope);
2137 }
2138
2139 if (decl.symbolType != SymbolType::Empty)
2140 {
2141 emitNameOf(decl);
2142 }
2143 mSink << " : ";
2144 WriteWgslType(mSink, decl.type, evdConfig.typeConfig);
2145 }
2146
visitDeclaration(Visit,TIntermDeclaration * declNode)2147 bool OutputWGSLTraverser::visitDeclaration(Visit, TIntermDeclaration *declNode)
2148 {
2149 ASSERT(declNode->getChildCount() == 1);
2150 TIntermNode &node = *declNode->getChildNode(0);
2151
2152 EmitVariableDeclarationConfig evdConfig;
2153 evdConfig.needsVar = true;
2154 evdConfig.isGlobalScope = mIndentLevel == 0;
2155
2156 if (TIntermSymbol *symbolNode = node.getAsSymbolNode())
2157 {
2158 const TVariable &var = symbolNode->variable();
2159 if (mRewritePipelineVarOutput->IsInputVar(var.uniqueId()) ||
2160 mRewritePipelineVarOutput->IsOutputVar(var.uniqueId()))
2161 {
2162 // Some variables, like shader inputs/outputs/builtins, are declared in the WGSL source
2163 // outside of the traverser.
2164 return false;
2165 }
2166 emitVariableDeclaration({var.symbolType(), var.name(), var.getType()}, evdConfig);
2167 }
2168 else if (TIntermBinary *initNode = node.getAsBinaryNode())
2169 {
2170 ASSERT(initNode->getOp() == TOperator::EOpInitialize);
2171 TIntermSymbol *leftSymbolNode = initNode->getLeft()->getAsSymbolNode();
2172 TIntermTyped *valueNode = initNode->getRight()->getAsTyped();
2173 ASSERT(leftSymbolNode && valueNode);
2174
2175 const TVariable &var = leftSymbolNode->variable();
2176 if (mRewritePipelineVarOutput->IsInputVar(var.uniqueId()) ||
2177 mRewritePipelineVarOutput->IsOutputVar(var.uniqueId()))
2178 {
2179 // Some variables, like shader inputs/outputs/builtins, are declared in the WGSL source
2180 // outside of the traverser.
2181 return false;
2182 }
2183
2184 emitVariableDeclaration({var.symbolType(), var.name(), var.getType()}, evdConfig);
2185 mSink << " = ";
2186 groupedTraverse(*valueNode);
2187 }
2188 else
2189 {
2190 UNREACHABLE();
2191 }
2192
2193 return false;
2194 }
2195
visitLoop(Visit,TIntermLoop * loopNode)2196 bool OutputWGSLTraverser::visitLoop(Visit, TIntermLoop *loopNode)
2197 {
2198 const TLoopType loopType = loopNode->getType();
2199
2200 switch (loopType)
2201 {
2202 case TLoopType::ELoopFor:
2203 return emitForLoop(loopNode);
2204 case TLoopType::ELoopWhile:
2205 return emitWhileLoop(loopNode);
2206 case TLoopType::ELoopDoWhile:
2207 return emulateDoWhileLoop(loopNode);
2208 }
2209 }
2210
emitForLoop(TIntermLoop * loopNode)2211 bool OutputWGSLTraverser::emitForLoop(TIntermLoop *loopNode)
2212 {
2213 ASSERT(loopNode->getType() == TLoopType::ELoopFor);
2214
2215 TIntermNode *initNode = loopNode->getInit();
2216 TIntermTyped *condNode = loopNode->getCondition();
2217 TIntermTyped *exprNode = loopNode->getExpression();
2218
2219 mSink << "for (";
2220
2221 if (initNode)
2222 {
2223 initNode->traverse(this);
2224 }
2225 else
2226 {
2227 mSink << " ";
2228 }
2229
2230 mSink << "; ";
2231
2232 if (condNode)
2233 {
2234 condNode->traverse(this);
2235 }
2236
2237 mSink << "; ";
2238
2239 if (exprNode)
2240 {
2241 exprNode->traverse(this);
2242 }
2243
2244 mSink << ")\n";
2245
2246 loopNode->getBody()->traverse(this);
2247
2248 return false;
2249 }
2250
emitWhileLoop(TIntermLoop * loopNode)2251 bool OutputWGSLTraverser::emitWhileLoop(TIntermLoop *loopNode)
2252 {
2253 ASSERT(loopNode->getType() == TLoopType::ELoopWhile);
2254
2255 TIntermNode *initNode = loopNode->getInit();
2256 TIntermTyped *condNode = loopNode->getCondition();
2257 TIntermTyped *exprNode = loopNode->getExpression();
2258 ASSERT(condNode);
2259 ASSERT(!initNode && !exprNode);
2260
2261 emitIndentation();
2262 mSink << "while (";
2263 condNode->traverse(this);
2264 mSink << ")\n";
2265 loopNode->getBody()->traverse(this);
2266
2267 return false;
2268 }
2269
emulateDoWhileLoop(TIntermLoop * loopNode)2270 bool OutputWGSLTraverser::emulateDoWhileLoop(TIntermLoop *loopNode)
2271 {
2272 ASSERT(loopNode->getType() == TLoopType::ELoopDoWhile);
2273
2274 TIntermNode *initNode = loopNode->getInit();
2275 TIntermTyped *condNode = loopNode->getCondition();
2276 TIntermTyped *exprNode = loopNode->getExpression();
2277 ASSERT(condNode);
2278 ASSERT(!initNode && !exprNode);
2279
2280 emitIndentation();
2281 // Write an infinite loop.
2282 mSink << "loop {\n";
2283 mIndentLevel++;
2284 loopNode->getBody()->traverse(this);
2285 mSink << "\n";
2286 emitIndentation();
2287 // At the end of the loop, break if the loop condition dos not still hold.
2288 mSink << "if (!(";
2289 condNode->traverse(this);
2290 mSink << ") { break; }\n";
2291 mIndentLevel--;
2292 emitIndentation();
2293 mSink << "}";
2294
2295 return false;
2296 }
2297
visitBranch(Visit,TIntermBranch * branchNode)2298 bool OutputWGSLTraverser::visitBranch(Visit, TIntermBranch *branchNode)
2299 {
2300 const TOperator flowOp = branchNode->getFlowOp();
2301 TIntermTyped *exprNode = branchNode->getExpression();
2302
2303 emitIndentation();
2304
2305 switch (flowOp)
2306 {
2307 case TOperator::EOpKill:
2308 {
2309 ASSERT(exprNode == nullptr);
2310 mSink << "discard";
2311 }
2312 break;
2313
2314 case TOperator::EOpReturn:
2315 {
2316 mSink << "return";
2317 if (exprNode)
2318 {
2319 mSink << " ";
2320 exprNode->traverse(this);
2321 }
2322 }
2323 break;
2324
2325 case TOperator::EOpBreak:
2326 {
2327 ASSERT(exprNode == nullptr);
2328 mSink << "break";
2329 }
2330 break;
2331
2332 case TOperator::EOpContinue:
2333 {
2334 ASSERT(exprNode == nullptr);
2335 mSink << "continue";
2336 }
2337 break;
2338
2339 default:
2340 {
2341 UNREACHABLE();
2342 }
2343 }
2344
2345 return false;
2346 }
2347
visitPreprocessorDirective(TIntermPreprocessorDirective * node)2348 void OutputWGSLTraverser::visitPreprocessorDirective(TIntermPreprocessorDirective *node)
2349 {
2350 // No preprocessor directives expected at this point.
2351 UNREACHABLE();
2352 }
2353
emitBareTypeName(const TType & type)2354 void OutputWGSLTraverser::emitBareTypeName(const TType &type)
2355 {
2356 WriteWgslBareTypeName(mSink, type, {});
2357 }
2358
emitType(const TType & type)2359 void OutputWGSLTraverser::emitType(const TType &type)
2360 {
2361 WriteWgslType(mSink, type, {});
2362 }
2363
2364 } // namespace
2365
TranslatorWGSL(sh::GLenum type,ShShaderSpec spec,ShShaderOutput output)2366 TranslatorWGSL::TranslatorWGSL(sh::GLenum type, ShShaderSpec spec, ShShaderOutput output)
2367 : TCompiler(type, spec, output)
2368 {}
2369
preTranslateTreeModifications(TIntermBlock * root)2370 bool TranslatorWGSL::preTranslateTreeModifications(TIntermBlock *root)
2371 {
2372
2373 int aggregateTypesUsedForUniforms = 0;
2374 for (const auto &uniform : getUniforms())
2375 {
2376 if (uniform.isStruct() || uniform.isArrayOfArrays())
2377 {
2378 ++aggregateTypesUsedForUniforms;
2379 }
2380 }
2381
2382 // Samplers are legal as function parameters, but samplers within structs or arrays are not
2383 // allowed in WGSL
2384 // (https://www.w3.org/TR/WGSL/#function-call-expr:~:text=A%20function%20parameter,a%20sampler%20type).
2385 // TODO(anglebug.com/389145696): handle arrays of samplers here.
2386
2387 // If there are any function calls that take array-of-array of opaque uniform parameters, or
2388 // other opaque uniforms that need special handling in WebGPU, monomorphize the functions by
2389 // removing said parameters and replacing them in the function body with the call arguments.
2390 //
2391 // This dramatically simplifies future transformations w.r.t to samplers in structs, array of
2392 // arrays of opaque types, atomic counters etc.
2393 UnsupportedFunctionArgsBitSet args{UnsupportedFunctionArgs::StructContainingSamplers,
2394 UnsupportedFunctionArgs::ArrayOfArrayOfSamplerOrImage,
2395 UnsupportedFunctionArgs::AtomicCounter,
2396 UnsupportedFunctionArgs::Image};
2397 if (!MonomorphizeUnsupportedFunctions(this, root, &getSymbolTable(), args))
2398 {
2399 return false;
2400 }
2401
2402 if (aggregateTypesUsedForUniforms > 0)
2403 {
2404 if (!SeparateStructFromUniformDeclarations(this, root, &getSymbolTable()))
2405 {
2406 return false;
2407 }
2408
2409 int removedUniformsCount;
2410
2411 // Requires MonomorphizeUnsupportedFunctions() to have been run already.
2412 if (!RewriteStructSamplers(this, root, &getSymbolTable(), &removedUniformsCount))
2413 {
2414 return false;
2415 }
2416 }
2417
2418 // Replace array of array of opaque uniforms with a flattened array. This is run after
2419 // MonomorphizeUnsupportedFunctions and RewriteStructSamplers so that it's not possible for an
2420 // array of array of opaque type to be partially subscripted and passed to a function.
2421 // TODO(anglebug.com/389145696): Even single-level arrays of samplers are not allowed in WGSL.
2422 if (!RewriteArrayOfArrayOfOpaqueUniforms(this, root, &getSymbolTable()))
2423 {
2424 return false;
2425 }
2426
2427 return true;
2428 }
2429
translate(TIntermBlock * root,const ShCompileOptions & compileOptions,PerformanceDiagnostics * perfDiagnostics)2430 bool TranslatorWGSL::translate(TIntermBlock *root,
2431 const ShCompileOptions &compileOptions,
2432 PerformanceDiagnostics *perfDiagnostics)
2433 {
2434 if (kOutputTreeBeforeTranslation)
2435 {
2436 OutputTree(root, getInfoSink().info);
2437 std::cout << getInfoSink().info.c_str();
2438 }
2439
2440 if (!preTranslateTreeModifications(root))
2441 {
2442 return false;
2443 }
2444 enableValidateNoMoreTransformations();
2445
2446 RewritePipelineVarOutput rewritePipelineVarOutput(getShaderType());
2447 WGSLGenerationMetadataForUniforms wgslGenerationMetadataForUniforms;
2448
2449 // WGSL's main() will need to take parameters or return values if any glsl (input/output)
2450 // builtin variables are used.
2451 if (!GenerateMainFunctionAndIOStructs(*this, *root, rewritePipelineVarOutput))
2452 {
2453 return false;
2454 }
2455
2456 TInfoSinkBase &sink = getInfoSink().obj;
2457 // Start writing the output structs that will be referred to by the `traverser`'s output.'
2458 if (!rewritePipelineVarOutput.OutputStructs(sink))
2459 {
2460 return false;
2461 }
2462
2463 if (!OutputUniformBlocksAndSamplers(this, root))
2464 {
2465 return false;
2466 }
2467
2468 UniformBlockMetadata uniformBlockMetadata;
2469 if (!RecordUniformBlockMetadata(root, uniformBlockMetadata))
2470 {
2471 return false;
2472 }
2473
2474 // Generate the body of the WGSL including the GLSL main() function.
2475 TInfoSinkBase traverserOutput;
2476 OutputWGSLTraverser traverser(&traverserOutput, &rewritePipelineVarOutput,
2477 &uniformBlockMetadata, &wgslGenerationMetadataForUniforms);
2478 root->traverse(&traverser);
2479
2480 sink << "\n";
2481 OutputUniformWrapperStructsAndConversions(sink, wgslGenerationMetadataForUniforms);
2482
2483 // The traverser output needs to be in the code after uniform wrapper structs are emitted above,
2484 // since the traverser code references the wrapper struct types.
2485 sink << traverserOutput.str();
2486
2487 // Write the actual WGSL main function, wgslMain(), which calls the GLSL main function.
2488 if (!rewritePipelineVarOutput.OutputMainFunction(sink))
2489 {
2490 return false;
2491 }
2492
2493 if (kOutputTranslatedShader)
2494 {
2495 std::cout << sink.str();
2496 }
2497
2498 return true;
2499 }
2500
shouldFlattenPragmaStdglInvariantAll()2501 bool TranslatorWGSL::shouldFlattenPragmaStdglInvariantAll()
2502 {
2503 // Not neccesary for WGSL transformation.
2504 return false;
2505 }
2506 } // namespace sh
2507