1 //
2 // Copyright 2024 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include "compiler/translator/wgsl/TranslatorWGSL.h"
8
9 #include "GLSLANG/ShaderLang.h"
10 #include "common/log_utils.h"
11 #include "compiler/translator/BaseTypes.h"
12 #include "compiler/translator/Diagnostics.h"
13 #include "compiler/translator/ImmutableString.h"
14 #include "compiler/translator/InfoSink.h"
15 #include "compiler/translator/IntermNode.h"
16 #include "compiler/translator/tree_util/IntermTraverse.h"
17
18 namespace sh
19 {
20 namespace
21 {
22
23 // When emitting a list of statements, this determines whether a semicolon follows the statement.
RequiresSemicolonTerminator(TIntermNode & node)24 bool RequiresSemicolonTerminator(TIntermNode &node)
25 {
26 if (node.getAsBlock())
27 {
28 return false;
29 }
30 if (node.getAsLoopNode())
31 {
32 return false;
33 }
34 if (node.getAsSwitchNode())
35 {
36 return false;
37 }
38 if (node.getAsIfElseNode())
39 {
40 return false;
41 }
42 if (node.getAsFunctionDefinition())
43 {
44 return false;
45 }
46 if (node.getAsCaseNode())
47 {
48 return false;
49 }
50
51 return true;
52 }
53
54 // For pretty formatting of the resulting WGSL text.
NewlinePad(TIntermNode & node)55 bool NewlinePad(TIntermNode &node)
56 {
57 if (node.getAsFunctionDefinition())
58 {
59 return true;
60 }
61 if (TIntermDeclaration *declNode = node.getAsDeclarationNode())
62 {
63 ASSERT(declNode->getChildCount() == 1);
64 TIntermNode &childNode = *declNode->getChildNode(0);
65 if (TIntermSymbol *symbolNode = childNode.getAsSymbolNode())
66 {
67 const TVariable &var = symbolNode->variable();
68 return var.getType().isStructSpecifier();
69 }
70 return false;
71 }
72 return false;
73 }
74
75 // A traverser that generates WGSL as it walks the AST.
76 class OutputWGSLTraverser : public TIntermTraverser
77 {
78 public:
79 OutputWGSLTraverser(TCompiler *compiler);
80 ~OutputWGSLTraverser() override;
81
82 protected:
83 void visitSymbol(TIntermSymbol *node) override;
84 void visitConstantUnion(TIntermConstantUnion *node) override;
85 bool visitSwizzle(Visit visit, TIntermSwizzle *node) override;
86 bool visitBinary(Visit visit, TIntermBinary *node) override;
87 bool visitUnary(Visit visit, TIntermUnary *node) override;
88 bool visitTernary(Visit visit, TIntermTernary *node) override;
89 bool visitIfElse(Visit visit, TIntermIfElse *node) override;
90 bool visitSwitch(Visit visit, TIntermSwitch *node) override;
91 bool visitCase(Visit visit, TIntermCase *node) override;
92 void visitFunctionPrototype(TIntermFunctionPrototype *node) override;
93 bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
94 bool visitAggregate(Visit visit, TIntermAggregate *node) override;
95 bool visitBlock(Visit visit, TIntermBlock *node) override;
96 bool visitGlobalQualifierDeclaration(Visit visit,
97 TIntermGlobalQualifierDeclaration *node) override;
98 bool visitDeclaration(Visit visit, TIntermDeclaration *node) override;
99 bool visitLoop(Visit visit, TIntermLoop *node) override;
100 bool visitBranch(Visit visit, TIntermBranch *node) override;
101 void visitPreprocessorDirective(TIntermPreprocessorDirective *node) override;
102
103 private:
104 void groupedTraverse(TIntermNode &node);
105 template <typename T>
106 void emitNameOf(const T &namedObject);
107 void emitBareTypeName(const TType &type);
108 void emitType(const TType &type);
109 void emitIndentation();
110 void emitOpenBrace();
111 void emitCloseBrace();
112 void emitFunctionSignature(const TFunction &func);
113 void emitFunctionReturn(const TFunction &func);
114 void emitFunctionParameter(const TFunction &func, const TVariable ¶m);
115
116 TInfoSinkBase &mSink;
117
118 int mIndentLevel = -1;
119 int mLastIndentationPos = -1;
120 };
121
OutputWGSLTraverser(TCompiler * compiler)122 OutputWGSLTraverser::OutputWGSLTraverser(TCompiler *compiler)
123 : TIntermTraverser(true, false, false), mSink(compiler->getInfoSink().obj)
124 {}
125
126 OutputWGSLTraverser::~OutputWGSLTraverser() = default;
127
groupedTraverse(TIntermNode & node)128 void OutputWGSLTraverser::groupedTraverse(TIntermNode &node)
129 {
130 // TODO(anglebug.com/8662): to make generated code more readable, do not always
131 // emit parentheses like WGSL is some Lisp dialect.
132 const bool emitParens = true;
133
134 if (emitParens)
135 {
136 mSink << "(";
137 }
138
139 node.traverse(this);
140
141 if (emitParens)
142 {
143 mSink << ")";
144 }
145 }
146
147 // Can be used with TSymbol or TField. Must have a .name() and a .symbolType().
148 template <typename T>
emitNameOf(const T & namedObject)149 void OutputWGSLTraverser::emitNameOf(const T &namedObject)
150 {
151 switch (namedObject.symbolType())
152 {
153 case SymbolType::BuiltIn:
154 {
155 mSink << namedObject.name();
156 }
157 break;
158 case SymbolType::UserDefined:
159 {
160 mSink << kUserDefinedNamePrefix << namedObject.name();
161 }
162 break;
163 case SymbolType::AngleInternal:
164 case SymbolType::Empty:
165 // TODO(anglebug.com/8662): support these if necessary
166 UNREACHABLE();
167 }
168 }
169
emitIndentation()170 void OutputWGSLTraverser::emitIndentation()
171 {
172 ASSERT(mIndentLevel >= 0);
173
174 if (mLastIndentationPos == mSink.size())
175 {
176 return; // Line is already indented.
177 }
178
179 for (int i = 0; i < mIndentLevel; ++i)
180 {
181 mSink << " ";
182 }
183
184 mLastIndentationPos = mSink.size();
185 }
186
emitOpenBrace()187 void OutputWGSLTraverser::emitOpenBrace()
188 {
189 ASSERT(mIndentLevel >= 0);
190
191 emitIndentation();
192 mSink << "{\n";
193 ++mIndentLevel;
194 }
195
emitCloseBrace()196 void OutputWGSLTraverser::emitCloseBrace()
197 {
198 ASSERT(mIndentLevel >= 1);
199
200 --mIndentLevel;
201 emitIndentation();
202 mSink << "}";
203 }
204
visitSymbol(TIntermSymbol * symbolNode)205 void OutputWGSLTraverser::visitSymbol(TIntermSymbol *symbolNode)
206 {
207
208 const TVariable &var = symbolNode->variable();
209 const TType &type = var.getType();
210 ASSERT(var.symbolType() != SymbolType::Empty);
211
212 if (type.getBasicType() == TBasicType::EbtVoid)
213 {
214 UNREACHABLE();
215 }
216 else
217 {
218 emitNameOf(var);
219 }
220 }
221
visitConstantUnion(TIntermConstantUnion * constValueNode)222 void OutputWGSLTraverser::visitConstantUnion(TIntermConstantUnion *constValueNode)
223 {
224 // TODO(anglebug.com/8662): support emitting constants..
225 }
226
visitSwizzle(Visit,TIntermSwizzle * swizzleNode)227 bool OutputWGSLTraverser::visitSwizzle(Visit, TIntermSwizzle *swizzleNode)
228 {
229 // TODO(anglebug.com/8662): support swizzle statements.
230 return false;
231 }
232
visitBinary(Visit,TIntermBinary * binaryNode)233 bool OutputWGSLTraverser::visitBinary(Visit, TIntermBinary *binaryNode)
234 {
235 // TODO(anglebug.com/8662): support binary statements.
236 return false;
237 }
238
visitUnary(Visit,TIntermUnary * unaryNode)239 bool OutputWGSLTraverser::visitUnary(Visit, TIntermUnary *unaryNode)
240 {
241 // TODO(anglebug.com/8662): support unary statements.
242 return false;
243 }
244
visitTernary(Visit,TIntermTernary * conditionalNode)245 bool OutputWGSLTraverser::visitTernary(Visit, TIntermTernary *conditionalNode)
246 {
247 // TODO(anglebug.com/8662): support ternaries.
248 return false;
249 }
250
visitIfElse(Visit,TIntermIfElse * ifThenElseNode)251 bool OutputWGSLTraverser::visitIfElse(Visit, TIntermIfElse *ifThenElseNode)
252 {
253 // TODO(anglebug.com/8662): support basic control flow.
254 return false;
255 }
256
visitSwitch(Visit,TIntermSwitch * switchNode)257 bool OutputWGSLTraverser::visitSwitch(Visit, TIntermSwitch *switchNode)
258 {
259 // TODO(anglebug.com/8662): support switch statements.
260 return false;
261 }
262
visitCase(Visit,TIntermCase * caseNode)263 bool OutputWGSLTraverser::visitCase(Visit, TIntermCase *caseNode)
264 {
265 // TODO(anglebug.com/8662): support switch statements.
266 return false;
267 }
268
emitFunctionReturn(const TFunction & func)269 void OutputWGSLTraverser::emitFunctionReturn(const TFunction &func)
270 {
271 const TType &returnType = func.getReturnType();
272 if (returnType.getBasicType() == EbtVoid)
273 {
274 return;
275 }
276 mSink << " -> ";
277 emitType(returnType);
278 }
279
280 // TODO(anglebug.com/42267100): Function overloads are not supported in WGSL, so function names
281 // should either be emitted mangled or overloaded functions should be renamed in the AST as a
282 // pre-pass. As of Apr 2024, WGSL function overloads are "not coming soon"
283 // (https://github.com/gpuweb/gpuweb/issues/876).
emitFunctionSignature(const TFunction & func)284 void OutputWGSLTraverser::emitFunctionSignature(const TFunction &func)
285 {
286 // TODO(anglebug.com/42267100): main functions should be renamed and labeled with @vertex or
287 // @fragment.
288 mSink << "fn ";
289
290 emitNameOf(func);
291 mSink << "(";
292
293 bool emitComma = false;
294 const size_t paramCount = func.getParamCount();
295 for (size_t i = 0; i < paramCount; ++i)
296 {
297 if (emitComma)
298 {
299 mSink << ", ";
300 }
301 emitComma = true;
302
303 const TVariable ¶m = *func.getParam(i);
304 emitFunctionParameter(func, param);
305 }
306
307 mSink << ")";
308
309 emitFunctionReturn(func);
310 }
311
emitFunctionParameter(const TFunction & func,const TVariable & param)312 void OutputWGSLTraverser::emitFunctionParameter(const TFunction &func, const TVariable ¶m)
313 {
314 // TODO(anglebug.com/8662): actually emit function parameters.
315
316 mSink << "FAKE_FUNCTION_PARAMETER";
317 }
318
visitFunctionPrototype(TIntermFunctionPrototype * funcProtoNode)319 void OutputWGSLTraverser::visitFunctionPrototype(TIntermFunctionPrototype *funcProtoNode)
320 {
321 const TFunction &func = *funcProtoNode->getFunction();
322
323 emitIndentation();
324 emitFunctionSignature(func);
325 }
326
visitFunctionDefinition(Visit,TIntermFunctionDefinition * funcDefNode)327 bool OutputWGSLTraverser::visitFunctionDefinition(Visit, TIntermFunctionDefinition *funcDefNode)
328 {
329 const TFunction &func = *funcDefNode->getFunction();
330 TIntermBlock &body = *funcDefNode->getBody();
331 emitIndentation();
332 emitFunctionSignature(func);
333 mSink << "\n";
334 body.traverse(this);
335 return false;
336 }
337
visitAggregate(Visit,TIntermAggregate * aggregateNode)338 bool OutputWGSLTraverser::visitAggregate(Visit, TIntermAggregate *aggregateNode)
339 {
340 // TODO(anglebug.com/8662): support aggregate statements.
341 return false;
342 }
343
visitBlock(Visit,TIntermBlock * blockNode)344 bool OutputWGSLTraverser::visitBlock(Visit, TIntermBlock *blockNode)
345 {
346 ASSERT(mIndentLevel >= -1);
347 const bool isGlobalScope = mIndentLevel == -1;
348
349 if (isGlobalScope)
350 {
351 ++mIndentLevel;
352 }
353 else
354 {
355 emitOpenBrace();
356 }
357
358 TIntermNode *prevStmtNode = nullptr;
359
360 const size_t stmtCount = blockNode->getChildCount();
361 for (size_t i = 0; i < stmtCount; ++i)
362 {
363 TIntermNode &stmtNode = *blockNode->getChildNode(i);
364
365 if (isGlobalScope && prevStmtNode && (NewlinePad(*prevStmtNode) || NewlinePad(stmtNode)))
366 {
367 mSink << "\n";
368 }
369 const bool isCase = stmtNode.getAsCaseNode();
370 mIndentLevel -= isCase;
371 emitIndentation();
372 mIndentLevel += isCase;
373 stmtNode.traverse(this);
374 if (RequiresSemicolonTerminator(stmtNode))
375 {
376 mSink << ";";
377 }
378 mSink << "\n";
379
380 prevStmtNode = &stmtNode;
381 }
382
383 if (isGlobalScope)
384 {
385 ASSERT(mIndentLevel == 0);
386 --mIndentLevel;
387 }
388 else
389 {
390 emitCloseBrace();
391 }
392
393 return false;
394 }
395
visitGlobalQualifierDeclaration(Visit,TIntermGlobalQualifierDeclaration *)396 bool OutputWGSLTraverser::visitGlobalQualifierDeclaration(Visit,
397 TIntermGlobalQualifierDeclaration *)
398 {
399 return false;
400 }
401
visitDeclaration(Visit,TIntermDeclaration * declNode)402 bool OutputWGSLTraverser::visitDeclaration(Visit, TIntermDeclaration *declNode)
403 {
404 // TODO(anglebug.com/8662): support variable declarations.
405 mSink << "FAKE_DECLARATION";
406 return false;
407 }
408
visitLoop(Visit,TIntermLoop * loopNode)409 bool OutputWGSLTraverser::visitLoop(Visit, TIntermLoop *loopNode)
410 {
411 // TODO(anglebug.com/8662): emit loops.
412 return false;
413 }
414
visitBranch(Visit,TIntermBranch * branchNode)415 bool OutputWGSLTraverser::visitBranch(Visit, TIntermBranch *branchNode)
416 {
417 // TODO(anglebug.com/8662): emit branch instructions.
418 return false;
419 }
420
visitPreprocessorDirective(TIntermPreprocessorDirective * node)421 void OutputWGSLTraverser::visitPreprocessorDirective(TIntermPreprocessorDirective *node)
422 {
423 // No preprocessor directives expected at this point.
424 UNREACHABLE();
425 }
426
emitBareTypeName(const TType & type)427 void OutputWGSLTraverser::emitBareTypeName(const TType &type)
428 {
429 const TBasicType basicType = type.getBasicType();
430
431 switch (basicType)
432 {
433 case TBasicType::EbtVoid:
434 case TBasicType::EbtBool:
435 mSink << type.getBasicString();
436 break;
437 // TODO(anglebug.com/8662): is there double precision (f64) in GLSL? It doesn't really exist
438 // in WGSL (i.e. f64 does not exist but AbstractFloat can handle 64 bits???) Metal does not
439 // have 64 bit double precision types. It's being implemented in WGPU:
440 // https://github.com/gpuweb/gpuweb/issues/2805
441 case TBasicType::EbtFloat:
442 mSink << "f32";
443 break;
444 case TBasicType::EbtInt:
445 mSink << "i32";
446 break;
447 case TBasicType::EbtUInt:
448 mSink << "u32";
449 break;
450
451 case TBasicType::EbtStruct:
452 emitNameOf(*type.getStruct());
453 break;
454
455 case TBasicType::EbtInterfaceBlock:
456 emitNameOf(*type.getInterfaceBlock());
457 break;
458
459 default:
460 if (IsSampler(basicType))
461 {
462 // TODO(anglebug.com/8662): possibly emit both a sampler and a texture2d. WGSL has
463 // sampler variables for the sampler configuration, whereas GLSL has sampler2d and
464 // other sampler* variables for an actual texture.
465 mSink << "texture2d<";
466 switch (type.getBasicType())
467 {
468 case EbtSampler2D:
469 mSink << "f32";
470 break;
471 case EbtISampler2D:
472 mSink << "i32";
473 break;
474 case EbtUSampler2D:
475 mSink << "u32";
476 break;
477 default:
478 // TODO(anglebug.com/8662): are any of the other sampler types necessary to
479 // translate?
480 UNIMPLEMENTED();
481 break;
482 }
483 if (type.getMemoryQualifier().readonly || type.getMemoryQualifier().writeonly)
484 {
485 // TODO(anglebug.com/8662): implement memory qualifiers.
486 UNIMPLEMENTED();
487 }
488 mSink << ">";
489 }
490 else if (IsImage(basicType))
491 {
492 // TODO(anglebug.com/8662): does texture2d also correspond to GLSL's image type?
493 mSink << "texture2d<";
494 switch (type.getBasicType())
495 {
496 case EbtImage2D:
497 mSink << "f32";
498 break;
499 case EbtIImage2D:
500 mSink << "i32";
501 break;
502 case EbtUImage2D:
503 mSink << "u32";
504 break;
505 default:
506 // TODO(anglebug.com/8662): are any of the other image types necessary to
507 // translate?
508 UNIMPLEMENTED();
509 break;
510 }
511 if (type.getMemoryQualifier().readonly || type.getMemoryQualifier().writeonly)
512 {
513 // TODO(anglebug.com/8662): implement memory qualifiers.
514 UNREACHABLE();
515 }
516 mSink << ">";
517 }
518 else
519 {
520 UNREACHABLE();
521 }
522 break;
523 }
524 }
525
emitType(const TType & type)526 void OutputWGSLTraverser::emitType(const TType &type)
527 {
528 // TODO(anglebug.com/8662): support types with dimensions.
529 ASSERT(!type.isVector() && !type.isMatrix() && !type.isArray());
530
531 // This type has no dimensions and is equivalent to its bare type.
532 emitBareTypeName(type);
533 }
534
535 } // namespace
536
TranslatorWGSL(sh::GLenum type,ShShaderSpec spec,ShShaderOutput output)537 TranslatorWGSL::TranslatorWGSL(sh::GLenum type, ShShaderSpec spec, ShShaderOutput output)
538 : TCompiler(type, spec, output)
539 {}
540
translate(TIntermBlock * root,const ShCompileOptions & compileOptions,PerformanceDiagnostics * perfDiagnostics)541 bool TranslatorWGSL::translate(TIntermBlock *root,
542 const ShCompileOptions &compileOptions,
543 PerformanceDiagnostics *perfDiagnostics)
544 {
545 // TODO(anglebug.com/8662): until the translator is ready to translate most basic shaders, emit
546 // the code commented out.
547 TInfoSinkBase &sink = getInfoSink().obj;
548 sink << "/*\n";
549 OutputWGSLTraverser traverser(this);
550 root->traverse(&traverser);
551 sink << "*/\n";
552
553 std::cout << getInfoSink().obj.str();
554
555 // TODO(anglebug.com/8662): delete this.
556 if (getShaderType() == GL_VERTEX_SHADER)
557 {
558 constexpr const char *kVertexShader = R"(@vertex
559 fn main(@builtin(vertex_index) vertex_index : u32) -> @builtin(position) vec4f
560 {
561 const pos = array(
562 vec2( 0.0, 0.5),
563 vec2(-0.5, -0.5),
564 vec2( 0.5, -0.5)
565 );
566
567 return vec4f(pos[vertex_index % 3], 0, 1);
568 })";
569 sink << kVertexShader;
570 }
571 else if (getShaderType() == GL_FRAGMENT_SHADER)
572 {
573 constexpr const char *kFragmentShader = R"(@fragment
574 fn main() -> @location(0) vec4f
575 {
576 return vec4(1, 0, 0, 1);
577 })";
578 sink << kFragmentShader;
579 }
580 else
581 {
582 UNREACHABLE();
583 return false;
584 }
585
586 return true;
587 }
588
shouldFlattenPragmaStdglInvariantAll()589 bool TranslatorWGSL::shouldFlattenPragmaStdglInvariantAll()
590 {
591 // Not neccesary for WGSL transformation.
592 return false;
593 }
594 } // namespace sh
595