• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include <cctype>
8 #include <map>
9 
10 #include "common/system_utils.h"
11 #include "compiler/translator/ImmutableStringBuilder.h"
12 #include "compiler/translator/SymbolTable.h"
13 #include "compiler/translator/TranslatorMetalDirect.h"
14 #include "compiler/translator/TranslatorMetalDirect/AstHelpers.h"
15 #include "compiler/translator/TranslatorMetalDirect/DebugSink.h"
16 #include "compiler/translator/TranslatorMetalDirect/EmitMetal.h"
17 #include "compiler/translator/TranslatorMetalDirect/Layout.h"
18 #include "compiler/translator/TranslatorMetalDirect/Name.h"
19 #include "compiler/translator/TranslatorMetalDirect/ProgramPrelude.h"
20 #include "compiler/translator/TranslatorMetalDirect/RewritePipelines.h"
21 #include "compiler/translator/tree_util/IntermTraverse.h"
22 
23 using namespace sh;
24 
25 ////////////////////////////////////////////////////////////////////////////////
26 
27 #if defined(ANGLE_ENABLE_ASSERTS)
28 using Sink = DebugSink;
29 #else
30 using Sink = TInfoSinkBase;
31 #endif
32 
33 ////////////////////////////////////////////////////////////////////////////////
34 
35 namespace
36 {
37 
38 struct VarDecl
39 {
VarDecl__anon16f098020111::VarDecl40     explicit VarDecl(const TVariable &var) : mVariable(&var), mIsField(false) {}
VarDecl__anon16f098020111::VarDecl41     explicit VarDecl(const TField &field) : mField(&field), mIsField(true) {}
42 
variable__anon16f098020111::VarDecl43     ANGLE_INLINE const TVariable &variable() const
44     {
45         ASSERT(isVariable());
46         return *mVariable;
47     }
48 
field__anon16f098020111::VarDecl49     ANGLE_INLINE const TField &field() const
50     {
51         ASSERT(isField());
52         return *mField;
53     }
54 
isVariable__anon16f098020111::VarDecl55     ANGLE_INLINE bool isVariable() const { return !mIsField; }
56 
isField__anon16f098020111::VarDecl57     ANGLE_INLINE bool isField() const { return mIsField; }
58 
type__anon16f098020111::VarDecl59     const TType &type() const { return isField() ? *field().type() : variable().getType(); }
60 
symbolType__anon16f098020111::VarDecl61     SymbolType symbolType() const
62     {
63         return isField() ? field().symbolType() : variable().symbolType();
64     }
65 
66   private:
67     union
68     {
69         const TVariable *mVariable;
70         const TField *mField;
71     };
72     bool mIsField;
73 };
74 
75 class GenMetalTraverser : public TIntermTraverser
76 {
77   public:
78     ~GenMetalTraverser() override;
79 
80     GenMetalTraverser(const TCompiler &compiler,
81                       Sink &out,
82                       IdGen &idGen,
83                       const PipelineStructs &pipelineStructs,
84                       const Invariants &invariants,
85                       SymbolEnv &symbolEnv,
86                       TSymbolTable *symbolTable);
87 
88     void visitSymbol(TIntermSymbol *) override;
89     void visitConstantUnion(TIntermConstantUnion *) override;
90     bool visitSwizzle(Visit, TIntermSwizzle *) override;
91     bool visitBinary(Visit, TIntermBinary *) override;
92     bool visitUnary(Visit, TIntermUnary *) override;
93     bool visitTernary(Visit, TIntermTernary *) override;
94     bool visitIfElse(Visit, TIntermIfElse *) override;
95     bool visitSwitch(Visit, TIntermSwitch *) override;
96     bool visitCase(Visit, TIntermCase *) override;
97     void visitFunctionPrototype(TIntermFunctionPrototype *) override;
98     bool visitFunctionDefinition(Visit, TIntermFunctionDefinition *) override;
99     bool visitAggregate(Visit, TIntermAggregate *) override;
100     bool visitBlock(Visit, TIntermBlock *) override;
101     bool visitGlobalQualifierDeclaration(Visit, TIntermGlobalQualifierDeclaration *) override;
102     bool visitDeclaration(Visit, TIntermDeclaration *) override;
103     bool visitLoop(Visit, TIntermLoop *) override;
104     bool visitForLoop(TIntermLoop *);
105     bool visitWhileLoop(TIntermLoop *);
106     bool visitDoWhileLoop(TIntermLoop *);
107     bool visitBranch(Visit, TIntermBranch *) override;
108 
109   private:
110     using FuncToName = std::map<ImmutableString, Name>;
111     static FuncToName BuildFuncToName();
112 
113     struct EmitVariableDeclarationConfig
114     {
115         bool isParameter                = false;
116         bool isMainParameter            = false;
117         bool emitPostQualifier          = false;
118         bool isPacked                   = false;
119         bool disableStructSpecifier     = false;
120         bool isUBO                      = false;
121         const AddressSpace *isPointer   = nullptr;
122         const AddressSpace *isReference = nullptr;
123     };
124 
125     struct EmitTypeConfig
126     {
127         const EmitVariableDeclarationConfig *evdConfig = nullptr;
128     };
129 
130     void emitIndentation();
131     void emitOpeningPointerParen();
132     void emitClosingPointerParen();
133     void emitFunctionSignature(const TFunction &func);
134     void emitFunctionReturn(const TFunction &func);
135     void emitFunctionParameter(const TFunction &func, const TVariable &param);
136 
137     void emitNameOf(const TField &object);
138     void emitNameOf(const TSymbol &object);
139     void emitNameOf(const VarDecl &object);
140 
141     void emitBareTypeName(const TType &type, const EmitTypeConfig &etConfig);
142     void emitType(const TType &type, const EmitTypeConfig &etConfig);
143     void emitPostQualifier(const EmitVariableDeclarationConfig &evdConfig,
144                            const VarDecl &decl,
145                            const TQualifier qualifier);
146 
147     struct FieldAnnotationIndices
148     {
149         size_t attribute = 0;
150         size_t color     = 0;
151     };
152 
153     void emitFieldDeclaration(const TField &field,
154                               const TStructure &parent,
155                               FieldAnnotationIndices &annotationIndices);
156     void emitAttributeDeclaration(const TField &field, FieldAnnotationIndices &annotationIndices);
157     void emitUniformBufferDeclaration(const TField &field,
158                                       FieldAnnotationIndices &annotationIndices);
159     void emitStructDeclaration(const TType &type);
160     void emitOrdinaryVariableDeclaration(const VarDecl &decl,
161                                          const EmitVariableDeclarationConfig &evdConfig);
162     void emitVariableDeclaration(const VarDecl &decl,
163                                  const EmitVariableDeclarationConfig &evdConfig);
164 
165     void emitOpenBrace();
166     void emitCloseBrace();
167 
168     void groupedTraverse(TIntermNode &node);
169 
170     const TField &getDirectField(const TFieldListCollection &fieldsNode,
171                                  const TConstantUnion &index);
172     const TField &getDirectField(const TIntermTyped &fieldsNode, TIntermTyped &indexNode);
173 
174     const TConstantUnion *emitConstantUnionArray(const TConstantUnion *const constUnion,
175                                                  const size_t size);
176 
177     const TConstantUnion *emitConstantUnion(const TType &type, const TConstantUnion *constUnion);
178 
179     void emitSingleConstant(const TConstantUnion *const constUnion);
180 
181   private:
182     Sink &mOut;
183     const TCompiler &mCompiler;
184     const PipelineStructs &mPipelineStructs;
185     const Invariants &mInvariants;
186     SymbolEnv &mSymbolEnv;
187     IdGen &mIdGen;
188     int mIndentLevel                  = -1;
189     int mLastIndentationPos           = -1;
190     int mOpenPointerParenCount        = 0;
191     bool mParentIsSwitch              = false;
192     bool isTraversingVertexMain       = false;
193     bool mTemporarilyDisableSemicolon = false;
194     std::unordered_map<const TSymbol *, Name> mRenamedSymbols;
195     const FuncToName mFuncToName          = BuildFuncToName();
196     size_t mMainTextureIndex              = 0;
197     size_t mMainSamplerIndex              = 0;
198     size_t mMainUniformBufferIndex        = 0;
199     size_t mDriverUniformsBindingIndex    = 0;
200     size_t mUBOArgumentBufferBindingIndex = 0;
201 };
202 
203 }  // anonymous namespace
204 
~GenMetalTraverser()205 GenMetalTraverser::~GenMetalTraverser()
206 {
207     ASSERT(mIndentLevel == -1);
208     ASSERT(!mParentIsSwitch);
209     ASSERT(mOpenPointerParenCount == 0);
210 }
211 
GenMetalTraverser(const TCompiler & compiler,Sink & out,IdGen & idGen,const PipelineStructs & pipelineStructs,const Invariants & invariants,SymbolEnv & symbolEnv,TSymbolTable * symbolTable)212 GenMetalTraverser::GenMetalTraverser(const TCompiler &compiler,
213                                      Sink &out,
214                                      IdGen &idGen,
215                                      const PipelineStructs &pipelineStructs,
216                                      const Invariants &invariants,
217                                      SymbolEnv &symbolEnv,
218                                      TSymbolTable *symbolTable)
219     : TIntermTraverser(true, false, false),
220       mOut(out),
221       mCompiler(compiler),
222       mPipelineStructs(pipelineStructs),
223       mInvariants(invariants),
224       mSymbolEnv(symbolEnv),
225       mIdGen(idGen),
226       mMainUniformBufferIndex(symbolTable->getDefaultUniformsBindingIndex()),
227       mDriverUniformsBindingIndex(symbolTable->getDriverUniformsBindingIndex()),
228       mUBOArgumentBufferBindingIndex(symbolTable->getUBOArgumentBufferBindingIndex())
229 {}
230 
emitIndentation()231 void GenMetalTraverser::emitIndentation()
232 {
233     ASSERT(mIndentLevel >= 0);
234 
235     if (mLastIndentationPos == mOut.size())
236     {
237         return;  // Line is already indented.
238     }
239 
240     for (int i = 0; i < mIndentLevel; ++i)
241     {
242         mOut << "  ";
243     }
244 
245     mLastIndentationPos = mOut.size();
246 }
247 
emitOpeningPointerParen()248 void GenMetalTraverser::emitOpeningPointerParen()
249 {
250     mOut << "(*";
251     mOpenPointerParenCount++;
252 }
253 
emitClosingPointerParen()254 void GenMetalTraverser::emitClosingPointerParen()
255 {
256     if (mOpenPointerParenCount > 0)
257     {
258         mOut << ")";
259         mOpenPointerParenCount--;
260     }
261 }
262 
GetOperatorString(TOperator op,const TType & resultType,const TType * argType0,const TType * argType1=nullptr)263 static const char *GetOperatorString(TOperator op,
264                                      const TType &resultType,
265                                      const TType *argType0,
266                                      const TType *argType1 = nullptr)
267 {
268     switch (op)
269     {
270         case TOperator::EOpComma:
271             return ",";
272         case TOperator::EOpAssign:
273             return "=";
274         case TOperator::EOpInitialize:
275             return "=";
276         case TOperator::EOpAddAssign:
277             return "+=";
278         case TOperator::EOpSubAssign:
279             return "-=";
280         case TOperator::EOpMulAssign:
281             return "*=";
282         case TOperator::EOpDivAssign:
283             return "/=";
284         case TOperator::EOpIModAssign:
285             return "%=";
286         case TOperator::EOpBitShiftLeftAssign:
287             return "<<=";  // TODO: Check logical vs arithmetic shifting.
288         case TOperator::EOpBitShiftRightAssign:
289             return ">>=";  // TODO: Check logical vs arithmetic shifting.
290         case TOperator::EOpBitwiseAndAssign:
291             return "&=";
292         case TOperator::EOpBitwiseXorAssign:
293             return "^=";
294         case TOperator::EOpBitwiseOrAssign:
295             return "|=";
296         case TOperator::EOpAdd:
297             return "+";
298         case TOperator::EOpSub:
299             return "-";
300         case TOperator::EOpMul:
301             return "*";
302         case TOperator::EOpDiv:
303             return "/";
304         case TOperator::EOpIMod:
305             return "%";
306         case TOperator::EOpBitShiftLeft:
307             return "<<";  // TODO: Check logical vs arithmetic shifting.
308         case TOperator::EOpBitShiftRight:
309             return ">>";  // TODO: Check logical vs arithmetic shifting.
310         case TOperator::EOpBitwiseAnd:
311             return "&";
312         case TOperator::EOpBitwiseXor:
313             return "^";
314         case TOperator::EOpBitwiseOr:
315             return "|";
316         case TOperator::EOpLessThan:
317             return "<";
318         case TOperator::EOpGreaterThan:
319             return ">";
320         case TOperator::EOpLessThanEqual:
321             return "<=";
322         case TOperator::EOpGreaterThanEqual:
323             return ">=";
324         case TOperator::EOpLessThanComponentWise:
325             return "<";
326         case TOperator::EOpLessThanEqualComponentWise:
327             return "<=";
328         case TOperator::EOpGreaterThanEqualComponentWise:
329             return ">=";
330         case TOperator::EOpGreaterThanComponentWise:
331             return ">";
332         case TOperator::EOpLogicalOr:
333             return "||";
334         case TOperator::EOpLogicalXor:
335             return "!=/*xor*/";  // XXX: This might need to be handled differently for some obtuse
336                                  // use case.
337         case TOperator::EOpLogicalAnd:
338             return "&&";
339         case TOperator::EOpNegative:
340             return "-";
341         case TOperator::EOpPositive:
342             if (argType0->isMatrix())
343             {
344                 return "";
345             }
346             return "+";
347         case TOperator::EOpLogicalNot:
348             return "!";
349         case TOperator::EOpNotComponentWise:
350             return "!";
351         case TOperator::EOpBitwiseNot:
352             return "~";
353         case TOperator::EOpPostIncrement:
354             return "++";
355         case TOperator::EOpPostDecrement:
356             return "--";
357         case TOperator::EOpPreIncrement:
358             return "++";
359         case TOperator::EOpPreDecrement:
360             return "--";
361         case TOperator::EOpVectorTimesScalarAssign:
362             return "*=";
363         case TOperator::EOpVectorTimesMatrixAssign:
364             return "*=";
365         case TOperator::EOpMatrixTimesScalarAssign:
366             return "*=";
367         case TOperator::EOpMatrixTimesMatrixAssign:
368             return "*=";
369         case TOperator::EOpVectorTimesScalar:
370             return "*";
371         case TOperator::EOpVectorTimesMatrix:
372             return "*";
373         case TOperator::EOpMatrixTimesVector:
374             return "*";
375         case TOperator::EOpMatrixTimesScalar:
376             return "*";
377         case TOperator::EOpMatrixTimesMatrix:
378             return "*";
379         case TOperator::EOpEqualComponentWise:
380             return "==";
381         case TOperator::EOpNotEqualComponentWise:
382             return "!=";
383 
384         case TOperator::EOpEqual:
385             if ((argType0->getStruct() && argType1->getStruct()) &&
386                 (argType0->isArray() && argType1->isArray()))
387             {
388                 return "ANGLE_equalStructArray";
389             }
390 
391             if ((argType0->isVector() && argType1->isVector()) ||
392                 (argType0->getStruct() && argType1->getStruct()) ||
393                 (argType0->isArray() && argType1->isArray()) ||
394                 (argType0->isMatrix() && argType1->isMatrix()))
395 
396             {
397                 return "ANGLE_equal";
398             }
399 
400             return "==";
401 
402         case TOperator::EOpNotEqual:
403             if ((argType0->getStruct() && argType1->getStruct()) &&
404                 (argType0->isArray() && argType1->isArray()))
405             {
406                 return "ANGLE_notEqualStructArray";
407             }
408 
409             if ((argType0->isVector() && argType1->isVector()) ||
410                 (argType0->isArray() && argType1->isArray()) ||
411                 (argType0->isMatrix() && argType1->isMatrix()))
412             {
413                 return "ANGLE_notEqual";
414             }
415             else if (argType0->getStruct() && argType1->getStruct())
416             {
417                 return "ANGLE_notEqualStruct";
418             }
419             return "!=";
420 
421         case TOperator::EOpKill:
422             UNIMPLEMENTED();
423             return "kill";
424         case TOperator::EOpReturn:
425             return "return";
426         case TOperator::EOpBreak:
427             return "break";
428         case TOperator::EOpContinue:
429             return "continue";
430 
431         case TOperator::EOpRadians:
432             return "ANGLE_radians";
433         case TOperator::EOpDegrees:
434             return "ANGLE_degrees";
435         case TOperator::EOpAtan:
436             return "ANGLE_atan";
437         case TOperator::EOpMod:
438             return "ANGLE_mod";  // differs from metal::mod
439         case TOperator::EOpRefract:
440             return "ANGLE_refract";
441         case TOperator::EOpDistance:
442             return "ANGLE_distance";
443         case TOperator::EOpLength:
444             return "ANGLE_length";
445         case TOperator::EOpDot:
446             return "ANGLE_dot";
447         case TOperator::EOpNormalize:
448             return "ANGLE_normalize";
449         case TOperator::EOpFaceforward:
450             return "ANGLE_faceforward";
451         case TOperator::EOpReflect:
452             return "ANGLE_reflect";
453         case TOperator::EOpMatrixCompMult:
454             return "ANGLE_componentWiseMultiply";
455         case TOperator::EOpOuterProduct:
456             return "ANGLE_outerProduct";
457         case TOperator::EOpSign:
458             return "ANGLE_sign";
459 
460         case TOperator::EOpAbs:
461             return "metal::abs";
462         case TOperator::EOpAll:
463             return "metal::all";
464         case TOperator::EOpAny:
465             return "metal::any";
466         case TOperator::EOpSin:
467             return "metal::sin";
468         case TOperator::EOpCos:
469             return "metal::cos";
470         case TOperator::EOpTan:
471             return "metal::tan";
472         case TOperator::EOpAsin:
473             return "metal::asin";
474         case TOperator::EOpAcos:
475             return "metal::acos";
476         case TOperator::EOpSinh:
477             return "metal::sinh";
478         case TOperator::EOpCosh:
479             return "metal::cosh";
480         case TOperator::EOpTanh:
481             return "metal::tanh";
482         case TOperator::EOpAsinh:
483             return "metal::asinh";
484         case TOperator::EOpAcosh:
485             return "metal::acosh";
486         case TOperator::EOpAtanh:
487             return "metal::atanh";
488         case TOperator::EOpFma:
489             return "metal::fma";
490         case TOperator::EOpPow:
491             return "metal::pow";
492         case TOperator::EOpExp:
493             return "metal::exp";
494         case TOperator::EOpExp2:
495             return "metal::exp2";
496         case TOperator::EOpLog:
497             return "metal::log";
498         case TOperator::EOpLog2:
499             return "metal::log2";
500         case TOperator::EOpSqrt:
501             return "metal::sqrt";
502         case TOperator::EOpFloor:
503             return "metal::floor";
504         case TOperator::EOpTrunc:
505             return "metal::trunc";
506         case TOperator::EOpCeil:
507             return "metal::ceil";
508         case TOperator::EOpFract:
509             return "metal::fract";
510         case TOperator::EOpMin:
511             return "metal::min";
512         case TOperator::EOpMax:
513             return "metal::max";
514         case TOperator::EOpRound:
515             return "metal::round";
516         case TOperator::EOpRoundEven:
517             return "metal::rint";
518         case TOperator::EOpClamp:
519             return "metal::clamp";  // TODO fast vs precise namespace
520         case TOperator::EOpMix:
521             return "metal::mix";
522         case TOperator::EOpStep:
523             return "metal::step";
524         case TOperator::EOpSmoothstep:
525             return "metal::smoothstep";
526         case TOperator::EOpModf:
527             return "metal::modf";
528         case TOperator::EOpIsnan:
529             return "metal::isnan";
530         case TOperator::EOpIsinf:
531             return "metal::isinf";
532         case TOperator::EOpLdexp:
533             return "metal::ldexp";
534         case TOperator::EOpFrexp:
535             return "metal::frexp";
536         case TOperator::EOpInversesqrt:
537             return "metal::rsqrt";
538         case TOperator::EOpCross:
539             return "metal::cross";
540         case TOperator::EOpDFdx:
541             return "metal::dfdx";
542         case TOperator::EOpDFdy:
543             return "metal::dfdy";
544         case TOperator::EOpFwidth:
545             return "metal::fwidth";
546         case TOperator::EOpTranspose:
547             return "metal::transpose";
548         case TOperator::EOpDeterminant:
549             return "metal::determinant";
550 
551         case TOperator::EOpInverse:
552             return "ANGLE_inverse";
553 
554         case TOperator::EOpFloatBitsToInt:
555         case TOperator::EOpFloatBitsToUint:
556         case TOperator::EOpIntBitsToFloat:
557         case TOperator::EOpUintBitsToFloat:
558         {
559 #define RETURN_AS_TYPE(post)                     \
560     do                                           \
561         switch (resultType.getBasicType())       \
562         {                                        \
563             case TBasicType::EbtInt:             \
564                 return "as_type<int" post ">";   \
565             case TBasicType::EbtUInt:            \
566                 return "as_type<uint" post ">";  \
567             case TBasicType::EbtFloat:           \
568                 return "as_type<float" post ">"; \
569             default:                             \
570                 UNIMPLEMENTED();                 \
571                 return "TOperator_TODO";         \
572         }                                        \
573     while (false)
574 
575             if (resultType.isScalar())
576             {
577                 RETURN_AS_TYPE("");
578             }
579             else if (resultType.isVector())
580             {
581                 switch (resultType.getNominalSize())
582                 {
583                     case 2:
584                         RETURN_AS_TYPE("2");
585                     case 3:
586                         RETURN_AS_TYPE("3");
587                     case 4:
588                         RETURN_AS_TYPE("4");
589                     default:
590                         UNREACHABLE();
591                         return nullptr;
592                 }
593             }
594             else
595             {
596                 UNIMPLEMENTED();
597                 return "TOperator_TODO";
598             }
599 
600 #undef RETURN_AS_TYPE
601         }
602 
603         case TOperator::EOpPackUnorm2x16:
604             return "metal::pack_float_to_unorm2x16";
605         case TOperator::EOpPackSnorm2x16:
606             return "metal::pack_float_to_snorm2x16";
607 
608         case TOperator::EOpPackUnorm4x8:
609             return "metal::pack_float_to_unorm4x8";
610         case TOperator::EOpPackSnorm4x8:
611             return "metal::pack_float_to_snorm4x8";
612 
613         case TOperator::EOpUnpackUnorm2x16:
614             return "metal::unpack_unorm2x16_to_float";
615         case TOperator::EOpUnpackSnorm2x16:
616             return "metal::unpack_snorm2x16_to_float";
617 
618         case TOperator::EOpUnpackUnorm4x8:
619             return "metal::unpack_unorm4x8_to_float";
620         case TOperator::EOpUnpackSnorm4x8:
621             return "metal::unpack_snorm4x8_to_float";
622 
623         case TOperator::EOpPackHalf2x16:
624             return "ANGLE_pack_half_2x16";
625         case TOperator::EOpUnpackHalf2x16:
626             return "ANGLE_unpack_half_2x16";
627 
628         case TOperator::EOpBitfieldExtract:
629         case TOperator::EOpBitfieldInsert:
630         case TOperator::EOpBitfieldReverse:
631         case TOperator::EOpBitCount:
632         case TOperator::EOpFindLSB:
633         case TOperator::EOpFindMSB:
634         case TOperator::EOpUaddCarry:
635         case TOperator::EOpUsubBorrow:
636         case TOperator::EOpUmulExtended:
637         case TOperator::EOpImulExtended:
638         case TOperator::EOpBarrier:
639         case TOperator::EOpMemoryBarrier:
640         case TOperator::EOpMemoryBarrierAtomicCounter:
641         case TOperator::EOpMemoryBarrierBuffer:
642         case TOperator::EOpMemoryBarrierImage:
643         case TOperator::EOpMemoryBarrierShared:
644         case TOperator::EOpGroupMemoryBarrier:
645         case TOperator::EOpAtomicAdd:
646         case TOperator::EOpAtomicMin:
647         case TOperator::EOpAtomicMax:
648         case TOperator::EOpAtomicAnd:
649         case TOperator::EOpAtomicOr:
650         case TOperator::EOpAtomicXor:
651         case TOperator::EOpAtomicExchange:
652         case TOperator::EOpAtomicCompSwap:
653         case TOperator::EOpEmitVertex:
654         case TOperator::EOpEndPrimitive:
655         case TOperator::EOpFtransform:
656         case TOperator::EOpPackDouble2x32:
657         case TOperator::EOpUnpackDouble2x32:
658         case TOperator::EOpArrayLength:
659             UNIMPLEMENTED();
660             return "TOperator_TODO";
661 
662         case TOperator::EOpNull:
663         case TOperator::EOpConstruct:
664         case TOperator::EOpCallFunctionInAST:
665         case TOperator::EOpCallInternalRawFunction:
666         case TOperator::EOpIndexDirect:
667         case TOperator::EOpIndexIndirect:
668         case TOperator::EOpIndexDirectStruct:
669         case TOperator::EOpIndexDirectInterfaceBlock:
670             UNREACHABLE();
671             return nullptr;
672         default:
673             // Any other built-in function.
674             return nullptr;
675     }
676 }
677 
IsSymbolicOperator(TOperator op,const TType & resultType,const TType * argType0,const TType * argType1=nullptr)678 static bool IsSymbolicOperator(TOperator op,
679                                const TType &resultType,
680                                const TType *argType0,
681                                const TType *argType1 = nullptr)
682 {
683     const char *operatorString = GetOperatorString(op, resultType, argType0, argType1);
684     if (operatorString == nullptr)
685     {
686         return false;
687     }
688     return !std::isalnum(operatorString[0]);
689 }
690 
AsSpecificBinaryNode(TIntermNode & node,TOperator op)691 static TIntermBinary *AsSpecificBinaryNode(TIntermNode &node, TOperator op)
692 {
693     TIntermBinary *binaryNode = node.getAsBinaryNode();
694     if (binaryNode)
695     {
696         return binaryNode->getOp() == op ? binaryNode : nullptr;
697     }
698     return nullptr;
699 }
700 
Parenthesize(TIntermNode & node)701 static bool Parenthesize(TIntermNode &node)
702 {
703     if (node.getAsSymbolNode())
704     {
705         return false;
706     }
707     if (node.getAsConstantUnion())
708     {
709         return false;
710     }
711     if (node.getAsAggregate())
712     {
713         return false;
714     }
715     if (node.getAsSwizzleNode())
716     {
717         return false;
718     }
719 
720     if (TIntermUnary *unaryNode = node.getAsUnaryNode())
721     {
722         // TODO: Use a precedence and associativity rules instead of this ad-hoc impl.
723         const TType &resultType = unaryNode->getType();
724         const TType &argType    = unaryNode->getOperand()->getType();
725         return IsSymbolicOperator(unaryNode->getOp(), resultType, &argType);
726     }
727 
728     if (TIntermBinary *binaryNode = node.getAsBinaryNode())
729     {
730         // TODO: Use a precedence and associativity rules instead of this ad-hoc impl.
731         const TOperator op = binaryNode->getOp();
732         switch (op)
733         {
734             case TOperator::EOpIndexDirectStruct:
735             case TOperator::EOpIndexDirectInterfaceBlock:
736             case TOperator::EOpIndexDirect:
737             case TOperator::EOpIndexIndirect:
738                 return Parenthesize(*binaryNode->getLeft());
739 
740             case TOperator::EOpAssign:
741             case TOperator::EOpInitialize:
742                 return AsSpecificBinaryNode(*binaryNode->getRight(), TOperator::EOpComma);
743 
744             default:
745             {
746                 const TType &resultType = binaryNode->getType();
747                 const TType &leftType   = binaryNode->getLeft()->getType();
748                 const TType &rightType  = binaryNode->getRight()->getType();
749                 return IsSymbolicOperator(binaryNode->getOp(), resultType, &leftType, &rightType);
750             }
751         }
752     }
753 
754     return true;
755 }
756 
groupedTraverse(TIntermNode & node)757 void GenMetalTraverser::groupedTraverse(TIntermNode &node)
758 {
759     const bool emitParens = Parenthesize(node);
760 
761     if (emitParens)
762     {
763         mOut << "(";
764     }
765 
766     node.traverse(this);
767 
768     if (emitParens)
769     {
770         mOut << ")";
771     }
772 }
773 
emitPostQualifier(const EmitVariableDeclarationConfig & evdConfig,const VarDecl & decl,const TQualifier qualifier)774 void GenMetalTraverser::emitPostQualifier(const EmitVariableDeclarationConfig &evdConfig,
775                                           const VarDecl &decl,
776                                           const TQualifier qualifier)
777 {
778     switch (qualifier)
779     {
780         case TQualifier::EvqPosition:
781         case TQualifier::EvqFragCoord:
782             mOut << " [[position]]";
783             break;
784 
785         case TQualifier::EvqPointSize:
786             mOut << " [[point_size]]";
787             break;
788 
789         case TQualifier::EvqVertexID:
790             if (evdConfig.isMainParameter)
791             {
792                 mOut << " [[vertex_id]]";
793             }
794             break;
795 
796         case TQualifier::EvqPointCoord:
797             if (evdConfig.isMainParameter)
798             {
799                 mOut << " [[point_coord]]";
800             }
801             break;
802 
803         case TQualifier::EvqFrontFacing:
804             if (evdConfig.isMainParameter)
805             {
806                 mOut << " [[front_facing]]";
807             }
808             break;
809 
810         default:
811             break;
812     }
813 
814     const bool isInvariant =
815         (decl.isField() ? mInvariants.contains(decl.field())
816                         : mInvariants.contains(decl.variable())) &&
817         (qualifier == TQualifier::EvqPosition || qualifier == TQualifier::EvqFragCoord);
818 
819     if (isInvariant)
820     {
821         mOut << " [[invariant]]";
822     }
823 }
824 
EmitName(Sink & out,const Name & name)825 static void EmitName(Sink &out, const Name &name)
826 {
827 #if defined(ANGLE_ENABLE_ASSERTS)
828     DebugSink::EscapedSink escapedOut(out.escape());
829 #else
830     TInfoSinkBase &escapedOut = out;
831 #endif
832     name.emit(escapedOut);
833 }
834 
emitNameOf(const TField & object)835 void GenMetalTraverser::emitNameOf(const TField &object)
836 {
837     EmitName(mOut, Name(object));
838 }
839 
emitNameOf(const TSymbol & object)840 void GenMetalTraverser::emitNameOf(const TSymbol &object)
841 {
842     auto it = mRenamedSymbols.find(&object);
843     if (it == mRenamedSymbols.end())
844     {
845         EmitName(mOut, Name(object));
846     }
847     else
848     {
849         EmitName(mOut, it->second);
850     }
851 }
852 
emitNameOf(const VarDecl & object)853 void GenMetalTraverser::emitNameOf(const VarDecl &object)
854 {
855     if (object.isField())
856     {
857         emitNameOf(object.field());
858     }
859     else
860     {
861         emitNameOf(object.variable());
862     }
863 }
864 
emitBareTypeName(const TType & type,const EmitTypeConfig & etConfig)865 void GenMetalTraverser::emitBareTypeName(const TType &type, const EmitTypeConfig &etConfig)
866 {
867     const TBasicType basicType = type.getBasicType();
868 
869     switch (basicType)
870     {
871         case TBasicType::EbtVoid:
872         case TBasicType::EbtBool:
873         case TBasicType::EbtFloat:
874         case TBasicType::EbtInt:
875         case TBasicType::EbtUInt:
876         {
877             mOut << type.getBasicString();
878         }
879         break;
880 
881         case TBasicType::EbtStruct:
882         {
883             const TStructure &structure = *type.getStruct();
884             emitNameOf(structure);
885         }
886         break;
887 
888         case TBasicType::EbtInterfaceBlock:
889         {
890             const TInterfaceBlock &interfaceBlock = *type.getInterfaceBlock();
891             emitNameOf(interfaceBlock);
892         }
893         break;
894 
895         default:
896         {
897             if (IsSampler(basicType))
898             {
899                 if (etConfig.evdConfig && etConfig.evdConfig->isMainParameter)
900                 {
901                     EmitName(mOut, GetTextureTypeName(basicType));
902                 }
903                 else
904                 {
905                     const TStructure &env = mSymbolEnv.getTextureEnv(basicType);
906                     emitNameOf(env);
907                 }
908             }
909             else
910             {
911                 UNIMPLEMENTED();
912             }
913         }
914     }
915 }
916 
emitType(const TType & type,const EmitTypeConfig & etConfig)917 void GenMetalTraverser::emitType(const TType &type, const EmitTypeConfig &etConfig)
918 {
919     const bool isUBO = etConfig.evdConfig ? etConfig.evdConfig->isUBO : false;
920     if (etConfig.evdConfig)
921     {
922         const auto &evdConfig = *etConfig.evdConfig;
923         if (isUBO)
924         {
925             if (type.isArray())
926             {
927                 mOut << "ANGLE_tensor<";
928             }
929         }
930         if (evdConfig.isPointer)
931         {
932             mOut << toString(*evdConfig.isPointer);
933             mOut << " ";
934         }
935         else if (evdConfig.isReference)
936         {
937             mOut << toString(*evdConfig.isReference);
938             mOut << " ";
939         }
940     }
941 
942     if (!isUBO)
943     {
944         if (type.isArray())
945         {
946             mOut << "ANGLE_tensor<";
947         }
948     }
949 
950     if (type.isVector() || type.isMatrix())
951     {
952         mOut << "metal::";
953     }
954 
955     if (etConfig.evdConfig && etConfig.evdConfig->isPacked)
956     {
957         mOut << "packed_";
958     }
959 
960     emitBareTypeName(type, etConfig);
961 
962     if (type.isVector())
963     {
964         mOut << type.getNominalSize();
965     }
966     else if (type.isMatrix())
967     {
968         mOut << type.getCols() << "x" << type.getRows();
969     }
970 
971     if (!isUBO)
972     {
973         if (type.isArray())
974         {
975             for (auto size : type.getArraySizes())
976             {
977                 mOut << ", " << size;
978             }
979             mOut << ">";
980         }
981     }
982 
983     if (etConfig.evdConfig)
984     {
985         const auto &evdConfig = *etConfig.evdConfig;
986         if (evdConfig.isPointer)
987         {
988             mOut << " *";
989         }
990         else if (evdConfig.isReference)
991         {
992             mOut << " &";
993         }
994         if (isUBO)
995         {
996             if (type.isArray())
997             {
998                 for (auto size : type.getArraySizes())
999                 {
1000                     mOut << ", " << size;
1001                 }
1002                 mOut << ">";
1003             }
1004         }
1005     }
1006 }
1007 
emitFieldDeclaration(const TField & field,const TStructure & parent,FieldAnnotationIndices & annotationIndices)1008 void GenMetalTraverser::emitFieldDeclaration(const TField &field,
1009                                              const TStructure &parent,
1010                                              FieldAnnotationIndices &annotationIndices)
1011 {
1012     const TType &type      = *field.type();
1013     const TBasicType basic = type.getBasicType();
1014 
1015     EmitVariableDeclarationConfig evdConfig;
1016     evdConfig.emitPostQualifier      = true;
1017     evdConfig.disableStructSpecifier = true;
1018     evdConfig.isPacked               = mSymbolEnv.isPacked(field);
1019     evdConfig.isUBO                  = mSymbolEnv.isUBO(field);
1020     evdConfig.isPointer              = mSymbolEnv.isPointer(field);
1021     evdConfig.isReference            = mSymbolEnv.isReference(field);
1022     emitVariableDeclaration(VarDecl(field), evdConfig);
1023 
1024     const TQualifier qual = type.getQualifier();
1025     switch (qual)
1026     {
1027         case TQualifier::EvqFlatIn:
1028             if (mPipelineStructs.fragmentIn.external == &parent)
1029             {
1030                 mOut << " [[flat]]";
1031                 TranslatorMetalReflection *reflection =
1032                     ((sh::TranslatorMetalDirect *)&mCompiler)->getTranslatorMetalReflection();
1033                 reflection->hasFlatInput = true;
1034             }
1035             break;
1036 
1037         case TQualifier::EvqFragmentOut:
1038         case TQualifier::EvqFragData:
1039             if (mPipelineStructs.fragmentOut.external == &parent)
1040             {
1041                 if ((type.isVector() &&
1042                      (basic == TBasicType::EbtInt || basic == TBasicType::EbtUInt ||
1043                       basic == TBasicType::EbtFloat)) ||
1044                     type.getQualifier() == EvqFragData)
1045                 {
1046                     // TODO:
1047                     // This is not correct in general and needs a reimplementation.
1048                     // In GLSL 3.0, We can't always assume this is going to be a safe index.
1049                     // See 4.3.8.2
1050                     // https://www.khronos.org/registry/OpenGL/specs/es/3.0/GLSL_ES_Specification_3.00.pdf
1051                     size_t index = annotationIndices.color++;
1052                     mOut << " [[color(" << index << ")]]";
1053                 }
1054             }
1055             break;
1056 
1057         case TQualifier::EvqFragDepth:
1058             mOut << " [[depth(any)]]";
1059             break;
1060 
1061         case TQualifier::EvqSampleMask:
1062             mOut << " [[sample_mask, function_constant(" << sh::mtl::kCoverageMaskEnabledConstName
1063                  << ")]]";
1064             break;
1065 
1066         default:
1067             break;
1068     }
1069 }
1070 
BuildExternalAttributeIndexMap(const TCompiler & compiler,const PipelineScoped<TStructure> & structure)1071 static std::map<Name, size_t> BuildExternalAttributeIndexMap(
1072     const TCompiler &compiler,
1073     const PipelineScoped<TStructure> &structure)
1074 {
1075     ASSERT(structure.isTotallyFull());
1076 
1077     const auto &shaderVars     = compiler.getAttributes();
1078     const size_t shaderVarSize = shaderVars.size();
1079     size_t shaderVarIndex      = 0;
1080 
1081     const auto &externalFields = structure.external->fields();
1082     const size_t externalSize  = externalFields.size();
1083     size_t externalIndex       = 0;
1084 
1085     const auto &internalFields = structure.internal->fields();
1086     const size_t internalSize  = internalFields.size();
1087     size_t internalIndex       = 0;
1088 
1089     // Internal fields are never split. External fields are sometimes split.
1090     ASSERT(externalSize >= internalSize);
1091 
1092     // Structures do not contain any inactive fields.
1093     ASSERT(shaderVarSize >= internalSize);
1094 
1095     std::map<Name, size_t> externalNameToAttributeIndex;
1096     size_t attributeIndex = 0;
1097 
1098     while (internalIndex < internalSize)
1099     {
1100         const TField &internalField = *internalFields[internalIndex];
1101         const Name internalName     = Name(internalField);
1102         const TType &internalType   = *internalField.type();
1103         while (internalName.rawName() != shaderVars[shaderVarIndex].name &&
1104                internalName.rawName() != shaderVars[shaderVarIndex].mappedName)
1105         {
1106             // This case represents an inactive field.
1107 
1108             ++shaderVarIndex;
1109             ASSERT(shaderVarIndex < shaderVarSize);
1110 
1111             ++attributeIndex;  // TODO: Might need to increment more if shader var type is a matrix.
1112         }
1113 
1114         const size_t cols = internalType.isMatrix() ? internalType.getCols() : 1;
1115 
1116         for (size_t c = 0; c < cols; ++c)
1117         {
1118             const TField &externalField = *externalFields[externalIndex];
1119             const Name externalName     = Name(externalField);
1120             ASSERT(!externalField.type()->isMatrix());
1121 
1122             externalNameToAttributeIndex[externalName] = attributeIndex;
1123 
1124             ++externalIndex;
1125             ++attributeIndex;
1126         }
1127 
1128         ++shaderVarIndex;
1129         ++internalIndex;
1130     }
1131 
1132     ASSERT(shaderVarIndex <= shaderVarSize);
1133     ASSERT(externalIndex <= externalSize);  // less than if padding was introduced
1134     ASSERT(internalIndex == internalSize);
1135 
1136     return externalNameToAttributeIndex;
1137 }
1138 
emitAttributeDeclaration(const TField & field,FieldAnnotationIndices & annotationIndices)1139 void GenMetalTraverser::emitAttributeDeclaration(const TField &field,
1140                                                  FieldAnnotationIndices &annotationIndices)
1141 {
1142     EmitVariableDeclarationConfig evdConfig;
1143     evdConfig.disableStructSpecifier = true;
1144     emitVariableDeclaration(VarDecl(field), evdConfig);
1145     mOut << sh::kUnassignedAttributeString;
1146 }
1147 
emitUniformBufferDeclaration(const TField & field,FieldAnnotationIndices & annotationIndices)1148 void GenMetalTraverser::emitUniformBufferDeclaration(const TField &field,
1149                                                      FieldAnnotationIndices &annotationIndices)
1150 {
1151     EmitVariableDeclarationConfig evdConfig;
1152     evdConfig.disableStructSpecifier = true;
1153     evdConfig.isUBO                  = mSymbolEnv.isUBO(field);
1154     evdConfig.isPointer              = mSymbolEnv.isPointer(field);
1155     emitVariableDeclaration(VarDecl(field), evdConfig);
1156     mOut << "[[id(" << annotationIndices.attribute << ")]]";
1157 
1158     const TType &type   = *field.type();
1159     const int arraySize = type.isArray() ? type.getArraySizeProduct() : 1;
1160 
1161     TranslatorMetalReflection *reflection =
1162         ((sh::TranslatorMetalDirect *)&mCompiler)->getTranslatorMetalReflection();
1163     ASSERT(type.getBasicType() == TBasicType::EbtStruct);
1164     const TStructure *structure    = type.getStruct();
1165     const std::string originalName = reflection->getOriginalName(structure->uniqueId().get());
1166     reflection->addUniformBufferBinding(
1167         originalName,
1168         {.bindIndex = annotationIndices.attribute, .arraySize = static_cast<size_t>(arraySize)});
1169 
1170     annotationIndices.attribute += arraySize;
1171 }
1172 
emitStructDeclaration(const TType & type)1173 void GenMetalTraverser::emitStructDeclaration(const TType &type)
1174 {
1175     ASSERT(type.getBasicType() == TBasicType::EbtStruct);
1176     ASSERT(type.isStructSpecifier());
1177 
1178     mOut << "struct ";
1179     emitBareTypeName(type, {});
1180 
1181     mOut << "\n";
1182     emitOpenBrace();
1183 
1184     const TStructure &structure = *type.getStruct();
1185     std::map<Name, size_t> fieldToAttributeIndex;
1186     const bool hasAttributeIndices      = mPipelineStructs.vertexIn.external == &structure;
1187     const bool hasUniformBufferIndicies = mPipelineStructs.uniformBuffers.external == &structure;
1188     const bool reclaimUnusedAttributeIndices = mCompiler.getShaderVersion() < 300;
1189 
1190     if (hasAttributeIndices)
1191     {
1192         fieldToAttributeIndex =
1193             BuildExternalAttributeIndexMap(mCompiler, mPipelineStructs.vertexIn);
1194     }
1195 
1196     FieldAnnotationIndices annotationIndices;
1197 
1198     for (const TField *field : structure.fields())
1199     {
1200         emitIndentation();
1201         if (hasAttributeIndices)
1202         {
1203             const auto it = fieldToAttributeIndex.find(Name(*field));
1204             if (it == fieldToAttributeIndex.end())
1205             {
1206                 ASSERT(field->symbolType() == SymbolType::AngleInternal);
1207                 ASSERT(field->name().beginsWith("_"));
1208                 ASSERT(angle::EndsWith(field->name().data(), "_pad"));
1209                 emitFieldDeclaration(*field, structure, annotationIndices);
1210             }
1211             else
1212             {
1213                 ASSERT(field->symbolType() != SymbolType::AngleInternal ||
1214                        !field->name().beginsWith("_") ||
1215                        !angle::EndsWith(field->name().data(), "_pad"));
1216                 if (!reclaimUnusedAttributeIndices)
1217                 {
1218                     annotationIndices.attribute = it->second;
1219                 }
1220                 emitAttributeDeclaration(*field, annotationIndices);
1221             }
1222         }
1223         else if (hasUniformBufferIndicies)
1224         {
1225             emitUniformBufferDeclaration(*field, annotationIndices);
1226         }
1227         else
1228         {
1229             emitFieldDeclaration(*field, structure, annotationIndices);
1230         }
1231         mOut << ";\n";
1232     }
1233 
1234     if (!mPipelineStructs.matches(structure, true, true))
1235     {
1236         MetalLayoutOfConfig layoutConfig;
1237         layoutConfig.treatSamplersAsTextureEnv = true;
1238         Layout layout                          = MetalLayoutOf(type, layoutConfig);
1239         size_t pad = (kDefaultStructAlignmentSize - layout.sizeOf) % kDefaultStructAlignmentSize;
1240         if (pad != 0)
1241         {
1242             emitIndentation();
1243             mOut << "char ";
1244             EmitName(mOut, mIdGen.createNewName("pad"));
1245             mOut << "[" << pad << "];\n";
1246         }
1247     }
1248 
1249     emitCloseBrace();
1250 }
1251 
emitOrdinaryVariableDeclaration(const VarDecl & decl,const EmitVariableDeclarationConfig & evdConfig)1252 void GenMetalTraverser::emitOrdinaryVariableDeclaration(
1253     const VarDecl &decl,
1254     const EmitVariableDeclarationConfig &evdConfig)
1255 {
1256     EmitTypeConfig etConfig;
1257     etConfig.evdConfig = &evdConfig;
1258 
1259     const TType &type = decl.type();
1260     emitType(type, etConfig);
1261     if (decl.symbolType() != SymbolType::Empty)
1262     {
1263         mOut << " ";
1264         emitNameOf(decl);
1265     }
1266 }
1267 
emitVariableDeclaration(const VarDecl & decl,const EmitVariableDeclarationConfig & evdConfig)1268 void GenMetalTraverser::emitVariableDeclaration(const VarDecl &decl,
1269                                                 const EmitVariableDeclarationConfig &evdConfig)
1270 {
1271     const SymbolType symbolType = decl.symbolType();
1272     const TType &type           = decl.type();
1273     const TBasicType basicType  = type.getBasicType();
1274 
1275     switch (basicType)
1276     {
1277         case TBasicType::EbtStruct:
1278         {
1279             if (type.isStructSpecifier() && !evdConfig.disableStructSpecifier)
1280             {
1281                 ASSERT(!evdConfig.isParameter);
1282                 emitStructDeclaration(type);
1283                 if (symbolType != SymbolType::Empty)
1284                 {
1285                     mOut << " ";
1286                     emitNameOf(decl);
1287                 }
1288             }
1289             else
1290             {
1291                 emitOrdinaryVariableDeclaration(decl, evdConfig);
1292             }
1293         }
1294         break;
1295 
1296         default:
1297         {
1298             ASSERT(symbolType != SymbolType::Empty || evdConfig.isParameter);
1299             emitOrdinaryVariableDeclaration(decl, evdConfig);
1300         }
1301     }
1302 
1303     if (evdConfig.emitPostQualifier)
1304     {
1305         emitPostQualifier(evdConfig, decl, type.getQualifier());
1306     }
1307 }
1308 
visitSymbol(TIntermSymbol * symbolNode)1309 void GenMetalTraverser::visitSymbol(TIntermSymbol *symbolNode)
1310 {
1311     const TVariable &var = symbolNode->variable();
1312     const TType &type    = var.getType();
1313     ASSERT(var.symbolType() != SymbolType::Empty);
1314 
1315     if (type.getBasicType() == TBasicType::EbtVoid)
1316     {
1317         mOut << "/*";
1318         emitNameOf(var);
1319         mOut << "*/";
1320     }
1321     else
1322     {
1323         emitNameOf(var);
1324     }
1325 }
1326 
emitSingleConstant(const TConstantUnion * const constUnion)1327 void GenMetalTraverser::emitSingleConstant(const TConstantUnion *const constUnion)
1328 {
1329     switch (constUnion->getType())
1330     {
1331         case TBasicType::EbtBool:
1332         {
1333             mOut << (constUnion->getBConst() ? "true" : "false");
1334         }
1335         break;
1336 
1337         case TBasicType::EbtFloat:
1338         {
1339             mOut << constUnion->getFConst() << "f";
1340         }
1341         break;
1342 
1343         case TBasicType::EbtInt:
1344         {
1345             mOut << constUnion->getIConst();
1346         }
1347         break;
1348 
1349         case TBasicType::EbtUInt:
1350         {
1351             mOut << constUnion->getUConst() << "u";
1352         }
1353         break;
1354 
1355         default:
1356         {
1357             UNIMPLEMENTED();
1358         }
1359     }
1360 }
1361 
emitConstantUnionArray(const TConstantUnion * const constUnion,const size_t size)1362 const TConstantUnion *GenMetalTraverser::emitConstantUnionArray(
1363     const TConstantUnion *const constUnion,
1364     const size_t size)
1365 {
1366     const TConstantUnion *constUnionIterated = constUnion;
1367     for (size_t i = 0; i < size; i++, constUnionIterated++)
1368     {
1369         emitSingleConstant(constUnionIterated);
1370 
1371         if (i != size - 1)
1372         {
1373             mOut << ", ";
1374         }
1375     }
1376     return constUnionIterated;
1377 }
1378 
emitConstantUnion(const TType & type,const TConstantUnion * constUnionBegin)1379 const TConstantUnion *GenMetalTraverser::emitConstantUnion(const TType &type,
1380                                                            const TConstantUnion *constUnionBegin)
1381 {
1382     const TConstantUnion *constUnionCurr = constUnionBegin;
1383     const TStructure *structure          = type.getStruct();
1384     if (structure)
1385     {
1386         EmitTypeConfig config = EmitTypeConfig{nullptr};
1387         emitType(type, config);
1388         mOut << "{";
1389         const TFieldList &fields = structure->fields();
1390         for (size_t i = 0; i < fields.size(); ++i)
1391         {
1392             const TType *fieldType = fields[i]->type();
1393             constUnionCurr         = emitConstantUnion(*fieldType, constUnionCurr);
1394             if (i != fields.size() - 1)
1395             {
1396                 mOut << ", ";
1397             }
1398         }
1399         mOut << "}";
1400     }
1401     else
1402     {
1403         size_t size    = type.getObjectSize();
1404         bool writeType = size > 1;
1405         if (writeType)
1406         {
1407             EmitTypeConfig config = EmitTypeConfig{nullptr};
1408             emitType(type, config);
1409             mOut << "(";
1410         }
1411         constUnionCurr = emitConstantUnionArray(constUnionCurr, size);
1412         if (writeType)
1413         {
1414             mOut << ")";
1415         }
1416     }
1417     return constUnionCurr;
1418 }
1419 
visitConstantUnion(TIntermConstantUnion * constValueNode)1420 void GenMetalTraverser::visitConstantUnion(TIntermConstantUnion *constValueNode)
1421 {
1422     emitConstantUnion(constValueNode->getType(), constValueNode->getConstantValue());
1423 }
1424 
visitSwizzle(Visit,TIntermSwizzle * swizzleNode)1425 bool GenMetalTraverser::visitSwizzle(Visit, TIntermSwizzle *swizzleNode)
1426 {
1427     groupedTraverse(*swizzleNode->getOperand());
1428     mOut << ".";
1429 
1430     {
1431 #if defined(ANGLE_ENABLE_ASSERTS)
1432         DebugSink::EscapedSink escapedOut(mOut.escape());
1433         TInfoSinkBase &out = escapedOut.get();
1434 #else
1435         TInfoSinkBase &out        = mOut;
1436 #endif
1437         swizzleNode->writeOffsetsAsXYZW(&out);
1438     }
1439 
1440     return false;
1441 }
1442 
getDirectField(const TFieldListCollection & fieldListCollection,const TConstantUnion & index)1443 const TField &GenMetalTraverser::getDirectField(const TFieldListCollection &fieldListCollection,
1444                                                 const TConstantUnion &index)
1445 {
1446     ASSERT(index.getType() == TBasicType::EbtInt);
1447 
1448     const TFieldList &fieldList = fieldListCollection.fields();
1449     const int indexVal          = index.getIConst();
1450     const TField &field         = *fieldList[indexVal];
1451 
1452     return field;
1453 }
1454 
getDirectField(const TIntermTyped & fieldsNode,TIntermTyped & indexNode)1455 const TField &GenMetalTraverser::getDirectField(const TIntermTyped &fieldsNode,
1456                                                 TIntermTyped &indexNode)
1457 {
1458     const TType &fieldsType = fieldsNode.getType();
1459 
1460     const TFieldListCollection *fieldListCollection = fieldsType.getStruct();
1461     if (fieldListCollection == nullptr)
1462     {
1463         fieldListCollection = fieldsType.getInterfaceBlock();
1464     }
1465     ASSERT(fieldListCollection);
1466 
1467     const TIntermConstantUnion *indexNode_ = indexNode.getAsConstantUnion();
1468     ASSERT(indexNode_);
1469     const TConstantUnion &index = *indexNode_->getConstantValue();
1470 
1471     return getDirectField(*fieldListCollection, index);
1472 }
1473 
visitBinary(Visit,TIntermBinary * binaryNode)1474 bool GenMetalTraverser::visitBinary(Visit, TIntermBinary *binaryNode)
1475 {
1476     const TOperator op      = binaryNode->getOp();
1477     TIntermTyped &leftNode  = *binaryNode->getLeft();
1478     TIntermTyped &rightNode = *binaryNode->getRight();
1479 
1480     switch (op)
1481     {
1482         case TOperator::EOpIndexDirectStruct:
1483         case TOperator::EOpIndexDirectInterfaceBlock:
1484         {
1485             const TField &field = getDirectField(leftNode, rightNode);
1486             if (mSymbolEnv.isPointer(field) && mSymbolEnv.isUBO(field))
1487             {
1488                 emitOpeningPointerParen();
1489             }
1490             groupedTraverse(leftNode);
1491             if (!mSymbolEnv.isPointer(field))
1492             {
1493                 emitClosingPointerParen();
1494             }
1495             mOut << ".";
1496             emitNameOf(field);
1497         }
1498         break;
1499 
1500         case TOperator::EOpIndexDirect:
1501         case TOperator::EOpIndexIndirect:
1502         {
1503             TType leftType = leftNode.getType();
1504             groupedTraverse(leftNode);
1505             mOut << "[";
1506             {
1507                 mOut << "ANGLE_int_clamp(";
1508                 groupedTraverse(rightNode);
1509                 mOut << ", 0, ";
1510                 if (leftType.isUnsizedArray())
1511                 {
1512                     groupedTraverse(leftNode);
1513                     mOut << ".size()";
1514                 }
1515                 else
1516                 {
1517                     int maxSize;
1518                     if (leftType.isArray())
1519                     {
1520                         maxSize = static_cast<int>(leftType.getOutermostArraySize()) - 1;
1521                     }
1522                     else
1523                     {
1524                         maxSize = leftType.getNominalSize() - 1;
1525                     }
1526                     mOut << maxSize;
1527                 }
1528                 mOut << ")";
1529             }
1530             mOut << "]";
1531         }
1532         break;
1533 
1534         default:
1535         {
1536             const TType &resultType = binaryNode->getType();
1537             const TType &leftType   = leftNode.getType();
1538             const TType &rightType  = rightNode.getType();
1539 
1540             if (IsSymbolicOperator(op, resultType, &leftType, &rightType))
1541             {
1542                 groupedTraverse(leftNode);
1543                 if (op != TOperator::EOpComma)
1544                 {
1545                     mOut << " ";
1546                 }
1547                 else
1548                 {
1549                     emitClosingPointerParen();
1550                 }
1551                 mOut << GetOperatorString(op, resultType, &leftType, &rightType) << " ";
1552                 groupedTraverse(rightNode);
1553             }
1554             else
1555             {
1556                 emitClosingPointerParen();
1557                 mOut << GetOperatorString(op, resultType, &leftType, &rightType) << "(";
1558                 leftNode.traverse(this);
1559                 mOut << ", ";
1560                 rightNode.traverse(this);
1561                 mOut << ")";
1562             }
1563         }
1564     }
1565 
1566     return false;
1567 }
1568 
IsPostfix(TOperator op)1569 static bool IsPostfix(TOperator op)
1570 {
1571     switch (op)
1572     {
1573         case TOperator::EOpPostIncrement:
1574         case TOperator::EOpPostDecrement:
1575             return true;
1576 
1577         default:
1578             return false;
1579     }
1580 }
1581 
visitUnary(Visit,TIntermUnary * unaryNode)1582 bool GenMetalTraverser::visitUnary(Visit, TIntermUnary *unaryNode)
1583 {
1584     const TOperator op      = unaryNode->getOp();
1585     const TType &resultType = unaryNode->getType();
1586 
1587     TIntermTyped &arg    = *unaryNode->getOperand();
1588     const TType &argType = arg.getType();
1589 
1590     const char *name = GetOperatorString(op, resultType, &argType);
1591 
1592     if (IsSymbolicOperator(op, resultType, &argType))
1593     {
1594         const bool postfix = IsPostfix(op);
1595         if (!postfix)
1596         {
1597             mOut << name;
1598         }
1599         groupedTraverse(arg);
1600         if (postfix)
1601         {
1602             mOut << name;
1603         }
1604     }
1605     else
1606     {
1607         mOut << name << "(";
1608         arg.traverse(this);
1609         mOut << ")";
1610     }
1611 
1612     return false;
1613 }
1614 
visitTernary(Visit,TIntermTernary * conditionalNode)1615 bool GenMetalTraverser::visitTernary(Visit, TIntermTernary *conditionalNode)
1616 {
1617     groupedTraverse(*conditionalNode->getCondition());
1618     mOut << " ? ";
1619     groupedTraverse(*conditionalNode->getTrueExpression());
1620     mOut << " : ";
1621     groupedTraverse(*conditionalNode->getFalseExpression());
1622 
1623     return false;
1624 }
1625 
visitIfElse(Visit,TIntermIfElse * ifThenElseNode)1626 bool GenMetalTraverser::visitIfElse(Visit, TIntermIfElse *ifThenElseNode)
1627 {
1628     TIntermTyped &condNode = *ifThenElseNode->getCondition();
1629     TIntermBlock *thenNode = ifThenElseNode->getTrueBlock();
1630     TIntermBlock *elseNode = ifThenElseNode->getFalseBlock();
1631 
1632     emitIndentation();
1633     mOut << "if (";
1634     condNode.traverse(this);
1635     mOut << ")";
1636 
1637     if (thenNode)
1638     {
1639         mOut << "\n";
1640         thenNode->traverse(this);
1641     }
1642     else
1643     {
1644         mOut << " {}";
1645     }
1646 
1647     if (elseNode)
1648     {
1649         mOut << "\n";
1650         emitIndentation();
1651         mOut << "else\n";
1652         elseNode->traverse(this);
1653     }
1654     else
1655     {
1656         // Always emit "else" even when empty block to avoid nested if-stmt issues.
1657         mOut << " else {}";
1658     }
1659 
1660     return false;
1661 }
1662 
visitSwitch(Visit,TIntermSwitch * switchNode)1663 bool GenMetalTraverser::visitSwitch(Visit, TIntermSwitch *switchNode)
1664 {
1665     emitIndentation();
1666     mOut << "switch (";
1667     switchNode->getInit()->traverse(this);
1668     mOut << ")\n";
1669 
1670     ASSERT(!mParentIsSwitch);
1671     mParentIsSwitch = true;
1672     switchNode->getStatementList()->traverse(this);
1673     mParentIsSwitch = false;
1674 
1675     return false;
1676 }
1677 
visitCase(Visit,TIntermCase * caseNode)1678 bool GenMetalTraverser::visitCase(Visit, TIntermCase *caseNode)
1679 {
1680     emitIndentation();
1681 
1682     if (caseNode->hasCondition())
1683     {
1684         TIntermTyped *condExpr = caseNode->getCondition();
1685         mOut << "case ";
1686         condExpr->traverse(this);
1687         mOut << ":";
1688     }
1689     else
1690     {
1691         mOut << "default:\n";
1692     }
1693 
1694     return false;
1695 }
1696 
emitFunctionSignature(const TFunction & func)1697 void GenMetalTraverser::emitFunctionSignature(const TFunction &func)
1698 {
1699     const bool isMain = func.isMain();
1700 
1701     emitFunctionReturn(func);
1702 
1703     mOut << " ";
1704     emitNameOf(func);
1705     if (isMain)
1706     {
1707         mOut << "0";
1708     }
1709     mOut << "(";
1710 
1711     bool emitComma          = false;
1712     const size_t paramCount = func.getParamCount();
1713     for (size_t i = 0; i < paramCount; ++i)
1714     {
1715         if (emitComma)
1716         {
1717             mOut << ", ";
1718         }
1719         emitComma = true;
1720 
1721         const TVariable &param = *func.getParam(i);
1722         emitFunctionParameter(func, param);
1723     }
1724     // TODO(anglebug.com/5505): reimplement transform feedback support.
1725     //    if (isTraversingVertexMain)
1726     //    {
1727     //        mOut << " @@XFB-Bindings@@ ";
1728     //    }
1729 
1730     mOut << ")";
1731 }
1732 
emitFunctionReturn(const TFunction & func)1733 void GenMetalTraverser::emitFunctionReturn(const TFunction &func)
1734 {
1735     const bool isMain       = func.isMain();
1736     bool isVertexMain       = false;
1737     const TType &returnType = func.getReturnType();
1738     if (isMain)
1739     {
1740         const TStructure *structure = returnType.getStruct();
1741         ASSERT(structure != nullptr);
1742         if (mPipelineStructs.fragmentOut.matches(*structure))
1743         {
1744             mOut << "fragment ";
1745         }
1746         else if (mPipelineStructs.vertexOut.matches(*structure))
1747         {
1748             mOut << "vertex __VERTEX_OUT(";
1749             isVertexMain = true;
1750         }
1751         else
1752         {
1753             UNIMPLEMENTED();
1754         }
1755     }
1756     emitType(returnType, EmitTypeConfig());
1757     if (isVertexMain)
1758         mOut << ") ";
1759 }
1760 
emitFunctionParameter(const TFunction & func,const TVariable & param)1761 void GenMetalTraverser::emitFunctionParameter(const TFunction &func, const TVariable &param)
1762 {
1763     const bool isMain = func.isMain();
1764 
1765     const TType &type           = param.getType();
1766     const TStructure *structure = type.getStruct();
1767 
1768     EmitVariableDeclarationConfig evdConfig;
1769     evdConfig.isParameter       = true;
1770     evdConfig.isMainParameter   = isMain;
1771     evdConfig.emitPostQualifier = isMain;
1772     evdConfig.isUBO             = mSymbolEnv.isUBO(param);
1773     evdConfig.isPointer         = mSymbolEnv.isPointer(param);
1774     evdConfig.isReference       = mSymbolEnv.isReference(param);
1775     emitVariableDeclaration(VarDecl(param), evdConfig);
1776 
1777     if (isMain)
1778     {
1779         TranslatorMetalReflection *reflection =
1780             ((sh::TranslatorMetalDirect *)&mCompiler)->getTranslatorMetalReflection();
1781         if (structure)
1782         {
1783             if (mPipelineStructs.fragmentIn.matches(*structure) ||
1784                 mPipelineStructs.vertexIn.matches(*structure))
1785             {
1786                 mOut << " [[stage_in]]";
1787             }
1788             else if (mPipelineStructs.angleUniforms.matches(*structure))
1789             {
1790                 mOut << " [[buffer(" << mDriverUniformsBindingIndex << ")]]";
1791             }
1792             else if (mPipelineStructs.uniformBuffers.matches(*structure))
1793             {
1794                 mOut << " [[buffer(" << mUBOArgumentBufferBindingIndex << ")]]";
1795                 reflection->hasUBOs = true;
1796             }
1797             else if (mPipelineStructs.userUniforms.matches(*structure))
1798             {
1799                 mOut << " [[buffer(" << mMainUniformBufferIndex << ")]]";
1800                 reflection->addUserUniformBufferBinding(param.name().data(),
1801                                                         mMainUniformBufferIndex);
1802                 mMainUniformBufferIndex += type.getArraySizeProduct();
1803             }
1804             else if (structure->name() == "metal::sampler")
1805             {
1806                 mOut << " [[sampler(" << (mMainSamplerIndex) << ")]]";
1807                 const std::string originalName =
1808                     reflection->getOriginalName(param.uniqueId().get());
1809                 reflection->addSamplerBinding(originalName, mMainSamplerIndex);
1810                 mMainSamplerIndex += type.getArraySizeProduct();
1811             }
1812         }
1813         else if (IsSampler(type.getBasicType()))
1814         {
1815             mOut << " [[texture(" << (mMainTextureIndex) << ")]]";
1816             const std::string originalName = reflection->getOriginalName(param.uniqueId().get());
1817             reflection->addTextureBinding(originalName, mMainSamplerIndex);
1818             mMainTextureIndex += type.getArraySizeProduct();
1819         }
1820         else if (Name(param) == Pipeline{Pipeline::Type::InstanceId, nullptr}.getStructInstanceName(
1821                                     Pipeline::Variant::Modified))
1822         {
1823             mOut << " [[instance_id]]";
1824         }
1825     }
1826 }
1827 
visitFunctionPrototype(TIntermFunctionPrototype * funcProtoNode)1828 void GenMetalTraverser::visitFunctionPrototype(TIntermFunctionPrototype *funcProtoNode)
1829 {
1830     const TFunction &func = *funcProtoNode->getFunction();
1831 
1832     emitIndentation();
1833     emitFunctionSignature(func);
1834 }
1835 
visitFunctionDefinition(Visit,TIntermFunctionDefinition * funcDefNode)1836 bool GenMetalTraverser::visitFunctionDefinition(Visit, TIntermFunctionDefinition *funcDefNode)
1837 {
1838     const TFunction &func = *funcDefNode->getFunction();
1839     TIntermBlock &body    = *funcDefNode->getBody();
1840     if (func.isMain())
1841     {
1842         const TType &returnType     = func.getReturnType();
1843         const TStructure *structure = returnType.getStruct();
1844         isTraversingVertexMain      = (mPipelineStructs.vertexOut.matches(*structure));
1845     }
1846     emitIndentation();
1847     emitFunctionSignature(func);
1848     mOut << "\n";
1849     body.traverse(this);
1850     if (isTraversingVertexMain)
1851     {
1852         isTraversingVertexMain = false;
1853     }
1854     return false;
1855 }
1856 
BuildFuncToName()1857 GenMetalTraverser::FuncToName GenMetalTraverser::BuildFuncToName()
1858 {
1859     FuncToName map;
1860 
1861     auto putAngle = [&](const char *nameStr) {
1862         const ImmutableString name(nameStr);
1863         ASSERT(map.find(name) == map.end());
1864         map[name] = Name(nameStr, SymbolType::AngleInternal);
1865     };
1866 
1867     putAngle("texelFetch");
1868     putAngle("texelFetchOffset");
1869     putAngle("texture");
1870     putAngle("texture1D");
1871     putAngle("texture1DLod");
1872     putAngle("texture1DProjLod");
1873     putAngle("texture2D");
1874     putAngle("texture2DLod");
1875     putAngle("texture2DProj");
1876     putAngle("texture2DRect");
1877     putAngle("texture2DProjLod");
1878     putAngle("texture2DRectProj");
1879     putAngle("texture3D");
1880     putAngle("texture3DLod");
1881     putAngle("texture3DProjLod");
1882     putAngle("textureCube");
1883     putAngle("textureCubeLod");
1884     putAngle("textureCubeProjLod");
1885     putAngle("textureGrad");
1886     putAngle("textureGradOffset");
1887     putAngle("textureLod");
1888     putAngle("textureLodOffset");
1889     putAngle("textureOffset");
1890     putAngle("textureProj");
1891     putAngle("textureProjGrad");
1892     putAngle("textureProjGradOffset");
1893     putAngle("textureProjLod");
1894     putAngle("textureProjLodOffset");
1895     putAngle("textureProjOffset");
1896     putAngle("textureSize");
1897 
1898     return map;
1899 }
1900 
visitAggregate(Visit,TIntermAggregate * aggregateNode)1901 bool GenMetalTraverser::visitAggregate(Visit, TIntermAggregate *aggregateNode)
1902 {
1903     const TIntermSequence &args = *aggregateNode->getSequence();
1904 
1905     auto emitArgList = [&](const char *open, const char *close) {
1906         mOut << open;
1907 
1908         bool emitComma = false;
1909         for (TIntermNode *arg : args)
1910         {
1911             if (emitComma)
1912             {
1913                 emitClosingPointerParen();
1914                 mOut << ", ";
1915             }
1916             emitComma = true;
1917             arg->traverse(this);
1918         }
1919 
1920         mOut << close;
1921     };
1922 
1923     const TType &retType = aggregateNode->getType();
1924 
1925     if (aggregateNode->isConstructor())
1926     {
1927         const bool isStandalone = getParentNode()->getAsBlock();
1928         if (isStandalone)
1929         {
1930             // Prevent constructor from being interpreted as a declaration by wrapping in parens.
1931             // This can happen if given something like:
1932             //      int(symbol); // <- This will be treated like `int symbol;`... don't want that.
1933             // So instead emit:
1934             //      (int(symbol));
1935             mOut << "(";
1936         }
1937 
1938         const EmitTypeConfig etConfig;
1939 
1940         if (retType.isArray())
1941         {
1942             emitType(retType, etConfig);
1943             emitArgList("{", "}");
1944         }
1945         else if (retType.getStruct())
1946         {
1947             emitType(retType, etConfig);
1948             emitArgList("{", "}");
1949         }
1950         else
1951         {
1952             emitType(retType, etConfig);
1953             emitArgList("(", ")");
1954         }
1955 
1956         if (isStandalone)
1957         {
1958             mOut << ")";
1959         }
1960 
1961         return false;
1962     }
1963     else
1964     {
1965         const TOperator op = aggregateNode->getOp();
1966         switch (op)
1967         {
1968             case TOperator::EOpCallFunctionInAST:
1969             case TOperator::EOpCallInternalRawFunction:
1970             {
1971                 const TFunction &func = *aggregateNode->getFunction();
1972                 emitNameOf(func);
1973                 //'@' symbol in name specifices a macro substitution marker.
1974                 if (!func.name().contains("@"))
1975                 {
1976                     emitArgList("(", ")");
1977                 }
1978                 else
1979                 {
1980                     mTemporarilyDisableSemicolon =
1981                         true;  // Disable semicolon for macro substitution.
1982                 }
1983                 return false;
1984             }
1985 
1986             default:
1987             {
1988                 auto getArgType = [&](size_t index) -> const TType * {
1989                     if (index < args.size())
1990                     {
1991                         TIntermTyped *arg = args[index]->getAsTyped();
1992                         ASSERT(arg);
1993                         return &arg->getType();
1994                     }
1995                     return nullptr;
1996                 };
1997 
1998                 ASSERT(!args.empty());
1999                 const TType *argType0 = getArgType(0);
2000                 const TType *argType1 = getArgType(1);
2001                 ASSERT(argType0);
2002 
2003                 const char *opName = GetOperatorString(op, retType, argType0, argType1);
2004 
2005                 if (IsSymbolicOperator(op, retType, argType0, argType1))
2006                 {
2007                     switch (args.size())
2008                     {
2009                         case 1:
2010                         {
2011                             TIntermNode &operandNode = *aggregateNode->getChildNode(0);
2012                             if (IsPostfix(op))
2013                             {
2014                                 mOut << opName;
2015                                 groupedTraverse(operandNode);
2016                                 return false;
2017                             }
2018                             else
2019                             {
2020                                 groupedTraverse(operandNode);
2021                                 mOut << opName;
2022                                 return false;
2023                             }
2024                         }
2025                         break;
2026 
2027                         case 2:
2028                         {
2029                             TIntermNode &leftNode  = *aggregateNode->getChildNode(0);
2030                             TIntermNode &rightNode = *aggregateNode->getChildNode(1);
2031                             groupedTraverse(leftNode);
2032                             mOut << " " << opName << " ";
2033                             groupedTraverse(rightNode);
2034                             return false;
2035                         }
2036                         break;
2037 
2038                         default:
2039                             UNREACHABLE();
2040                             return false;
2041                     }
2042                 }
2043                 else if (opName == nullptr)
2044                 {
2045                     const TFunction &func = *aggregateNode->getFunction();
2046                     auto it               = mFuncToName.find(func.name());
2047                     ASSERT(it != mFuncToName.end());
2048                     EmitName(mOut, it->second);
2049                     emitArgList("(", ")");
2050                     return false;
2051                 }
2052                 else
2053                 {
2054                     mOut << opName;
2055                     emitArgList("(", ")");
2056                     return false;
2057                 }
2058             }
2059         }
2060     }
2061 }
2062 
emitOpenBrace()2063 void GenMetalTraverser::emitOpenBrace()
2064 {
2065     ASSERT(mIndentLevel >= 0);
2066 
2067     emitIndentation();
2068     mOut << "{\n";
2069     ++mIndentLevel;
2070 }
2071 
emitCloseBrace()2072 void GenMetalTraverser::emitCloseBrace()
2073 {
2074     ASSERT(mIndentLevel >= 1);
2075 
2076     --mIndentLevel;
2077     emitIndentation();
2078     mOut << "}";
2079 }
2080 
RequiresSemicolonTerminator(TIntermNode & node)2081 static bool RequiresSemicolonTerminator(TIntermNode &node)
2082 {
2083     if (node.getAsBlock())
2084     {
2085         return false;
2086     }
2087     if (node.getAsLoopNode())
2088     {
2089         return false;
2090     }
2091     if (node.getAsSwitchNode())
2092     {
2093         return false;
2094     }
2095     if (node.getAsIfElseNode())
2096     {
2097         return false;
2098     }
2099     if (node.getAsFunctionDefinition())
2100     {
2101         return false;
2102     }
2103     if (node.getAsCaseNode())
2104     {
2105         return false;
2106     }
2107 
2108     return true;
2109 }
2110 
NewlinePad(TIntermNode & node)2111 static bool NewlinePad(TIntermNode &node)
2112 {
2113     if (node.getAsFunctionDefinition())
2114     {
2115         return true;
2116     }
2117     if (TIntermDeclaration *declNode = node.getAsDeclarationNode())
2118     {
2119         ASSERT(declNode->getChildCount() == 1);
2120         TIntermNode &childNode = *declNode->getChildNode(0);
2121         if (TIntermSymbol *symbolNode = childNode.getAsSymbolNode())
2122         {
2123             const TVariable &var = symbolNode->variable();
2124             return var.getType().isStructSpecifier();
2125         }
2126         return false;
2127     }
2128     return false;
2129 }
2130 
visitBlock(Visit,TIntermBlock * blockNode)2131 bool GenMetalTraverser::visitBlock(Visit, TIntermBlock *blockNode)
2132 {
2133     ASSERT(mIndentLevel >= -1);
2134     const bool isGlobalScope  = mIndentLevel == -1;
2135     const bool parentIsSwitch = mParentIsSwitch;
2136     mParentIsSwitch           = false;
2137 
2138     if (isGlobalScope)
2139     {
2140         ++mIndentLevel;
2141     }
2142     else
2143     {
2144         emitOpenBrace();
2145         if (parentIsSwitch)
2146         {
2147             ++mIndentLevel;
2148         }
2149     }
2150 
2151     TIntermNode *prevStmtNode = nullptr;
2152 
2153     const size_t stmtCount = blockNode->getChildCount();
2154     for (size_t i = 0; i < stmtCount; ++i)
2155     {
2156         TIntermNode &stmtNode = *blockNode->getChildNode(i);
2157 
2158         if (isGlobalScope && prevStmtNode && (NewlinePad(*prevStmtNode) || NewlinePad(stmtNode)))
2159         {
2160             mOut << "\n";
2161         }
2162         const bool isCase = stmtNode.getAsCaseNode();
2163         mIndentLevel -= isCase;
2164         emitIndentation();
2165         mIndentLevel += isCase;
2166         stmtNode.traverse(this);
2167         if (RequiresSemicolonTerminator(stmtNode) && !mTemporarilyDisableSemicolon)
2168         {
2169             mOut << ";";
2170         }
2171         mTemporarilyDisableSemicolon = false;
2172         mOut << "\n";
2173 
2174         prevStmtNode = &stmtNode;
2175     }
2176 
2177     if (isGlobalScope)
2178     {
2179         ASSERT(mIndentLevel == 0);
2180         --mIndentLevel;
2181     }
2182     else
2183     {
2184         if (parentIsSwitch)
2185         {
2186             ASSERT(mIndentLevel >= 1);
2187             --mIndentLevel;
2188         }
2189         emitCloseBrace();
2190         mParentIsSwitch = parentIsSwitch;
2191     }
2192 
2193     return false;
2194 }
2195 
visitGlobalQualifierDeclaration(Visit,TIntermGlobalQualifierDeclaration *)2196 bool GenMetalTraverser::visitGlobalQualifierDeclaration(Visit, TIntermGlobalQualifierDeclaration *)
2197 {
2198     UNREACHABLE();  // RewriteGlobalQualifierDecls should have been called before this.
2199     return false;
2200 }
2201 
visitDeclaration(Visit,TIntermDeclaration * declNode)2202 bool GenMetalTraverser::visitDeclaration(Visit, TIntermDeclaration *declNode)
2203 {
2204     ASSERT(declNode->getChildCount() == 1);
2205     TIntermNode &node = *declNode->getChildNode(0);
2206 
2207     EmitVariableDeclarationConfig evdConfig;
2208 
2209     if (TIntermSymbol *symbolNode = node.getAsSymbolNode())
2210     {
2211         const TVariable &var = symbolNode->variable();
2212         emitVariableDeclaration(VarDecl(var), evdConfig);
2213     }
2214     else if (TIntermBinary *initNode = node.getAsBinaryNode())
2215     {
2216         ASSERT(initNode->getOp() == TOperator::EOpInitialize);
2217         TIntermSymbol *leftSymbolNode = initNode->getLeft()->getAsSymbolNode();
2218         TIntermTyped *valueNode       = initNode->getRight()->getAsTyped();
2219         ASSERT(leftSymbolNode && valueNode);
2220 
2221         if (getRootNode() == getParentBlock())
2222         {
2223             // DeferGlobalInitializers should have turned non-const global initializers into
2224             // deferred initializers. Note that variables marked as EvqGlobal can be treated as
2225             // EvqConst in some ANGLE code but not actually have their qualifier actually changed to
2226             // EvqConst. Thus just assume all EvqGlobal are actually EvqConst for all code run after
2227             // DeferGlobalInitializers.
2228             mOut << "constant ";
2229         }
2230 
2231         const TVariable &var = leftSymbolNode->variable();
2232         const Name varName(var);
2233 
2234         if (ExpressionContainsName(varName, *valueNode))
2235         {
2236             mRenamedSymbols[&var] = mIdGen.createNewName(varName);
2237         }
2238 
2239         emitVariableDeclaration(VarDecl(var), evdConfig);
2240         mOut << " = ";
2241         groupedTraverse(*valueNode);
2242     }
2243     else
2244     {
2245         UNREACHABLE();
2246     }
2247 
2248     return false;
2249 }
2250 
visitLoop(Visit,TIntermLoop * loopNode)2251 bool GenMetalTraverser::visitLoop(Visit, TIntermLoop *loopNode)
2252 {
2253     const TLoopType loopType = loopNode->getType();
2254 
2255     switch (loopType)
2256     {
2257         case TLoopType::ELoopFor:
2258             return visitForLoop(loopNode);
2259         case TLoopType::ELoopWhile:
2260             return visitWhileLoop(loopNode);
2261         case TLoopType::ELoopDoWhile:
2262             return visitDoWhileLoop(loopNode);
2263     }
2264 }
2265 
visitForLoop(TIntermLoop * loopNode)2266 bool GenMetalTraverser::visitForLoop(TIntermLoop *loopNode)
2267 {
2268     ASSERT(loopNode->getType() == TLoopType::ELoopFor);
2269 
2270     TIntermNode *initNode  = loopNode->getInit();
2271     TIntermTyped *condNode = loopNode->getCondition();
2272     TIntermTyped *exprNode = loopNode->getExpression();
2273     TIntermBlock *bodyNode = loopNode->getBody();
2274     ASSERT(bodyNode);
2275 
2276     mOut << "for (";
2277 
2278     if (initNode)
2279     {
2280         initNode->traverse(this);
2281     }
2282     else
2283     {
2284         mOut << " ";
2285     }
2286 
2287     mOut << "; ";
2288 
2289     if (condNode)
2290     {
2291         condNode->traverse(this);
2292     }
2293 
2294     mOut << "; ";
2295 
2296     if (exprNode)
2297     {
2298         exprNode->traverse(this);
2299     }
2300 
2301     mOut << ")\n";
2302 
2303     bodyNode->traverse(this);
2304 
2305     return false;
2306 }
2307 
visitWhileLoop(TIntermLoop * loopNode)2308 bool GenMetalTraverser::visitWhileLoop(TIntermLoop *loopNode)
2309 {
2310     ASSERT(loopNode->getType() == TLoopType::ELoopWhile);
2311 
2312     TIntermNode *initNode  = loopNode->getInit();
2313     TIntermTyped *condNode = loopNode->getCondition();
2314     TIntermTyped *exprNode = loopNode->getExpression();
2315     TIntermBlock *bodyNode = loopNode->getBody();
2316     ASSERT(condNode && bodyNode);
2317     ASSERT(!initNode && !exprNode);
2318 
2319     emitIndentation();
2320     mOut << "while (";
2321     condNode->traverse(this);
2322     mOut << ")\n";
2323     bodyNode->traverse(this);
2324 
2325     return false;
2326 }
2327 
visitDoWhileLoop(TIntermLoop * loopNode)2328 bool GenMetalTraverser::visitDoWhileLoop(TIntermLoop *loopNode)
2329 {
2330     ASSERT(loopNode->getType() == TLoopType::ELoopDoWhile);
2331 
2332     TIntermNode *initNode  = loopNode->getInit();
2333     TIntermTyped *condNode = loopNode->getCondition();
2334     TIntermTyped *exprNode = loopNode->getExpression();
2335     TIntermBlock *bodyNode = loopNode->getBody();
2336     ASSERT(condNode && bodyNode);
2337     ASSERT(!initNode && !exprNode);
2338 
2339     emitIndentation();
2340     mOut << "do\n";
2341     bodyNode->traverse(this);
2342     mOut << "\n";
2343     emitIndentation();
2344     mOut << "while (";
2345     condNode->traverse(this);
2346     mOut << ");";
2347 
2348     return false;
2349 }
2350 
visitBranch(Visit,TIntermBranch * branchNode)2351 bool GenMetalTraverser::visitBranch(Visit, TIntermBranch *branchNode)
2352 {
2353     const TOperator flowOp = branchNode->getFlowOp();
2354     TIntermTyped *exprNode = branchNode->getExpression();
2355 
2356     emitIndentation();
2357 
2358     switch (flowOp)
2359     {
2360         case TOperator::EOpKill:
2361         {
2362             ASSERT(exprNode == nullptr);
2363             mOut << "metal::discard_fragment()";
2364         }
2365         break;
2366 
2367         case TOperator::EOpReturn:
2368         {
2369             if (isTraversingVertexMain)
2370             {
2371                 mOut << "#if TRANSFORM_FEEDBACK_ENABLED\n";
2372                 emitIndentation();
2373                 mOut << "return;\n";
2374                 emitIndentation();
2375                 mOut << "#else\n";
2376                 emitIndentation();
2377             }
2378             mOut << "return";
2379             if (exprNode)
2380             {
2381                 mOut << " ";
2382                 exprNode->traverse(this);
2383                 mOut << ";";
2384             }
2385             if (isTraversingVertexMain)
2386             {
2387                 mOut << "\n";
2388                 emitIndentation();
2389                 mOut << "#endif\n";
2390                 mTemporarilyDisableSemicolon = true;
2391             }
2392         }
2393         break;
2394 
2395         case TOperator::EOpBreak:
2396         {
2397             ASSERT(exprNode == nullptr);
2398             mOut << "break";
2399         }
2400         break;
2401 
2402         case TOperator::EOpContinue:
2403         {
2404             ASSERT(exprNode == nullptr);
2405             mOut << "continue";
2406         }
2407         break;
2408 
2409         default:
2410         {
2411             UNREACHABLE();
2412         }
2413     }
2414 
2415     return false;
2416 }
2417 
2418 static size_t emitMetalCallCount = 0;
2419 
EmitMetal(TCompiler & compiler,TIntermBlock & root,IdGen & idGen,const PipelineStructs & pipelineStructs,const Invariants & invariants,SymbolEnv & symbolEnv,const ProgramPreludeConfig & ppc,TSymbolTable * symbolTable)2420 bool sh::EmitMetal(TCompiler &compiler,
2421                    TIntermBlock &root,
2422                    IdGen &idGen,
2423                    const PipelineStructs &pipelineStructs,
2424                    const Invariants &invariants,
2425                    SymbolEnv &symbolEnv,
2426                    const ProgramPreludeConfig &ppc,
2427                    TSymbolTable *symbolTable)
2428 {
2429     TInfoSinkBase &out = compiler.getInfoSink().obj;
2430 
2431     {
2432         ++emitMetalCallCount;
2433         std::string filenameProto = angle::GetEnvironmentVar("GMD_FIXED_EMIT");
2434         if (!filenameProto.empty())
2435         {
2436             if (filenameProto != "/dev/null")
2437             {
2438                 auto tryOpen = [&](char const *ext) {
2439                     auto filename = filenameProto;
2440                     filename += std::to_string(emitMetalCallCount);
2441                     filename += ".";
2442                     filename += ext;
2443                     return fopen(filename.c_str(), "rb");
2444                 };
2445                 FILE *file = tryOpen("metal");
2446                 if (!file)
2447                 {
2448                     file = tryOpen("cpp");
2449                 }
2450                 ASSERT(file);
2451 
2452                 fseek(file, 0, SEEK_END);
2453                 size_t fileSize = ftell(file);
2454                 fseek(file, 0, SEEK_SET);
2455 
2456                 std::vector<char> buff;
2457                 buff.resize(fileSize + 1);
2458                 fread(buff.data(), fileSize, 1, file);
2459                 buff.back() = '\0';
2460 
2461                 fclose(file);
2462 
2463                 out << buff.data();
2464             }
2465 
2466             return true;
2467         }
2468     }
2469 
2470     out << "\n\n";
2471 
2472     if (!EmitProgramPrelude(root, out, ppc))
2473     {
2474         return false;
2475     }
2476 
2477     {
2478 #if defined(ANGLE_ENABLE_ASSERTS)
2479         DebugSink outWrapper(out, angle::GetBoolEnvironmentVar("GMD_STDOUT"));
2480         outWrapper.watch(angle::GetEnvironmentVar("GMD_WATCH_STRING"));
2481 #else
2482         TInfoSinkBase &outWrapper = out;
2483 #endif
2484         GenMetalTraverser gen(compiler, outWrapper, idGen, pipelineStructs, invariants, symbolEnv,
2485                               symbolTable);
2486         root.traverse(&gen);
2487     }
2488 
2489     out << "\n";
2490 
2491     return true;
2492 }
2493