• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2024 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include "compiler/translator/wgsl/TranslatorWGSL.h"
8 
9 #include "GLSLANG/ShaderLang.h"
10 #include "common/log_utils.h"
11 #include "compiler/translator/BaseTypes.h"
12 #include "compiler/translator/Diagnostics.h"
13 #include "compiler/translator/ImmutableString.h"
14 #include "compiler/translator/InfoSink.h"
15 #include "compiler/translator/IntermNode.h"
16 #include "compiler/translator/tree_util/IntermTraverse.h"
17 
18 namespace sh
19 {
20 namespace
21 {
22 
23 // When emitting a list of statements, this determines whether a semicolon follows the statement.
RequiresSemicolonTerminator(TIntermNode & node)24 bool RequiresSemicolonTerminator(TIntermNode &node)
25 {
26     if (node.getAsBlock())
27     {
28         return false;
29     }
30     if (node.getAsLoopNode())
31     {
32         return false;
33     }
34     if (node.getAsSwitchNode())
35     {
36         return false;
37     }
38     if (node.getAsIfElseNode())
39     {
40         return false;
41     }
42     if (node.getAsFunctionDefinition())
43     {
44         return false;
45     }
46     if (node.getAsCaseNode())
47     {
48         return false;
49     }
50 
51     return true;
52 }
53 
54 // For pretty formatting of the resulting WGSL text.
NewlinePad(TIntermNode & node)55 bool NewlinePad(TIntermNode &node)
56 {
57     if (node.getAsFunctionDefinition())
58     {
59         return true;
60     }
61     if (TIntermDeclaration *declNode = node.getAsDeclarationNode())
62     {
63         ASSERT(declNode->getChildCount() == 1);
64         TIntermNode &childNode = *declNode->getChildNode(0);
65         if (TIntermSymbol *symbolNode = childNode.getAsSymbolNode())
66         {
67             const TVariable &var = symbolNode->variable();
68             return var.getType().isStructSpecifier();
69         }
70         return false;
71     }
72     return false;
73 }
74 
75 // A traverser that generates WGSL as it walks the AST.
76 class OutputWGSLTraverser : public TIntermTraverser
77 {
78   public:
79     OutputWGSLTraverser(TCompiler *compiler);
80     ~OutputWGSLTraverser() override;
81 
82   protected:
83     void visitSymbol(TIntermSymbol *node) override;
84     void visitConstantUnion(TIntermConstantUnion *node) override;
85     bool visitSwizzle(Visit visit, TIntermSwizzle *node) override;
86     bool visitBinary(Visit visit, TIntermBinary *node) override;
87     bool visitUnary(Visit visit, TIntermUnary *node) override;
88     bool visitTernary(Visit visit, TIntermTernary *node) override;
89     bool visitIfElse(Visit visit, TIntermIfElse *node) override;
90     bool visitSwitch(Visit visit, TIntermSwitch *node) override;
91     bool visitCase(Visit visit, TIntermCase *node) override;
92     void visitFunctionPrototype(TIntermFunctionPrototype *node) override;
93     bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
94     bool visitAggregate(Visit visit, TIntermAggregate *node) override;
95     bool visitBlock(Visit visit, TIntermBlock *node) override;
96     bool visitGlobalQualifierDeclaration(Visit visit,
97                                          TIntermGlobalQualifierDeclaration *node) override;
98     bool visitDeclaration(Visit visit, TIntermDeclaration *node) override;
99     bool visitLoop(Visit visit, TIntermLoop *node) override;
100     bool visitBranch(Visit visit, TIntermBranch *node) override;
101     void visitPreprocessorDirective(TIntermPreprocessorDirective *node) override;
102 
103   private:
104     void groupedTraverse(TIntermNode &node);
105     template <typename T>
106     void emitNameOf(const T &namedObject);
107     void emitBareTypeName(const TType &type);
108     void emitType(const TType &type);
109     void emitIndentation();
110     void emitOpenBrace();
111     void emitCloseBrace();
112     void emitFunctionSignature(const TFunction &func);
113     void emitFunctionReturn(const TFunction &func);
114     void emitFunctionParameter(const TFunction &func, const TVariable &param);
115 
116     TInfoSinkBase &mSink;
117 
118     int mIndentLevel        = -1;
119     int mLastIndentationPos = -1;
120 };
121 
OutputWGSLTraverser(TCompiler * compiler)122 OutputWGSLTraverser::OutputWGSLTraverser(TCompiler *compiler)
123     : TIntermTraverser(true, false, false), mSink(compiler->getInfoSink().obj)
124 {}
125 
126 OutputWGSLTraverser::~OutputWGSLTraverser() = default;
127 
groupedTraverse(TIntermNode & node)128 void OutputWGSLTraverser::groupedTraverse(TIntermNode &node)
129 {
130     // TODO(anglebug.com/8662): to make generated code more readable, do not always
131     // emit parentheses like WGSL is some Lisp dialect.
132     const bool emitParens = true;
133 
134     if (emitParens)
135     {
136         mSink << "(";
137     }
138 
139     node.traverse(this);
140 
141     if (emitParens)
142     {
143         mSink << ")";
144     }
145 }
146 
147 // Can be used with TSymbol or TField. Must have a .name() and a .symbolType().
148 template <typename T>
emitNameOf(const T & namedObject)149 void OutputWGSLTraverser::emitNameOf(const T &namedObject)
150 {
151     switch (namedObject.symbolType())
152     {
153         case SymbolType::BuiltIn:
154         {
155             mSink << namedObject.name();
156         }
157         break;
158         case SymbolType::UserDefined:
159         {
160             mSink << kUserDefinedNamePrefix << namedObject.name();
161         }
162         break;
163         case SymbolType::AngleInternal:
164         case SymbolType::Empty:
165             // TODO(anglebug.com/8662): support these if necessary
166             UNREACHABLE();
167     }
168 }
169 
emitIndentation()170 void OutputWGSLTraverser::emitIndentation()
171 {
172     ASSERT(mIndentLevel >= 0);
173 
174     if (mLastIndentationPos == mSink.size())
175     {
176         return;  // Line is already indented.
177     }
178 
179     for (int i = 0; i < mIndentLevel; ++i)
180     {
181         mSink << "  ";
182     }
183 
184     mLastIndentationPos = mSink.size();
185 }
186 
emitOpenBrace()187 void OutputWGSLTraverser::emitOpenBrace()
188 {
189     ASSERT(mIndentLevel >= 0);
190 
191     emitIndentation();
192     mSink << "{\n";
193     ++mIndentLevel;
194 }
195 
emitCloseBrace()196 void OutputWGSLTraverser::emitCloseBrace()
197 {
198     ASSERT(mIndentLevel >= 1);
199 
200     --mIndentLevel;
201     emitIndentation();
202     mSink << "}";
203 }
204 
visitSymbol(TIntermSymbol * symbolNode)205 void OutputWGSLTraverser::visitSymbol(TIntermSymbol *symbolNode)
206 {
207 
208     const TVariable &var = symbolNode->variable();
209     const TType &type    = var.getType();
210     ASSERT(var.symbolType() != SymbolType::Empty);
211 
212     if (type.getBasicType() == TBasicType::EbtVoid)
213     {
214         UNREACHABLE();
215     }
216     else
217     {
218         emitNameOf(var);
219     }
220 }
221 
visitConstantUnion(TIntermConstantUnion * constValueNode)222 void OutputWGSLTraverser::visitConstantUnion(TIntermConstantUnion *constValueNode)
223 {
224     // TODO(anglebug.com/8662): support emitting constants..
225 }
226 
visitSwizzle(Visit,TIntermSwizzle * swizzleNode)227 bool OutputWGSLTraverser::visitSwizzle(Visit, TIntermSwizzle *swizzleNode)
228 {
229     // TODO(anglebug.com/8662): support swizzle statements.
230     return false;
231 }
232 
visitBinary(Visit,TIntermBinary * binaryNode)233 bool OutputWGSLTraverser::visitBinary(Visit, TIntermBinary *binaryNode)
234 {
235     // TODO(anglebug.com/8662): support binary statements.
236     return false;
237 }
238 
visitUnary(Visit,TIntermUnary * unaryNode)239 bool OutputWGSLTraverser::visitUnary(Visit, TIntermUnary *unaryNode)
240 {
241     // TODO(anglebug.com/8662): support unary statements.
242     return false;
243 }
244 
visitTernary(Visit,TIntermTernary * conditionalNode)245 bool OutputWGSLTraverser::visitTernary(Visit, TIntermTernary *conditionalNode)
246 {
247     // TODO(anglebug.com/8662): support ternaries.
248     return false;
249 }
250 
visitIfElse(Visit,TIntermIfElse * ifThenElseNode)251 bool OutputWGSLTraverser::visitIfElse(Visit, TIntermIfElse *ifThenElseNode)
252 {
253     // TODO(anglebug.com/8662): support basic control flow.
254     return false;
255 }
256 
visitSwitch(Visit,TIntermSwitch * switchNode)257 bool OutputWGSLTraverser::visitSwitch(Visit, TIntermSwitch *switchNode)
258 {
259     // TODO(anglebug.com/8662): support switch statements.
260     return false;
261 }
262 
visitCase(Visit,TIntermCase * caseNode)263 bool OutputWGSLTraverser::visitCase(Visit, TIntermCase *caseNode)
264 {
265     // TODO(anglebug.com/8662): support switch statements.
266     return false;
267 }
268 
emitFunctionReturn(const TFunction & func)269 void OutputWGSLTraverser::emitFunctionReturn(const TFunction &func)
270 {
271     const TType &returnType = func.getReturnType();
272     if (returnType.getBasicType() == EbtVoid)
273     {
274         return;
275     }
276     mSink << " -> ";
277     emitType(returnType);
278 }
279 
280 // TODO(anglebug.com/42267100): Function overloads are not supported in WGSL, so function names
281 // should either be emitted mangled or overloaded functions should be renamed in the AST as a
282 // pre-pass. As of Apr 2024, WGSL function overloads are "not coming soon"
283 // (https://github.com/gpuweb/gpuweb/issues/876).
emitFunctionSignature(const TFunction & func)284 void OutputWGSLTraverser::emitFunctionSignature(const TFunction &func)
285 {
286     // TODO(anglebug.com/42267100): main functions should be renamed and labeled with @vertex or
287     // @fragment.
288     mSink << "fn ";
289 
290     emitNameOf(func);
291     mSink << "(";
292 
293     bool emitComma          = false;
294     const size_t paramCount = func.getParamCount();
295     for (size_t i = 0; i < paramCount; ++i)
296     {
297         if (emitComma)
298         {
299             mSink << ", ";
300         }
301         emitComma = true;
302 
303         const TVariable &param = *func.getParam(i);
304         emitFunctionParameter(func, param);
305     }
306 
307     mSink << ")";
308 
309     emitFunctionReturn(func);
310 }
311 
emitFunctionParameter(const TFunction & func,const TVariable & param)312 void OutputWGSLTraverser::emitFunctionParameter(const TFunction &func, const TVariable &param)
313 {
314     // TODO(anglebug.com/8662): actually emit function parameters.
315 
316     mSink << "FAKE_FUNCTION_PARAMETER";
317 }
318 
visitFunctionPrototype(TIntermFunctionPrototype * funcProtoNode)319 void OutputWGSLTraverser::visitFunctionPrototype(TIntermFunctionPrototype *funcProtoNode)
320 {
321     const TFunction &func = *funcProtoNode->getFunction();
322 
323     emitIndentation();
324     emitFunctionSignature(func);
325 }
326 
visitFunctionDefinition(Visit,TIntermFunctionDefinition * funcDefNode)327 bool OutputWGSLTraverser::visitFunctionDefinition(Visit, TIntermFunctionDefinition *funcDefNode)
328 {
329     const TFunction &func = *funcDefNode->getFunction();
330     TIntermBlock &body    = *funcDefNode->getBody();
331     emitIndentation();
332     emitFunctionSignature(func);
333     mSink << "\n";
334     body.traverse(this);
335     return false;
336 }
337 
visitAggregate(Visit,TIntermAggregate * aggregateNode)338 bool OutputWGSLTraverser::visitAggregate(Visit, TIntermAggregate *aggregateNode)
339 {
340     // TODO(anglebug.com/8662): support aggregate statements.
341     return false;
342 }
343 
visitBlock(Visit,TIntermBlock * blockNode)344 bool OutputWGSLTraverser::visitBlock(Visit, TIntermBlock *blockNode)
345 {
346     ASSERT(mIndentLevel >= -1);
347     const bool isGlobalScope = mIndentLevel == -1;
348 
349     if (isGlobalScope)
350     {
351         ++mIndentLevel;
352     }
353     else
354     {
355         emitOpenBrace();
356     }
357 
358     TIntermNode *prevStmtNode = nullptr;
359 
360     const size_t stmtCount = blockNode->getChildCount();
361     for (size_t i = 0; i < stmtCount; ++i)
362     {
363         TIntermNode &stmtNode = *blockNode->getChildNode(i);
364 
365         if (isGlobalScope && prevStmtNode && (NewlinePad(*prevStmtNode) || NewlinePad(stmtNode)))
366         {
367             mSink << "\n";
368         }
369         const bool isCase = stmtNode.getAsCaseNode();
370         mIndentLevel -= isCase;
371         emitIndentation();
372         mIndentLevel += isCase;
373         stmtNode.traverse(this);
374         if (RequiresSemicolonTerminator(stmtNode))
375         {
376             mSink << ";";
377         }
378         mSink << "\n";
379 
380         prevStmtNode = &stmtNode;
381     }
382 
383     if (isGlobalScope)
384     {
385         ASSERT(mIndentLevel == 0);
386         --mIndentLevel;
387     }
388     else
389     {
390         emitCloseBrace();
391     }
392 
393     return false;
394 }
395 
visitGlobalQualifierDeclaration(Visit,TIntermGlobalQualifierDeclaration *)396 bool OutputWGSLTraverser::visitGlobalQualifierDeclaration(Visit,
397                                                           TIntermGlobalQualifierDeclaration *)
398 {
399     return false;
400 }
401 
visitDeclaration(Visit,TIntermDeclaration * declNode)402 bool OutputWGSLTraverser::visitDeclaration(Visit, TIntermDeclaration *declNode)
403 {
404     // TODO(anglebug.com/8662): support variable declarations.
405     mSink << "FAKE_DECLARATION";
406     return false;
407 }
408 
visitLoop(Visit,TIntermLoop * loopNode)409 bool OutputWGSLTraverser::visitLoop(Visit, TIntermLoop *loopNode)
410 {
411     // TODO(anglebug.com/8662): emit loops.
412     return false;
413 }
414 
visitBranch(Visit,TIntermBranch * branchNode)415 bool OutputWGSLTraverser::visitBranch(Visit, TIntermBranch *branchNode)
416 {
417     // TODO(anglebug.com/8662): emit branch instructions.
418     return false;
419 }
420 
visitPreprocessorDirective(TIntermPreprocessorDirective * node)421 void OutputWGSLTraverser::visitPreprocessorDirective(TIntermPreprocessorDirective *node)
422 {
423     // No preprocessor directives expected at this point.
424     UNREACHABLE();
425 }
426 
emitBareTypeName(const TType & type)427 void OutputWGSLTraverser::emitBareTypeName(const TType &type)
428 {
429     const TBasicType basicType = type.getBasicType();
430 
431     switch (basicType)
432     {
433         case TBasicType::EbtVoid:
434         case TBasicType::EbtBool:
435             mSink << type.getBasicString();
436             break;
437         // TODO(anglebug.com/8662): is there double precision (f64) in GLSL? It doesn't really exist
438         // in WGSL (i.e. f64 does not exist but AbstractFloat can handle 64 bits???) Metal does not
439         // have 64 bit double precision types. It's being implemented in WGPU:
440         // https://github.com/gpuweb/gpuweb/issues/2805
441         case TBasicType::EbtFloat:
442             mSink << "f32";
443             break;
444         case TBasicType::EbtInt:
445             mSink << "i32";
446             break;
447         case TBasicType::EbtUInt:
448             mSink << "u32";
449             break;
450 
451         case TBasicType::EbtStruct:
452             emitNameOf(*type.getStruct());
453             break;
454 
455         case TBasicType::EbtInterfaceBlock:
456             emitNameOf(*type.getInterfaceBlock());
457             break;
458 
459         default:
460             if (IsSampler(basicType))
461             {
462                 //  TODO(anglebug.com/8662): possibly emit both a sampler and a texture2d. WGSL has
463                 //  sampler variables for the sampler configuration, whereas GLSL has sampler2d and
464                 //  other sampler* variables for an actual texture.
465                 mSink << "texture2d<";
466                 switch (type.getBasicType())
467                 {
468                     case EbtSampler2D:
469                         mSink << "f32";
470                         break;
471                     case EbtISampler2D:
472                         mSink << "i32";
473                         break;
474                     case EbtUSampler2D:
475                         mSink << "u32";
476                         break;
477                     default:
478                         // TODO(anglebug.com/8662): are any of the other sampler types necessary to
479                         // translate?
480                         UNIMPLEMENTED();
481                         break;
482                 }
483                 if (type.getMemoryQualifier().readonly || type.getMemoryQualifier().writeonly)
484                 {
485                     // TODO(anglebug.com/8662): implement memory qualifiers.
486                     UNIMPLEMENTED();
487                 }
488                 mSink << ">";
489             }
490             else if (IsImage(basicType))
491             {
492                 // TODO(anglebug.com/8662): does texture2d also correspond to GLSL's image type?
493                 mSink << "texture2d<";
494                 switch (type.getBasicType())
495                 {
496                     case EbtImage2D:
497                         mSink << "f32";
498                         break;
499                     case EbtIImage2D:
500                         mSink << "i32";
501                         break;
502                     case EbtUImage2D:
503                         mSink << "u32";
504                         break;
505                     default:
506                         // TODO(anglebug.com/8662): are any of the other image types necessary to
507                         // translate?
508                         UNIMPLEMENTED();
509                         break;
510                 }
511                 if (type.getMemoryQualifier().readonly || type.getMemoryQualifier().writeonly)
512                 {
513                     // TODO(anglebug.com/8662): implement memory qualifiers.
514                     UNREACHABLE();
515                 }
516                 mSink << ">";
517             }
518             else
519             {
520                 UNREACHABLE();
521             }
522             break;
523     }
524 }
525 
emitType(const TType & type)526 void OutputWGSLTraverser::emitType(const TType &type)
527 {
528     // TODO(anglebug.com/8662): support types with dimensions.
529     ASSERT(!type.isVector() && !type.isMatrix() && !type.isArray());
530 
531     // This type has no dimensions and is equivalent to its bare type.
532     emitBareTypeName(type);
533 }
534 
535 }  // namespace
536 
TranslatorWGSL(sh::GLenum type,ShShaderSpec spec,ShShaderOutput output)537 TranslatorWGSL::TranslatorWGSL(sh::GLenum type, ShShaderSpec spec, ShShaderOutput output)
538     : TCompiler(type, spec, output)
539 {}
540 
translate(TIntermBlock * root,const ShCompileOptions & compileOptions,PerformanceDiagnostics * perfDiagnostics)541 bool TranslatorWGSL::translate(TIntermBlock *root,
542                                const ShCompileOptions &compileOptions,
543                                PerformanceDiagnostics *perfDiagnostics)
544 {
545     // TODO(anglebug.com/8662): until the translator is ready to translate most basic shaders, emit
546     // the code commented out.
547     TInfoSinkBase &sink = getInfoSink().obj;
548     sink << "/*\n";
549     OutputWGSLTraverser traverser(this);
550     root->traverse(&traverser);
551     sink << "*/\n";
552 
553     std::cout << getInfoSink().obj.str();
554 
555     // TODO(anglebug.com/8662): delete this.
556     if (getShaderType() == GL_VERTEX_SHADER)
557     {
558         constexpr const char *kVertexShader = R"(@vertex
559 fn main(@builtin(vertex_index) vertex_index : u32) -> @builtin(position) vec4f
560 {
561     const pos = array(
562         vec2( 0.0,  0.5),
563         vec2(-0.5, -0.5),
564         vec2( 0.5, -0.5)
565     );
566 
567     return vec4f(pos[vertex_index % 3], 0, 1);
568 })";
569         sink << kVertexShader;
570     }
571     else if (getShaderType() == GL_FRAGMENT_SHADER)
572     {
573         constexpr const char *kFragmentShader = R"(@fragment
574 fn main() -> @location(0) vec4f
575 {
576     return vec4(1, 0, 0, 1);
577 })";
578         sink << kFragmentShader;
579     }
580     else
581     {
582         UNREACHABLE();
583         return false;
584     }
585 
586     return true;
587 }
588 
shouldFlattenPragmaStdglInvariantAll()589 bool TranslatorWGSL::shouldFlattenPragmaStdglInvariantAll()
590 {
591     // Not neccesary for WGSL transformation.
592     return false;
593 }
594 }  // namespace sh
595