• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include <cctype>
8 #include <map>
9 
10 #include "common/system_utils.h"
11 #include "compiler/translator/BaseTypes.h"
12 #include "compiler/translator/ImmutableStringBuilder.h"
13 #include "compiler/translator/SymbolTable.h"
14 #include "compiler/translator/TranslatorMetalDirect.h"
15 #include "compiler/translator/TranslatorMetalDirect/AstHelpers.h"
16 #include "compiler/translator/TranslatorMetalDirect/DebugSink.h"
17 #include "compiler/translator/TranslatorMetalDirect/EmitMetal.h"
18 #include "compiler/translator/TranslatorMetalDirect/Layout.h"
19 #include "compiler/translator/TranslatorMetalDirect/Name.h"
20 #include "compiler/translator/TranslatorMetalDirect/ProgramPrelude.h"
21 #include "compiler/translator/TranslatorMetalDirect/RewritePipelines.h"
22 #include "compiler/translator/tree_util/IntermTraverse.h"
23 
24 using namespace sh;
25 
26 ////////////////////////////////////////////////////////////////////////////////
27 
28 #if defined(ANGLE_ENABLE_ASSERTS)
29 using Sink = DebugSink;
30 #else
31 using Sink = TInfoSinkBase;
32 #endif
33 
34 ////////////////////////////////////////////////////////////////////////////////
35 
36 namespace
37 {
38 
39 struct VarDecl
40 {
VarDecl__anon15dd06740111::VarDecl41     explicit VarDecl(const TVariable &var) : mVariable(&var), mIsField(false) {}
VarDecl__anon15dd06740111::VarDecl42     explicit VarDecl(const TField &field) : mField(&field), mIsField(true) {}
43 
variable__anon15dd06740111::VarDecl44     ANGLE_INLINE const TVariable &variable() const
45     {
46         ASSERT(isVariable());
47         return *mVariable;
48     }
49 
field__anon15dd06740111::VarDecl50     ANGLE_INLINE const TField &field() const
51     {
52         ASSERT(isField());
53         return *mField;
54     }
55 
isVariable__anon15dd06740111::VarDecl56     ANGLE_INLINE bool isVariable() const { return !mIsField; }
57 
isField__anon15dd06740111::VarDecl58     ANGLE_INLINE bool isField() const { return mIsField; }
59 
type__anon15dd06740111::VarDecl60     const TType &type() const { return isField() ? *field().type() : variable().getType(); }
61 
symbolType__anon15dd06740111::VarDecl62     SymbolType symbolType() const
63     {
64         return isField() ? field().symbolType() : variable().symbolType();
65     }
66 
67   private:
68     union
69     {
70         const TVariable *mVariable;
71         const TField *mField;
72     };
73     bool mIsField;
74 };
75 
76 class GenMetalTraverser : public TIntermTraverser
77 {
78   public:
79     ~GenMetalTraverser() override;
80 
81     GenMetalTraverser(const TCompiler &compiler,
82                       Sink &out,
83                       IdGen &idGen,
84                       const PipelineStructs &pipelineStructs,
85                       SymbolEnv &symbolEnv,
86                       TSymbolTable *symbolTable);
87 
88     void visitSymbol(TIntermSymbol *) override;
89     void visitConstantUnion(TIntermConstantUnion *) override;
90     bool visitSwizzle(Visit, TIntermSwizzle *) override;
91     bool visitBinary(Visit, TIntermBinary *) override;
92     bool visitUnary(Visit, TIntermUnary *) override;
93     bool visitTernary(Visit, TIntermTernary *) override;
94     bool visitIfElse(Visit, TIntermIfElse *) override;
95     bool visitSwitch(Visit, TIntermSwitch *) override;
96     bool visitCase(Visit, TIntermCase *) override;
97     void visitFunctionPrototype(TIntermFunctionPrototype *) override;
98     bool visitFunctionDefinition(Visit, TIntermFunctionDefinition *) override;
99     bool visitAggregate(Visit, TIntermAggregate *) override;
100     bool visitBlock(Visit, TIntermBlock *) override;
101     bool visitGlobalQualifierDeclaration(Visit, TIntermGlobalQualifierDeclaration *) override;
102     bool visitDeclaration(Visit, TIntermDeclaration *) override;
103     bool visitLoop(Visit, TIntermLoop *) override;
104     bool visitForLoop(TIntermLoop *);
105     bool visitWhileLoop(TIntermLoop *);
106     bool visitDoWhileLoop(TIntermLoop *);
107     bool visitBranch(Visit, TIntermBranch *) override;
108 
109   private:
110     using FuncToName = std::map<ImmutableString, Name>;
111     static FuncToName BuildFuncToName();
112 
113     struct EmitVariableDeclarationConfig
114     {
115         bool isParameter                = false;
116         bool isMainParameter            = false;
117         bool emitPostQualifier          = false;
118         bool isPacked                   = false;
119         bool disableStructSpecifier     = false;
120         bool isUBO                      = false;
121         const AddressSpace *isPointer   = nullptr;
122         const AddressSpace *isReference = nullptr;
123     };
124 
125     struct EmitTypeConfig
126     {
127         const EmitVariableDeclarationConfig *evdConfig = nullptr;
128     };
129 
130     void emitIndentation();
131     void emitOpeningPointerParen();
132     void emitClosingPointerParen();
133     void emitFunctionSignature(const TFunction &func);
134     void emitFunctionReturn(const TFunction &func);
135     void emitFunctionParameter(const TFunction &func, const TVariable &param);
136 
137     void emitNameOf(const TField &object);
138     void emitNameOf(const TSymbol &object);
139     void emitNameOf(const VarDecl &object);
140 
141     void emitBareTypeName(const TType &type, const EmitTypeConfig &etConfig);
142     void emitType(const TType &type, const EmitTypeConfig &etConfig);
143     void emitPostQualifier(const EmitVariableDeclarationConfig &evdConfig,
144                            const VarDecl &decl,
145                            const TQualifier qualifier);
146 
147     struct FieldAnnotationIndices
148     {
149         size_t attribute = 0;
150         size_t color     = 0;
151     };
152 
153     void emitFieldDeclaration(const TField &field,
154                               const TStructure &parent,
155                               FieldAnnotationIndices &annotationIndices);
156     void emitAttributeDeclaration(const TField &field, FieldAnnotationIndices &annotationIndices);
157     void emitUniformBufferDeclaration(const TField &field,
158                                       FieldAnnotationIndices &annotationIndices);
159     void emitStructDeclaration(const TType &type);
160     void emitOrdinaryVariableDeclaration(const VarDecl &decl,
161                                          const EmitVariableDeclarationConfig &evdConfig);
162     void emitVariableDeclaration(const VarDecl &decl,
163                                  const EmitVariableDeclarationConfig &evdConfig);
164 
165     void emitOpenBrace();
166     void emitCloseBrace();
167 
168     void groupedTraverse(TIntermNode &node);
169 
170     const TField &getDirectField(const TFieldListCollection &fieldsNode,
171                                  const TConstantUnion &index);
172     const TField &getDirectField(const TIntermTyped &fieldsNode, TIntermTyped &indexNode);
173 
174     const TConstantUnion *emitConstantUnionArray(const TConstantUnion *const constUnion,
175                                                  const size_t size);
176 
177     const TConstantUnion *emitConstantUnion(const TType &type, const TConstantUnion *constUnion);
178 
179     void emitSingleConstant(const TConstantUnion *const constUnion);
180 
181   private:
182     Sink &mOut;
183     const TCompiler &mCompiler;
184     const PipelineStructs &mPipelineStructs;
185     SymbolEnv &mSymbolEnv;
186     IdGen &mIdGen;
187     int mIndentLevel                  = -1;
188     int mLastIndentationPos           = -1;
189     int mOpenPointerParenCount        = 0;
190     bool mParentIsSwitch              = false;
191     bool isTraversingVertexMain       = false;
192     bool mTemporarilyDisableSemicolon = false;
193     std::unordered_map<const TSymbol *, Name> mRenamedSymbols;
194     const FuncToName mFuncToName          = BuildFuncToName();
195     size_t mMainTextureIndex              = 0;
196     size_t mMainSamplerIndex              = 0;
197     size_t mMainUniformBufferIndex        = 0;
198     size_t mDriverUniformsBindingIndex    = 0;
199     size_t mUBOArgumentBufferBindingIndex = 0;
200 };
201 }  // anonymous namespace
202 
~GenMetalTraverser()203 GenMetalTraverser::~GenMetalTraverser()
204 {
205     ASSERT(mIndentLevel == -1);
206     ASSERT(!mParentIsSwitch);
207     ASSERT(mOpenPointerParenCount == 0);
208 }
209 
GenMetalTraverser(const TCompiler & compiler,Sink & out,IdGen & idGen,const PipelineStructs & pipelineStructs,SymbolEnv & symbolEnv,TSymbolTable * symbolTable)210 GenMetalTraverser::GenMetalTraverser(const TCompiler &compiler,
211                                      Sink &out,
212                                      IdGen &idGen,
213                                      const PipelineStructs &pipelineStructs,
214                                      SymbolEnv &symbolEnv,
215                                      TSymbolTable *symbolTable)
216     : TIntermTraverser(true, false, false),
217       mOut(out),
218       mCompiler(compiler),
219       mPipelineStructs(pipelineStructs),
220       mSymbolEnv(symbolEnv),
221       mIdGen(idGen),
222       mMainUniformBufferIndex(symbolTable->getDefaultUniformsBindingIndex()),
223       mDriverUniformsBindingIndex(symbolTable->getDriverUniformsBindingIndex()),
224       mUBOArgumentBufferBindingIndex(symbolTable->getUBOArgumentBufferBindingIndex())
225 {}
226 
emitIndentation()227 void GenMetalTraverser::emitIndentation()
228 {
229     ASSERT(mIndentLevel >= 0);
230 
231     if (mLastIndentationPos == mOut.size())
232     {
233         return;  // Line is already indented.
234     }
235 
236     for (int i = 0; i < mIndentLevel; ++i)
237     {
238         mOut << "  ";
239     }
240 
241     mLastIndentationPos = mOut.size();
242 }
243 
emitOpeningPointerParen()244 void GenMetalTraverser::emitOpeningPointerParen()
245 {
246     mOut << "(*";
247     mOpenPointerParenCount++;
248 }
249 
emitClosingPointerParen()250 void GenMetalTraverser::emitClosingPointerParen()
251 {
252     if (mOpenPointerParenCount > 0)
253     {
254         mOut << ")";
255         mOpenPointerParenCount--;
256     }
257 }
258 
GetOperatorString(TOperator op,const TType & resultType,const TType * argType0,const TType * argType1,const TType * argType2)259 static const char *GetOperatorString(TOperator op,
260                                      const TType &resultType,
261                                      const TType *argType0,
262                                      const TType *argType1,
263                                      const TType *argType2)
264 {
265     switch (op)
266     {
267         case TOperator::EOpComma:
268             return ",";
269         case TOperator::EOpAssign:
270             return "=";
271         case TOperator::EOpInitialize:
272             return "=";
273         case TOperator::EOpAddAssign:
274             return "+=";
275         case TOperator::EOpSubAssign:
276             return "-=";
277         case TOperator::EOpMulAssign:
278             return "*=";
279         case TOperator::EOpDivAssign:
280             return "/=";
281         case TOperator::EOpIModAssign:
282             return "%=";
283         case TOperator::EOpBitShiftLeftAssign:
284             return "<<=";  // TODO: Check logical vs arithmetic shifting.
285         case TOperator::EOpBitShiftRightAssign:
286             return ">>=";  // TODO: Check logical vs arithmetic shifting.
287         case TOperator::EOpBitwiseAndAssign:
288             return "&=";
289         case TOperator::EOpBitwiseXorAssign:
290             return "^=";
291         case TOperator::EOpBitwiseOrAssign:
292             return "|=";
293         case TOperator::EOpAdd:
294             return "+";
295         case TOperator::EOpSub:
296             return "-";
297         case TOperator::EOpMul:
298             return "*";
299         case TOperator::EOpDiv:
300             return "/";
301         case TOperator::EOpIMod:
302             return "%";
303         case TOperator::EOpBitShiftLeft:
304             return "<<";  // TODO: Check logical vs arithmetic shifting.
305         case TOperator::EOpBitShiftRight:
306             return ">>";  // TODO: Check logical vs arithmetic shifting.
307         case TOperator::EOpBitwiseAnd:
308             return "&";
309         case TOperator::EOpBitwiseXor:
310             return "^";
311         case TOperator::EOpBitwiseOr:
312             return "|";
313         case TOperator::EOpLessThan:
314             return "<";
315         case TOperator::EOpGreaterThan:
316             return ">";
317         case TOperator::EOpLessThanEqual:
318             return "<=";
319         case TOperator::EOpGreaterThanEqual:
320             return ">=";
321         case TOperator::EOpLessThanComponentWise:
322             return "<";
323         case TOperator::EOpLessThanEqualComponentWise:
324             return "<=";
325         case TOperator::EOpGreaterThanEqualComponentWise:
326             return ">=";
327         case TOperator::EOpGreaterThanComponentWise:
328             return ">";
329         case TOperator::EOpLogicalOr:
330             return "||";
331         case TOperator::EOpLogicalXor:
332             return "!=/*xor*/";  // XXX: This might need to be handled differently for some obtuse
333                                  // use case.
334         case TOperator::EOpLogicalAnd:
335             return "&&";
336         case TOperator::EOpNegative:
337             return "-";
338         case TOperator::EOpPositive:
339             if (argType0->isMatrix())
340             {
341                 return "";
342             }
343             return "+";
344         case TOperator::EOpLogicalNot:
345             return "!";
346         case TOperator::EOpNotComponentWise:
347             return "!";
348         case TOperator::EOpBitwiseNot:
349             return "~";
350         case TOperator::EOpPostIncrement:
351             return "++";
352         case TOperator::EOpPostDecrement:
353             return "--";
354         case TOperator::EOpPreIncrement:
355             return "++";
356         case TOperator::EOpPreDecrement:
357             return "--";
358         case TOperator::EOpVectorTimesScalarAssign:
359             return "*=";
360         case TOperator::EOpVectorTimesMatrixAssign:
361             return "*=";
362         case TOperator::EOpMatrixTimesScalarAssign:
363             return "*=";
364         case TOperator::EOpMatrixTimesMatrixAssign:
365             return "*=";
366         case TOperator::EOpVectorTimesScalar:
367             return "*";
368         case TOperator::EOpVectorTimesMatrix:
369             return "*";
370         case TOperator::EOpMatrixTimesVector:
371             return "*";
372         case TOperator::EOpMatrixTimesScalar:
373             return "*";
374         case TOperator::EOpMatrixTimesMatrix:
375             return "*";
376         case TOperator::EOpEqualComponentWise:
377             return "==";
378         case TOperator::EOpNotEqualComponentWise:
379             return "!=";
380 
381         case TOperator::EOpEqual:
382             if ((argType0->getStruct() && argType1->getStruct()) &&
383                 (argType0->isArray() && argType1->isArray()))
384             {
385                 return "ANGLE_equalStructArray";
386             }
387 
388             if ((argType0->isVector() && argType1->isVector()) ||
389                 (argType0->getStruct() && argType1->getStruct()) ||
390                 (argType0->isArray() && argType1->isArray()) ||
391                 (argType0->isMatrix() && argType1->isMatrix()))
392 
393             {
394                 return "ANGLE_equal";
395             }
396 
397             return "==";
398 
399         case TOperator::EOpNotEqual:
400             if ((argType0->getStruct() && argType1->getStruct()) &&
401                 (argType0->isArray() && argType1->isArray()))
402             {
403                 return "ANGLE_notEqualStructArray";
404             }
405 
406             if ((argType0->isVector() && argType1->isVector()) ||
407                 (argType0->isArray() && argType1->isArray()) ||
408                 (argType0->isMatrix() && argType1->isMatrix()))
409             {
410                 return "ANGLE_notEqual";
411             }
412             else if (argType0->getStruct() && argType1->getStruct())
413             {
414                 return "ANGLE_notEqualStruct";
415             }
416             return "!=";
417 
418         case TOperator::EOpKill:
419             UNIMPLEMENTED();
420             return "kill";
421         case TOperator::EOpReturn:
422             return "return";
423         case TOperator::EOpBreak:
424             return "break";
425         case TOperator::EOpContinue:
426             return "continue";
427 
428         case TOperator::EOpRadians:
429             return "ANGLE_radians";
430         case TOperator::EOpDegrees:
431             return "ANGLE_degrees";
432         case TOperator::EOpAtan:
433             return "ANGLE_atan";
434         case TOperator::EOpMod:
435             return "ANGLE_mod";  // differs from metal::mod
436         case TOperator::EOpRefract:
437             return "ANGLE_refract";
438         case TOperator::EOpDistance:
439             return "ANGLE_distance";
440         case TOperator::EOpLength:
441             return "ANGLE_length";
442         case TOperator::EOpDot:
443             return "ANGLE_dot";
444         case TOperator::EOpNormalize:
445             return "ANGLE_normalize";
446         case TOperator::EOpFaceforward:
447             return "ANGLE_faceforward";
448         case TOperator::EOpReflect:
449             return "ANGLE_reflect";
450         case TOperator::EOpMatrixCompMult:
451             return "ANGLE_componentWiseMultiply";
452         case TOperator::EOpOuterProduct:
453             return "ANGLE_outerProduct";
454         case TOperator::EOpSign:
455             return "ANGLE_sign";
456 
457         case TOperator::EOpAbs:
458             return "metal::abs";
459         case TOperator::EOpAll:
460             return "metal::all";
461         case TOperator::EOpAny:
462             return "metal::any";
463         case TOperator::EOpSin:
464             return "metal::sin";
465         case TOperator::EOpCos:
466             return "metal::cos";
467         case TOperator::EOpTan:
468             return "metal::tan";
469         case TOperator::EOpAsin:
470             return "metal::asin";
471         case TOperator::EOpAcos:
472             return "metal::acos";
473         case TOperator::EOpSinh:
474             return "metal::sinh";
475         case TOperator::EOpCosh:
476             return "metal::cosh";
477         case TOperator::EOpTanh:
478             return "metal::tanh";
479         case TOperator::EOpAsinh:
480             return "metal::asinh";
481         case TOperator::EOpAcosh:
482             return "metal::acosh";
483         case TOperator::EOpAtanh:
484             return "metal::atanh";
485         case TOperator::EOpFma:
486             return "metal::fma";
487         case TOperator::EOpPow:
488             return "metal::pow";
489         case TOperator::EOpExp:
490             return "metal::exp";
491         case TOperator::EOpExp2:
492             return "metal::exp2";
493         case TOperator::EOpLog:
494             return "metal::log";
495         case TOperator::EOpLog2:
496             return "metal::log2";
497         case TOperator::EOpSqrt:
498             return "metal::sqrt";
499         case TOperator::EOpFloor:
500             return "metal::floor";
501         case TOperator::EOpTrunc:
502             return "metal::trunc";
503         case TOperator::EOpCeil:
504             return "metal::ceil";
505         case TOperator::EOpFract:
506             return "metal::fract";
507         case TOperator::EOpMin:
508             return "metal::min";
509         case TOperator::EOpMax:
510             return "metal::max";
511         case TOperator::EOpRound:
512             return "metal::round";
513         case TOperator::EOpRoundEven:
514             return "metal::rint";
515         case TOperator::EOpClamp:
516             return "metal::clamp";  // TODO fast vs precise namespace
517         case TOperator::EOpMix:
518             if (argType2 && argType2->getBasicType() == EbtBool)
519                 return "ANGLE_mix_bool";
520             return "metal::mix";
521         case TOperator::EOpStep:
522             return "metal::step";
523         case TOperator::EOpSmoothstep:
524             return "metal::smoothstep";
525         case TOperator::EOpModf:
526             return "metal::modf";
527         case TOperator::EOpIsnan:
528             return "metal::isnan";
529         case TOperator::EOpIsinf:
530             return "metal::isinf";
531         case TOperator::EOpLdexp:
532             return "metal::ldexp";
533         case TOperator::EOpFrexp:
534             return "metal::frexp";
535         case TOperator::EOpInversesqrt:
536             return "metal::rsqrt";
537         case TOperator::EOpCross:
538             return "metal::cross";
539         case TOperator::EOpDFdx:
540             return "metal::dfdx";
541         case TOperator::EOpDFdy:
542             return "metal::dfdy";
543         case TOperator::EOpFwidth:
544             return "metal::fwidth";
545         case TOperator::EOpTranspose:
546             return "metal::transpose";
547         case TOperator::EOpDeterminant:
548             return "metal::determinant";
549 
550         case TOperator::EOpInverse:
551             return "ANGLE_inverse";
552 
553         case TOperator::EOpFloatBitsToInt:
554         case TOperator::EOpFloatBitsToUint:
555         case TOperator::EOpIntBitsToFloat:
556         case TOperator::EOpUintBitsToFloat:
557         {
558 #define RETURN_AS_TYPE(post)                     \
559     do                                           \
560         switch (resultType.getBasicType())       \
561         {                                        \
562             case TBasicType::EbtInt:             \
563                 return "as_type<int" post ">";   \
564             case TBasicType::EbtUInt:            \
565                 return "as_type<uint" post ">";  \
566             case TBasicType::EbtFloat:           \
567                 return "as_type<float" post ">"; \
568             default:                             \
569                 UNIMPLEMENTED();                 \
570                 return "TOperator_TODO";         \
571         }                                        \
572     while (false)
573 
574             if (resultType.isScalar())
575             {
576                 RETURN_AS_TYPE("");
577             }
578             else if (resultType.isVector())
579             {
580                 switch (resultType.getNominalSize())
581                 {
582                     case 2:
583                         RETURN_AS_TYPE("2");
584                     case 3:
585                         RETURN_AS_TYPE("3");
586                     case 4:
587                         RETURN_AS_TYPE("4");
588                     default:
589                         UNREACHABLE();
590                         return nullptr;
591                 }
592             }
593             else
594             {
595                 UNIMPLEMENTED();
596                 return "TOperator_TODO";
597             }
598 
599 #undef RETURN_AS_TYPE
600         }
601 
602         case TOperator::EOpPackUnorm2x16:
603             return "metal::pack_float_to_unorm2x16";
604         case TOperator::EOpPackSnorm2x16:
605             return "metal::pack_float_to_snorm2x16";
606 
607         case TOperator::EOpPackUnorm4x8:
608             return "metal::pack_float_to_unorm4x8";
609         case TOperator::EOpPackSnorm4x8:
610             return "metal::pack_float_to_snorm4x8";
611 
612         case TOperator::EOpUnpackUnorm2x16:
613             return "metal::unpack_unorm2x16_to_float";
614         case TOperator::EOpUnpackSnorm2x16:
615             return "metal::unpack_snorm2x16_to_float";
616 
617         case TOperator::EOpUnpackUnorm4x8:
618             return "metal::unpack_unorm4x8_to_float";
619         case TOperator::EOpUnpackSnorm4x8:
620             return "metal::unpack_snorm4x8_to_float";
621 
622         case TOperator::EOpPackHalf2x16:
623             return "ANGLE_pack_half_2x16";
624         case TOperator::EOpUnpackHalf2x16:
625             return "ANGLE_unpack_half_2x16";
626 
627         case TOperator::EOpBitfieldExtract:
628         case TOperator::EOpBitfieldInsert:
629         case TOperator::EOpBitfieldReverse:
630         case TOperator::EOpBitCount:
631         case TOperator::EOpFindLSB:
632         case TOperator::EOpFindMSB:
633         case TOperator::EOpUaddCarry:
634         case TOperator::EOpUsubBorrow:
635         case TOperator::EOpUmulExtended:
636         case TOperator::EOpImulExtended:
637         case TOperator::EOpBarrier:
638         case TOperator::EOpMemoryBarrier:
639         case TOperator::EOpMemoryBarrierAtomicCounter:
640         case TOperator::EOpMemoryBarrierBuffer:
641         case TOperator::EOpMemoryBarrierImage:
642         case TOperator::EOpMemoryBarrierShared:
643         case TOperator::EOpGroupMemoryBarrier:
644         case TOperator::EOpAtomicAdd:
645         case TOperator::EOpAtomicMin:
646         case TOperator::EOpAtomicMax:
647         case TOperator::EOpAtomicAnd:
648         case TOperator::EOpAtomicOr:
649         case TOperator::EOpAtomicXor:
650         case TOperator::EOpAtomicExchange:
651         case TOperator::EOpAtomicCompSwap:
652         case TOperator::EOpEmitVertex:
653         case TOperator::EOpEndPrimitive:
654         case TOperator::EOpFtransform:
655         case TOperator::EOpPackDouble2x32:
656         case TOperator::EOpUnpackDouble2x32:
657         case TOperator::EOpArrayLength:
658             UNIMPLEMENTED();
659             return "TOperator_TODO";
660 
661         case TOperator::EOpNull:
662         case TOperator::EOpConstruct:
663         case TOperator::EOpCallFunctionInAST:
664         case TOperator::EOpCallInternalRawFunction:
665         case TOperator::EOpIndexDirect:
666         case TOperator::EOpIndexIndirect:
667         case TOperator::EOpIndexDirectStruct:
668         case TOperator::EOpIndexDirectInterfaceBlock:
669             UNREACHABLE();
670             return nullptr;
671         default:
672             // Any other built-in function.
673             return nullptr;
674     }
675 }
676 
IsSymbolicOperator(TOperator op,const TType & resultType,const TType * argType0,const TType * argType1)677 static bool IsSymbolicOperator(TOperator op,
678                                const TType &resultType,
679                                const TType *argType0,
680                                const TType *argType1)
681 {
682     const char *operatorString = GetOperatorString(op, resultType, argType0, argType1, nullptr);
683     if (operatorString == nullptr)
684     {
685         return false;
686     }
687     return !std::isalnum(operatorString[0]);
688 }
689 
AsSpecificBinaryNode(TIntermNode & node,TOperator op)690 static TIntermBinary *AsSpecificBinaryNode(TIntermNode &node, TOperator op)
691 {
692     TIntermBinary *binaryNode = node.getAsBinaryNode();
693     if (binaryNode)
694     {
695         return binaryNode->getOp() == op ? binaryNode : nullptr;
696     }
697     return nullptr;
698 }
699 
Parenthesize(TIntermNode & node)700 static bool Parenthesize(TIntermNode &node)
701 {
702     if (node.getAsSymbolNode())
703     {
704         return false;
705     }
706     if (node.getAsConstantUnion())
707     {
708         return false;
709     }
710     if (node.getAsAggregate())
711     {
712         return false;
713     }
714     if (node.getAsSwizzleNode())
715     {
716         return false;
717     }
718 
719     if (TIntermUnary *unaryNode = node.getAsUnaryNode())
720     {
721         // TODO: Use a precedence and associativity rules instead of this ad-hoc impl.
722         const TType &resultType = unaryNode->getType();
723         const TType &argType    = unaryNode->getOperand()->getType();
724         return IsSymbolicOperator(unaryNode->getOp(), resultType, &argType, nullptr);
725     }
726 
727     if (TIntermBinary *binaryNode = node.getAsBinaryNode())
728     {
729         // TODO: Use a precedence and associativity rules instead of this ad-hoc impl.
730         const TOperator op = binaryNode->getOp();
731         switch (op)
732         {
733             case TOperator::EOpIndexDirectStruct:
734             case TOperator::EOpIndexDirectInterfaceBlock:
735             case TOperator::EOpIndexDirect:
736             case TOperator::EOpIndexIndirect:
737                 return Parenthesize(*binaryNode->getLeft());
738 
739             case TOperator::EOpAssign:
740             case TOperator::EOpInitialize:
741                 return AsSpecificBinaryNode(*binaryNode->getRight(), TOperator::EOpComma);
742 
743             default:
744             {
745                 const TType &resultType = binaryNode->getType();
746                 const TType &leftType   = binaryNode->getLeft()->getType();
747                 const TType &rightType  = binaryNode->getRight()->getType();
748                 return IsSymbolicOperator(binaryNode->getOp(), resultType, &leftType, &rightType);
749             }
750         }
751     }
752 
753     return true;
754 }
755 
groupedTraverse(TIntermNode & node)756 void GenMetalTraverser::groupedTraverse(TIntermNode &node)
757 {
758     const bool emitParens = Parenthesize(node);
759 
760     if (emitParens)
761     {
762         mOut << "(";
763     }
764 
765     node.traverse(this);
766 
767     if (emitParens)
768     {
769         mOut << ")";
770     }
771 }
772 
emitPostQualifier(const EmitVariableDeclarationConfig & evdConfig,const VarDecl & decl,const TQualifier qualifier)773 void GenMetalTraverser::emitPostQualifier(const EmitVariableDeclarationConfig &evdConfig,
774                                           const VarDecl &decl,
775                                           const TQualifier qualifier)
776 {
777     bool isInvariant = false;
778     switch (qualifier)
779     {
780         case TQualifier::EvqPosition:
781             isInvariant = decl.type().isInvariant();
782             ANGLE_FALLTHROUGH;
783         case TQualifier::EvqFragCoord:
784             mOut << " [[position]]";
785             break;
786 
787         case TQualifier::EvqPointSize:
788             mOut << " [[point_size]]";
789             break;
790 
791         case TQualifier::EvqVertexID:
792             if (evdConfig.isMainParameter)
793             {
794                 mOut << " [[vertex_id]]";
795             }
796             break;
797 
798         case TQualifier::EvqPointCoord:
799             if (evdConfig.isMainParameter)
800             {
801                 mOut << " [[point_coord]]";
802             }
803             break;
804 
805         case TQualifier::EvqFrontFacing:
806             if (evdConfig.isMainParameter)
807             {
808                 mOut << " [[front_facing]]";
809             }
810             break;
811 
812         default:
813             break;
814     }
815 
816     if (isInvariant)
817     {
818         mOut << " [[invariant]]";
819 
820         TranslatorMetalReflection *reflection = mtl::getTranslatorMetalReflection(&mCompiler);
821         reflection->hasInvariance             = true;
822     }
823 }
824 
EmitName(Sink & out,const Name & name)825 static void EmitName(Sink &out, const Name &name)
826 {
827 #if defined(ANGLE_ENABLE_ASSERTS)
828     DebugSink::EscapedSink escapedOut(out.escape());
829 #else
830     TInfoSinkBase &escapedOut = out;
831 #endif
832     name.emit(escapedOut);
833 }
834 
emitNameOf(const TField & object)835 void GenMetalTraverser::emitNameOf(const TField &object)
836 {
837     EmitName(mOut, Name(object));
838 }
839 
emitNameOf(const TSymbol & object)840 void GenMetalTraverser::emitNameOf(const TSymbol &object)
841 {
842     auto it = mRenamedSymbols.find(&object);
843     if (it == mRenamedSymbols.end())
844     {
845         EmitName(mOut, Name(object));
846     }
847     else
848     {
849         EmitName(mOut, it->second);
850     }
851 }
852 
emitNameOf(const VarDecl & object)853 void GenMetalTraverser::emitNameOf(const VarDecl &object)
854 {
855     if (object.isField())
856     {
857         emitNameOf(object.field());
858     }
859     else
860     {
861         emitNameOf(object.variable());
862     }
863 }
864 
emitBareTypeName(const TType & type,const EmitTypeConfig & etConfig)865 void GenMetalTraverser::emitBareTypeName(const TType &type, const EmitTypeConfig &etConfig)
866 {
867     const TBasicType basicType = type.getBasicType();
868 
869     switch (basicType)
870     {
871         case TBasicType::EbtVoid:
872         case TBasicType::EbtBool:
873         case TBasicType::EbtFloat:
874         case TBasicType::EbtInt:
875         case TBasicType::EbtUInt:
876         {
877             mOut << type.getBasicString();
878         }
879         break;
880 
881         case TBasicType::EbtStruct:
882         {
883             const TStructure &structure = *type.getStruct();
884             emitNameOf(structure);
885         }
886         break;
887 
888         case TBasicType::EbtInterfaceBlock:
889         {
890             const TInterfaceBlock &interfaceBlock = *type.getInterfaceBlock();
891             emitNameOf(interfaceBlock);
892         }
893         break;
894 
895         default:
896         {
897             if (IsSampler(basicType))
898             {
899                 if (etConfig.evdConfig && etConfig.evdConfig->isMainParameter)
900                 {
901                     EmitName(mOut, GetTextureTypeName(basicType));
902                 }
903                 else
904                 {
905                     const TStructure &env = mSymbolEnv.getTextureEnv(basicType);
906                     emitNameOf(env);
907                 }
908             }
909             else
910             {
911                 UNIMPLEMENTED();
912             }
913         }
914     }
915 }
916 
emitType(const TType & type,const EmitTypeConfig & etConfig)917 void GenMetalTraverser::emitType(const TType &type, const EmitTypeConfig &etConfig)
918 {
919     const bool isUBO = etConfig.evdConfig ? etConfig.evdConfig->isUBO : false;
920     if (etConfig.evdConfig)
921     {
922         const auto &evdConfig = *etConfig.evdConfig;
923         if (isUBO)
924         {
925             if (type.isArray())
926             {
927                 mOut << "ANGLE_tensor<";
928             }
929         }
930         if (evdConfig.isPointer)
931         {
932             mOut << toString(*evdConfig.isPointer);
933             mOut << " ";
934         }
935         else if (evdConfig.isReference)
936         {
937             mOut << toString(*evdConfig.isReference);
938             mOut << " ";
939         }
940     }
941 
942     if (!isUBO)
943     {
944         if (type.isArray())
945         {
946             mOut << "ANGLE_tensor<";
947         }
948     }
949 
950     if (type.isVector() || type.isMatrix())
951     {
952         mOut << "metal::";
953     }
954 
955     if (etConfig.evdConfig && etConfig.evdConfig->isPacked)
956     {
957         mOut << "packed_";
958     }
959 
960     emitBareTypeName(type, etConfig);
961 
962     if (type.isVector())
963     {
964         mOut << type.getNominalSize();
965     }
966     else if (type.isMatrix())
967     {
968         mOut << type.getCols() << "x" << type.getRows();
969     }
970 
971     if (!isUBO)
972     {
973         if (type.isArray())
974         {
975             for (auto size : type.getArraySizes())
976             {
977                 mOut << ", " << size;
978             }
979             mOut << ">";
980         }
981     }
982 
983     if (etConfig.evdConfig)
984     {
985         const auto &evdConfig = *etConfig.evdConfig;
986         if (evdConfig.isPointer)
987         {
988             mOut << " *";
989         }
990         else if (evdConfig.isReference)
991         {
992             mOut << " &";
993         }
994         if (isUBO)
995         {
996             if (type.isArray())
997             {
998                 for (auto size : type.getArraySizes())
999                 {
1000                     mOut << ", " << size;
1001                 }
1002                 mOut << ">";
1003             }
1004         }
1005     }
1006 }
1007 
emitFieldDeclaration(const TField & field,const TStructure & parent,FieldAnnotationIndices & annotationIndices)1008 void GenMetalTraverser::emitFieldDeclaration(const TField &field,
1009                                              const TStructure &parent,
1010                                              FieldAnnotationIndices &annotationIndices)
1011 {
1012     const TType &type      = *field.type();
1013     const TBasicType basic = type.getBasicType();
1014 
1015     EmitVariableDeclarationConfig evdConfig;
1016     evdConfig.emitPostQualifier      = true;
1017     evdConfig.disableStructSpecifier = true;
1018     evdConfig.isPacked               = mSymbolEnv.isPacked(field);
1019     evdConfig.isUBO                  = mSymbolEnv.isUBO(field);
1020     evdConfig.isPointer              = mSymbolEnv.isPointer(field);
1021     evdConfig.isReference            = mSymbolEnv.isReference(field);
1022     emitVariableDeclaration(VarDecl(field), evdConfig);
1023 
1024     const TQualifier qual = type.getQualifier();
1025     switch (qual)
1026     {
1027         case TQualifier::EvqFlatIn:
1028             if (mPipelineStructs.fragmentIn.external == &parent)
1029             {
1030                 mOut << " [[flat]]";
1031                 TranslatorMetalReflection *reflection =
1032                     mtl::getTranslatorMetalReflection(&mCompiler);
1033                 reflection->hasFlatInput = true;
1034             }
1035             break;
1036 
1037         case TQualifier::EvqFragmentOut:
1038         case TQualifier::EvqFragData:
1039             if (mPipelineStructs.fragmentOut.external == &parent)
1040             {
1041                 if ((type.isVector() &&
1042                      (basic == TBasicType::EbtInt || basic == TBasicType::EbtUInt ||
1043                       basic == TBasicType::EbtFloat)) ||
1044                     type.getQualifier() == EvqFragData)
1045                 {
1046                     // The OpenGL ES 3.0 spec says locations must be specified
1047                     // unless there is only a single output, in which case the
1048                     // location is 0. So, when we get to this point the shader
1049                     // will have been rejected if locations are not specified
1050                     // and there is more than one output.
1051                     const TLayoutQualifier &layoutQualifier = type.getLayoutQualifier();
1052                     size_t index = layoutQualifier.locationsSpecified ? layoutQualifier.location
1053                                                                       : annotationIndices.color++;
1054                     mOut << " [[color(" << index << ")]]";
1055                 }
1056             }
1057             break;
1058 
1059         case TQualifier::EvqFragDepth:
1060             mOut << " [[depth(any)]]";
1061             break;
1062 
1063         case TQualifier::EvqSampleMask:
1064             mOut << " [[sample_mask, function_constant(" << sh::mtl::kCoverageMaskEnabledConstName
1065                  << ")]]";
1066             break;
1067 
1068         default:
1069             break;
1070     }
1071 }
1072 
BuildExternalAttributeIndexMap(const TCompiler & compiler,const PipelineScoped<TStructure> & structure)1073 static std::map<Name, size_t> BuildExternalAttributeIndexMap(
1074     const TCompiler &compiler,
1075     const PipelineScoped<TStructure> &structure)
1076 {
1077     ASSERT(structure.isTotallyFull());
1078 
1079     const auto &shaderVars     = compiler.getAttributes();
1080     const size_t shaderVarSize = shaderVars.size();
1081     size_t shaderVarIndex      = 0;
1082 
1083     const auto &externalFields = structure.external->fields();
1084     const size_t externalSize  = externalFields.size();
1085     size_t externalIndex       = 0;
1086 
1087     const auto &internalFields = structure.internal->fields();
1088     const size_t internalSize  = internalFields.size();
1089     size_t internalIndex       = 0;
1090 
1091     // Internal fields are never split. External fields are sometimes split.
1092     ASSERT(externalSize >= internalSize);
1093 
1094     // Structures do not contain any inactive fields.
1095     ASSERT(shaderVarSize >= internalSize);
1096 
1097     std::map<Name, size_t> externalNameToAttributeIndex;
1098     size_t attributeIndex = 0;
1099 
1100     while (internalIndex < internalSize)
1101     {
1102         const TField &internalField = *internalFields[internalIndex];
1103         const Name internalName     = Name(internalField);
1104         const TType &internalType   = *internalField.type();
1105         while (internalName.rawName() != shaderVars[shaderVarIndex].name &&
1106                internalName.rawName() != shaderVars[shaderVarIndex].mappedName)
1107         {
1108             // This case represents an inactive field.
1109 
1110             ++shaderVarIndex;
1111             ASSERT(shaderVarIndex < shaderVarSize);
1112 
1113             ++attributeIndex;  // TODO: Might need to increment more if shader var type is a matrix.
1114         }
1115 
1116         const size_t cols = internalType.isMatrix() ? internalType.getCols() : 1;
1117 
1118         for (size_t c = 0; c < cols; ++c)
1119         {
1120             const TField &externalField = *externalFields[externalIndex];
1121             const Name externalName     = Name(externalField);
1122             ASSERT(!externalField.type()->isMatrix());
1123 
1124             externalNameToAttributeIndex[externalName] = attributeIndex;
1125 
1126             ++externalIndex;
1127             ++attributeIndex;
1128         }
1129 
1130         ++shaderVarIndex;
1131         ++internalIndex;
1132     }
1133 
1134     ASSERT(shaderVarIndex <= shaderVarSize);
1135     ASSERT(externalIndex <= externalSize);  // less than if padding was introduced
1136     ASSERT(internalIndex == internalSize);
1137 
1138     return externalNameToAttributeIndex;
1139 }
1140 
emitAttributeDeclaration(const TField & field,FieldAnnotationIndices & annotationIndices)1141 void GenMetalTraverser::emitAttributeDeclaration(const TField &field,
1142                                                  FieldAnnotationIndices &annotationIndices)
1143 {
1144     EmitVariableDeclarationConfig evdConfig;
1145     evdConfig.disableStructSpecifier = true;
1146     emitVariableDeclaration(VarDecl(field), evdConfig);
1147     mOut << sh::kUnassignedAttributeString;
1148 }
1149 
emitUniformBufferDeclaration(const TField & field,FieldAnnotationIndices & annotationIndices)1150 void GenMetalTraverser::emitUniformBufferDeclaration(const TField &field,
1151                                                      FieldAnnotationIndices &annotationIndices)
1152 {
1153     EmitVariableDeclarationConfig evdConfig;
1154     evdConfig.disableStructSpecifier = true;
1155     evdConfig.isUBO                  = mSymbolEnv.isUBO(field);
1156     evdConfig.isPointer              = mSymbolEnv.isPointer(field);
1157     emitVariableDeclaration(VarDecl(field), evdConfig);
1158     mOut << "[[id(" << annotationIndices.attribute << ")]]";
1159 
1160     const TType &type   = *field.type();
1161     const int arraySize = type.isArray() ? type.getArraySizeProduct() : 1;
1162 
1163     TranslatorMetalReflection *reflection = mtl::getTranslatorMetalReflection(&mCompiler);
1164     ASSERT(type.getBasicType() == TBasicType::EbtStruct);
1165     const TStructure *structure    = type.getStruct();
1166     const std::string originalName = reflection->getOriginalName(structure->uniqueId().get());
1167     reflection->addUniformBufferBinding(
1168         originalName,
1169         {.bindIndex = annotationIndices.attribute, .arraySize = static_cast<size_t>(arraySize)});
1170 
1171     annotationIndices.attribute += arraySize;
1172 }
1173 
emitStructDeclaration(const TType & type)1174 void GenMetalTraverser::emitStructDeclaration(const TType &type)
1175 {
1176     ASSERT(type.getBasicType() == TBasicType::EbtStruct);
1177     ASSERT(type.isStructSpecifier());
1178 
1179     mOut << "struct ";
1180     emitBareTypeName(type, {});
1181 
1182     mOut << "\n";
1183     emitOpenBrace();
1184 
1185     const TStructure &structure = *type.getStruct();
1186     std::map<Name, size_t> fieldToAttributeIndex;
1187     const bool hasAttributeIndices      = mPipelineStructs.vertexIn.external == &structure;
1188     const bool hasUniformBufferIndicies = mPipelineStructs.uniformBuffers.external == &structure;
1189     const bool reclaimUnusedAttributeIndices = mCompiler.getShaderVersion() < 300;
1190 
1191     if (hasAttributeIndices)
1192     {
1193         fieldToAttributeIndex =
1194             BuildExternalAttributeIndexMap(mCompiler, mPipelineStructs.vertexIn);
1195     }
1196 
1197     FieldAnnotationIndices annotationIndices;
1198 
1199     for (const TField *field : structure.fields())
1200     {
1201         emitIndentation();
1202         if (hasAttributeIndices)
1203         {
1204             const auto it = fieldToAttributeIndex.find(Name(*field));
1205             if (it == fieldToAttributeIndex.end())
1206             {
1207                 ASSERT(field->symbolType() == SymbolType::AngleInternal);
1208                 ASSERT(field->name().beginsWith("_"));
1209                 ASSERT(angle::EndsWith(field->name().data(), "_pad"));
1210                 emitFieldDeclaration(*field, structure, annotationIndices);
1211             }
1212             else
1213             {
1214                 ASSERT(field->symbolType() != SymbolType::AngleInternal ||
1215                        !field->name().beginsWith("_") ||
1216                        !angle::EndsWith(field->name().data(), "_pad"));
1217                 if (!reclaimUnusedAttributeIndices)
1218                 {
1219                     annotationIndices.attribute = it->second;
1220                 }
1221                 emitAttributeDeclaration(*field, annotationIndices);
1222             }
1223         }
1224         else if (hasUniformBufferIndicies)
1225         {
1226             emitUniformBufferDeclaration(*field, annotationIndices);
1227         }
1228         else
1229         {
1230             emitFieldDeclaration(*field, structure, annotationIndices);
1231         }
1232         mOut << ";\n";
1233     }
1234 
1235     if (!mPipelineStructs.matches(structure, true, true))
1236     {
1237         MetalLayoutOfConfig layoutConfig;
1238         layoutConfig.treatSamplersAsTextureEnv = true;
1239         Layout layout                          = MetalLayoutOf(type, layoutConfig);
1240         size_t pad = (kDefaultStructAlignmentSize - layout.sizeOf) % kDefaultStructAlignmentSize;
1241         if (pad != 0)
1242         {
1243             emitIndentation();
1244             mOut << "char ";
1245             EmitName(mOut, mIdGen.createNewName("pad"));
1246             mOut << "[" << pad << "];\n";
1247         }
1248     }
1249 
1250     emitCloseBrace();
1251 }
1252 
emitOrdinaryVariableDeclaration(const VarDecl & decl,const EmitVariableDeclarationConfig & evdConfig)1253 void GenMetalTraverser::emitOrdinaryVariableDeclaration(
1254     const VarDecl &decl,
1255     const EmitVariableDeclarationConfig &evdConfig)
1256 {
1257     EmitTypeConfig etConfig;
1258     etConfig.evdConfig = &evdConfig;
1259 
1260     const TType &type = decl.type();
1261     emitType(type, etConfig);
1262     if (decl.symbolType() != SymbolType::Empty)
1263     {
1264         mOut << " ";
1265         emitNameOf(decl);
1266     }
1267 }
1268 
emitVariableDeclaration(const VarDecl & decl,const EmitVariableDeclarationConfig & evdConfig)1269 void GenMetalTraverser::emitVariableDeclaration(const VarDecl &decl,
1270                                                 const EmitVariableDeclarationConfig &evdConfig)
1271 {
1272     const SymbolType symbolType = decl.symbolType();
1273     const TType &type           = decl.type();
1274     const TBasicType basicType  = type.getBasicType();
1275 
1276     switch (basicType)
1277     {
1278         case TBasicType::EbtStruct:
1279         {
1280             if (type.isStructSpecifier() && !evdConfig.disableStructSpecifier)
1281             {
1282                 ASSERT(!evdConfig.isParameter);
1283                 emitStructDeclaration(type);
1284                 if (symbolType != SymbolType::Empty)
1285                 {
1286                     mOut << " ";
1287                     emitNameOf(decl);
1288                 }
1289             }
1290             else
1291             {
1292                 emitOrdinaryVariableDeclaration(decl, evdConfig);
1293             }
1294         }
1295         break;
1296 
1297         default:
1298         {
1299             ASSERT(symbolType != SymbolType::Empty || evdConfig.isParameter);
1300             emitOrdinaryVariableDeclaration(decl, evdConfig);
1301         }
1302     }
1303 
1304     if (evdConfig.emitPostQualifier)
1305     {
1306         emitPostQualifier(evdConfig, decl, type.getQualifier());
1307     }
1308 }
1309 
visitSymbol(TIntermSymbol * symbolNode)1310 void GenMetalTraverser::visitSymbol(TIntermSymbol *symbolNode)
1311 {
1312     const TVariable &var = symbolNode->variable();
1313     const TType &type    = var.getType();
1314     ASSERT(var.symbolType() != SymbolType::Empty);
1315 
1316     if (type.getBasicType() == TBasicType::EbtVoid)
1317     {
1318         mOut << "/*";
1319         emitNameOf(var);
1320         mOut << "*/";
1321     }
1322     else
1323     {
1324         emitNameOf(var);
1325     }
1326 }
1327 
emitSingleConstant(const TConstantUnion * const constUnion)1328 void GenMetalTraverser::emitSingleConstant(const TConstantUnion *const constUnion)
1329 {
1330     switch (constUnion->getType())
1331     {
1332         case TBasicType::EbtBool:
1333         {
1334             mOut << (constUnion->getBConst() ? "true" : "false");
1335         }
1336         break;
1337 
1338         case TBasicType::EbtFloat:
1339         {
1340             float value = constUnion->getFConst();
1341             if (std::isnan(value))
1342             {
1343                 mOut << "NAN";
1344             }
1345             else if (std::isinf(value))
1346             {
1347                 if (value < 0)
1348                 {
1349                     mOut << "-";
1350                 }
1351                 mOut << "INFINITY";
1352             }
1353             else
1354             {
1355                 mOut << value << "f";
1356             }
1357         }
1358         break;
1359 
1360         case TBasicType::EbtInt:
1361         {
1362             mOut << constUnion->getIConst();
1363         }
1364         break;
1365 
1366         case TBasicType::EbtUInt:
1367         {
1368             mOut << constUnion->getUConst() << "u";
1369         }
1370         break;
1371 
1372         default:
1373         {
1374             UNIMPLEMENTED();
1375         }
1376     }
1377 }
1378 
emitConstantUnionArray(const TConstantUnion * const constUnion,const size_t size)1379 const TConstantUnion *GenMetalTraverser::emitConstantUnionArray(
1380     const TConstantUnion *const constUnion,
1381     const size_t size)
1382 {
1383     const TConstantUnion *constUnionIterated = constUnion;
1384     for (size_t i = 0; i < size; i++, constUnionIterated++)
1385     {
1386         emitSingleConstant(constUnionIterated);
1387 
1388         if (i != size - 1)
1389         {
1390             mOut << ", ";
1391         }
1392     }
1393     return constUnionIterated;
1394 }
1395 
emitConstantUnion(const TType & type,const TConstantUnion * constUnionBegin)1396 const TConstantUnion *GenMetalTraverser::emitConstantUnion(const TType &type,
1397                                                            const TConstantUnion *constUnionBegin)
1398 {
1399     const TConstantUnion *constUnionCurr = constUnionBegin;
1400     const TStructure *structure          = type.getStruct();
1401     if (structure)
1402     {
1403         EmitTypeConfig config = EmitTypeConfig{nullptr};
1404         emitType(type, config);
1405         mOut << "{";
1406         const TFieldList &fields = structure->fields();
1407         for (size_t i = 0; i < fields.size(); ++i)
1408         {
1409             const TType *fieldType = fields[i]->type();
1410             constUnionCurr         = emitConstantUnion(*fieldType, constUnionCurr);
1411             if (i != fields.size() - 1)
1412             {
1413                 mOut << ", ";
1414             }
1415         }
1416         mOut << "}";
1417     }
1418     else
1419     {
1420         size_t size    = type.getObjectSize();
1421         bool writeType = size > 1;
1422         if (writeType)
1423         {
1424             EmitTypeConfig config = EmitTypeConfig{nullptr};
1425             emitType(type, config);
1426             mOut << "(";
1427         }
1428         constUnionCurr = emitConstantUnionArray(constUnionCurr, size);
1429         if (writeType)
1430         {
1431             mOut << ")";
1432         }
1433     }
1434     return constUnionCurr;
1435 }
1436 
visitConstantUnion(TIntermConstantUnion * constValueNode)1437 void GenMetalTraverser::visitConstantUnion(TIntermConstantUnion *constValueNode)
1438 {
1439     emitConstantUnion(constValueNode->getType(), constValueNode->getConstantValue());
1440 }
1441 
visitSwizzle(Visit,TIntermSwizzle * swizzleNode)1442 bool GenMetalTraverser::visitSwizzle(Visit, TIntermSwizzle *swizzleNode)
1443 {
1444     groupedTraverse(*swizzleNode->getOperand());
1445     mOut << ".";
1446 
1447     {
1448 #if defined(ANGLE_ENABLE_ASSERTS)
1449         DebugSink::EscapedSink escapedOut(mOut.escape());
1450         TInfoSinkBase &out = escapedOut.get();
1451 #else
1452         TInfoSinkBase &out        = mOut;
1453 #endif
1454         swizzleNode->writeOffsetsAsXYZW(&out);
1455     }
1456 
1457     return false;
1458 }
1459 
getDirectField(const TFieldListCollection & fieldListCollection,const TConstantUnion & index)1460 const TField &GenMetalTraverser::getDirectField(const TFieldListCollection &fieldListCollection,
1461                                                 const TConstantUnion &index)
1462 {
1463     ASSERT(index.getType() == TBasicType::EbtInt);
1464 
1465     const TFieldList &fieldList = fieldListCollection.fields();
1466     const int indexVal          = index.getIConst();
1467     const TField &field         = *fieldList[indexVal];
1468 
1469     return field;
1470 }
1471 
getDirectField(const TIntermTyped & fieldsNode,TIntermTyped & indexNode)1472 const TField &GenMetalTraverser::getDirectField(const TIntermTyped &fieldsNode,
1473                                                 TIntermTyped &indexNode)
1474 {
1475     const TType &fieldsType = fieldsNode.getType();
1476 
1477     const TFieldListCollection *fieldListCollection = fieldsType.getStruct();
1478     if (fieldListCollection == nullptr)
1479     {
1480         fieldListCollection = fieldsType.getInterfaceBlock();
1481     }
1482     ASSERT(fieldListCollection);
1483 
1484     const TIntermConstantUnion *indexNode_ = indexNode.getAsConstantUnion();
1485     ASSERT(indexNode_);
1486     const TConstantUnion &index = *indexNode_->getConstantValue();
1487 
1488     return getDirectField(*fieldListCollection, index);
1489 }
1490 
visitBinary(Visit,TIntermBinary * binaryNode)1491 bool GenMetalTraverser::visitBinary(Visit, TIntermBinary *binaryNode)
1492 {
1493     const TOperator op      = binaryNode->getOp();
1494     TIntermTyped &leftNode  = *binaryNode->getLeft();
1495     TIntermTyped &rightNode = *binaryNode->getRight();
1496 
1497     switch (op)
1498     {
1499         case TOperator::EOpIndexDirectStruct:
1500         case TOperator::EOpIndexDirectInterfaceBlock:
1501         {
1502             const TField &field = getDirectField(leftNode, rightNode);
1503             if (mSymbolEnv.isPointer(field) && mSymbolEnv.isUBO(field))
1504             {
1505                 emitOpeningPointerParen();
1506             }
1507             groupedTraverse(leftNode);
1508             if (!mSymbolEnv.isPointer(field))
1509             {
1510                 emitClosingPointerParen();
1511             }
1512             mOut << ".";
1513             emitNameOf(field);
1514         }
1515         break;
1516 
1517         case TOperator::EOpIndexDirect:
1518         case TOperator::EOpIndexIndirect:
1519         {
1520             TType leftType = leftNode.getType();
1521             groupedTraverse(leftNode);
1522             mOut << "[";
1523             {
1524                 mOut << "ANGLE_int_clamp(";
1525                 groupedTraverse(rightNode);
1526                 mOut << ", 0, ";
1527                 if (leftType.isUnsizedArray())
1528                 {
1529                     groupedTraverse(leftNode);
1530                     mOut << ".size()";
1531                 }
1532                 else
1533                 {
1534                     int maxSize;
1535                     if (leftType.isArray())
1536                     {
1537                         maxSize = static_cast<int>(leftType.getOutermostArraySize()) - 1;
1538                     }
1539                     else
1540                     {
1541                         maxSize = leftType.getNominalSize() - 1;
1542                     }
1543                     mOut << maxSize;
1544                 }
1545                 mOut << ")";
1546             }
1547             mOut << "]";
1548         }
1549         break;
1550 
1551         default:
1552         {
1553             const TType &resultType = binaryNode->getType();
1554             const TType &leftType   = leftNode.getType();
1555             const TType &rightType  = rightNode.getType();
1556 
1557             if (IsSymbolicOperator(op, resultType, &leftType, &rightType))
1558             {
1559                 groupedTraverse(leftNode);
1560                 if (op != TOperator::EOpComma)
1561                 {
1562                     mOut << " ";
1563                 }
1564                 else
1565                 {
1566                     emitClosingPointerParen();
1567                 }
1568                 mOut << GetOperatorString(op, resultType, &leftType, &rightType, nullptr) << " ";
1569                 groupedTraverse(rightNode);
1570             }
1571             else
1572             {
1573                 emitClosingPointerParen();
1574                 mOut << GetOperatorString(op, resultType, &leftType, &rightType, nullptr) << "(";
1575                 leftNode.traverse(this);
1576                 mOut << ", ";
1577                 rightNode.traverse(this);
1578                 mOut << ")";
1579             }
1580         }
1581     }
1582 
1583     return false;
1584 }
1585 
IsPostfix(TOperator op)1586 static bool IsPostfix(TOperator op)
1587 {
1588     switch (op)
1589     {
1590         case TOperator::EOpPostIncrement:
1591         case TOperator::EOpPostDecrement:
1592             return true;
1593 
1594         default:
1595             return false;
1596     }
1597 }
1598 
visitUnary(Visit,TIntermUnary * unaryNode)1599 bool GenMetalTraverser::visitUnary(Visit, TIntermUnary *unaryNode)
1600 {
1601     const TOperator op      = unaryNode->getOp();
1602     const TType &resultType = unaryNode->getType();
1603 
1604     TIntermTyped &arg    = *unaryNode->getOperand();
1605     const TType &argType = arg.getType();
1606 
1607     const char *name = GetOperatorString(op, resultType, &argType, nullptr, nullptr);
1608 
1609     if (IsSymbolicOperator(op, resultType, &argType, nullptr))
1610     {
1611         const bool postfix = IsPostfix(op);
1612         if (!postfix)
1613         {
1614             mOut << name;
1615         }
1616         groupedTraverse(arg);
1617         if (postfix)
1618         {
1619             mOut << name;
1620         }
1621     }
1622     else
1623     {
1624         mOut << name << "(";
1625         arg.traverse(this);
1626         mOut << ")";
1627     }
1628 
1629     return false;
1630 }
1631 
visitTernary(Visit,TIntermTernary * conditionalNode)1632 bool GenMetalTraverser::visitTernary(Visit, TIntermTernary *conditionalNode)
1633 {
1634     groupedTraverse(*conditionalNode->getCondition());
1635     mOut << " ? ";
1636     groupedTraverse(*conditionalNode->getTrueExpression());
1637     mOut << " : ";
1638     groupedTraverse(*conditionalNode->getFalseExpression());
1639 
1640     return false;
1641 }
1642 
visitIfElse(Visit,TIntermIfElse * ifThenElseNode)1643 bool GenMetalTraverser::visitIfElse(Visit, TIntermIfElse *ifThenElseNode)
1644 {
1645     TIntermTyped &condNode = *ifThenElseNode->getCondition();
1646     TIntermBlock *thenNode = ifThenElseNode->getTrueBlock();
1647     TIntermBlock *elseNode = ifThenElseNode->getFalseBlock();
1648 
1649     emitIndentation();
1650     mOut << "if (";
1651     condNode.traverse(this);
1652     mOut << ")";
1653 
1654     if (thenNode)
1655     {
1656         mOut << "\n";
1657         thenNode->traverse(this);
1658     }
1659     else
1660     {
1661         mOut << " {}";
1662     }
1663 
1664     if (elseNode)
1665     {
1666         mOut << "\n";
1667         emitIndentation();
1668         mOut << "else\n";
1669         elseNode->traverse(this);
1670     }
1671     else
1672     {
1673         // Always emit "else" even when empty block to avoid nested if-stmt issues.
1674         mOut << " else {}";
1675     }
1676 
1677     return false;
1678 }
1679 
visitSwitch(Visit,TIntermSwitch * switchNode)1680 bool GenMetalTraverser::visitSwitch(Visit, TIntermSwitch *switchNode)
1681 {
1682     emitIndentation();
1683     mOut << "switch (";
1684     switchNode->getInit()->traverse(this);
1685     mOut << ")\n";
1686 
1687     ASSERT(!mParentIsSwitch);
1688     mParentIsSwitch = true;
1689     switchNode->getStatementList()->traverse(this);
1690     mParentIsSwitch = false;
1691 
1692     return false;
1693 }
1694 
visitCase(Visit,TIntermCase * caseNode)1695 bool GenMetalTraverser::visitCase(Visit, TIntermCase *caseNode)
1696 {
1697     emitIndentation();
1698 
1699     if (caseNode->hasCondition())
1700     {
1701         TIntermTyped *condExpr = caseNode->getCondition();
1702         mOut << "case ";
1703         condExpr->traverse(this);
1704         mOut << ":";
1705     }
1706     else
1707     {
1708         mOut << "default:\n";
1709     }
1710 
1711     return false;
1712 }
1713 
emitFunctionSignature(const TFunction & func)1714 void GenMetalTraverser::emitFunctionSignature(const TFunction &func)
1715 {
1716     const bool isMain = func.isMain();
1717 
1718     emitFunctionReturn(func);
1719 
1720     mOut << " ";
1721     emitNameOf(func);
1722     if (isMain)
1723     {
1724         mOut << "0";
1725     }
1726     mOut << "(";
1727 
1728     bool emitComma          = false;
1729     const size_t paramCount = func.getParamCount();
1730     for (size_t i = 0; i < paramCount; ++i)
1731     {
1732         if (emitComma)
1733         {
1734             mOut << ", ";
1735         }
1736         emitComma = true;
1737 
1738         const TVariable &param = *func.getParam(i);
1739         emitFunctionParameter(func, param);
1740     }
1741 
1742     if (isTraversingVertexMain)
1743     {
1744         mOut << " @@XFB-Bindings@@ ";
1745     }
1746 
1747     mOut << ")";
1748 }
1749 
emitFunctionReturn(const TFunction & func)1750 void GenMetalTraverser::emitFunctionReturn(const TFunction &func)
1751 {
1752     const bool isMain       = func.isMain();
1753     bool isVertexMain       = false;
1754     const TType &returnType = func.getReturnType();
1755     if (isMain)
1756     {
1757         const TStructure *structure = returnType.getStruct();
1758         ASSERT(structure != nullptr);
1759         if (mPipelineStructs.fragmentOut.matches(*structure))
1760         {
1761             mOut << "fragment ";
1762         }
1763         else if (mPipelineStructs.vertexOut.matches(*structure))
1764         {
1765             mOut << "vertex __VERTEX_OUT(";
1766             isVertexMain = true;
1767         }
1768         else
1769         {
1770             UNIMPLEMENTED();
1771         }
1772     }
1773     emitType(returnType, EmitTypeConfig());
1774     if (isVertexMain)
1775         mOut << ") ";
1776 }
1777 
emitFunctionParameter(const TFunction & func,const TVariable & param)1778 void GenMetalTraverser::emitFunctionParameter(const TFunction &func, const TVariable &param)
1779 {
1780     const bool isMain = func.isMain();
1781 
1782     const TType &type           = param.getType();
1783     const TStructure *structure = type.getStruct();
1784 
1785     EmitVariableDeclarationConfig evdConfig;
1786     evdConfig.isParameter       = true;
1787     evdConfig.isMainParameter   = isMain;
1788     evdConfig.emitPostQualifier = isMain;
1789     evdConfig.isUBO             = mSymbolEnv.isUBO(param);
1790     evdConfig.isPointer         = mSymbolEnv.isPointer(param);
1791     evdConfig.isReference       = mSymbolEnv.isReference(param);
1792     emitVariableDeclaration(VarDecl(param), evdConfig);
1793 
1794     if (isMain)
1795     {
1796         TranslatorMetalReflection *reflection = mtl::getTranslatorMetalReflection(&mCompiler);
1797         if (structure)
1798         {
1799             if (mPipelineStructs.fragmentIn.matches(*structure) ||
1800                 mPipelineStructs.vertexIn.matches(*structure))
1801             {
1802                 mOut << " [[stage_in]]";
1803             }
1804             else if (mPipelineStructs.angleUniforms.matches(*structure))
1805             {
1806                 mOut << " [[buffer(" << mDriverUniformsBindingIndex << ")]]";
1807             }
1808             else if (mPipelineStructs.uniformBuffers.matches(*structure))
1809             {
1810                 mOut << " [[buffer(" << mUBOArgumentBufferBindingIndex << ")]]";
1811                 reflection->hasUBOs = true;
1812             }
1813             else if (mPipelineStructs.userUniforms.matches(*structure))
1814             {
1815                 mOut << " [[buffer(" << mMainUniformBufferIndex << ")]]";
1816                 reflection->addUserUniformBufferBinding(param.name().data(),
1817                                                         mMainUniformBufferIndex);
1818                 mMainUniformBufferIndex += type.getArraySizeProduct();
1819             }
1820             else if (structure->name() == "metal::sampler")
1821             {
1822                 mOut << " [[sampler(" << (mMainSamplerIndex) << ")]]";
1823                 const std::string originalName =
1824                     reflection->getOriginalName(param.uniqueId().get());
1825                 reflection->addSamplerBinding(originalName, mMainSamplerIndex);
1826                 mMainSamplerIndex += type.getArraySizeProduct();
1827             }
1828         }
1829         else if (IsSampler(type.getBasicType()))
1830         {
1831             mOut << " [[texture(" << (mMainTextureIndex) << ")]]";
1832             const std::string originalName = reflection->getOriginalName(param.uniqueId().get());
1833             reflection->addTextureBinding(originalName, mMainSamplerIndex);
1834             mMainTextureIndex += type.getArraySizeProduct();
1835         }
1836         else if (Name(param) == Pipeline{Pipeline::Type::InstanceId, nullptr}.getStructInstanceName(
1837                                     Pipeline::Variant::Modified))
1838         {
1839             mOut << " [[instance_id]]";
1840         }
1841     }
1842 }
1843 
visitFunctionPrototype(TIntermFunctionPrototype * funcProtoNode)1844 void GenMetalTraverser::visitFunctionPrototype(TIntermFunctionPrototype *funcProtoNode)
1845 {
1846     const TFunction &func = *funcProtoNode->getFunction();
1847 
1848     emitIndentation();
1849     emitFunctionSignature(func);
1850 }
1851 
visitFunctionDefinition(Visit,TIntermFunctionDefinition * funcDefNode)1852 bool GenMetalTraverser::visitFunctionDefinition(Visit, TIntermFunctionDefinition *funcDefNode)
1853 {
1854     const TFunction &func = *funcDefNode->getFunction();
1855     TIntermBlock &body    = *funcDefNode->getBody();
1856     if (func.isMain())
1857     {
1858         const TType &returnType     = func.getReturnType();
1859         const TStructure *structure = returnType.getStruct();
1860         isTraversingVertexMain      = (mPipelineStructs.vertexOut.matches(*structure));
1861     }
1862     emitIndentation();
1863     emitFunctionSignature(func);
1864     mOut << "\n";
1865     body.traverse(this);
1866     if (isTraversingVertexMain)
1867     {
1868         isTraversingVertexMain = false;
1869     }
1870     return false;
1871 }
1872 
BuildFuncToName()1873 GenMetalTraverser::FuncToName GenMetalTraverser::BuildFuncToName()
1874 {
1875     FuncToName map;
1876 
1877     auto putAngle = [&](const char *nameStr) {
1878         const ImmutableString name(nameStr);
1879         ASSERT(map.find(name) == map.end());
1880         map[name] = Name(nameStr, SymbolType::AngleInternal);
1881     };
1882 
1883     putAngle("texelFetch");
1884     putAngle("texelFetchOffset");
1885     putAngle("texture");
1886     putAngle("texture1D");
1887     putAngle("texture1DLod");
1888     putAngle("texture1DProjLod");
1889     putAngle("texture2D");
1890     putAngle("texture2DLod");
1891     putAngle("texture2DProj");
1892     putAngle("texture2DRect");
1893     putAngle("texture2DProjLod");
1894     putAngle("texture2DRectProj");
1895     putAngle("texture3D");
1896     putAngle("texture3DLod");
1897     putAngle("texture3DProjLod");
1898     putAngle("textureCube");
1899     putAngle("textureCubeLod");
1900     putAngle("textureCubeProjLod");
1901     putAngle("textureGrad");
1902     putAngle("textureGradOffset");
1903     putAngle("textureLod");
1904     putAngle("textureLodOffset");
1905     putAngle("textureOffset");
1906     putAngle("textureProj");
1907     putAngle("textureProjGrad");
1908     putAngle("textureProjGradOffset");
1909     putAngle("textureProjLod");
1910     putAngle("textureProjLodOffset");
1911     putAngle("textureProjOffset");
1912     putAngle("textureSize");
1913 
1914     return map;
1915 }
1916 
visitAggregate(Visit,TIntermAggregate * aggregateNode)1917 bool GenMetalTraverser::visitAggregate(Visit, TIntermAggregate *aggregateNode)
1918 {
1919     const TIntermSequence &args = *aggregateNode->getSequence();
1920 
1921     auto emitArgList = [&](const char *open, const char *close) {
1922         mOut << open;
1923 
1924         bool emitComma = false;
1925         for (TIntermNode *arg : args)
1926         {
1927             if (emitComma)
1928             {
1929                 emitClosingPointerParen();
1930                 mOut << ", ";
1931             }
1932             emitComma = true;
1933             arg->traverse(this);
1934         }
1935 
1936         mOut << close;
1937     };
1938 
1939     const TType &retType = aggregateNode->getType();
1940 
1941     if (aggregateNode->isConstructor())
1942     {
1943         const bool isStandalone = getParentNode()->getAsBlock();
1944         if (isStandalone)
1945         {
1946             // Prevent constructor from being interpreted as a declaration by wrapping in parens.
1947             // This can happen if given something like:
1948             //      int(symbol); // <- This will be treated like `int symbol;`... don't want that.
1949             // So instead emit:
1950             //      (int(symbol));
1951             mOut << "(";
1952         }
1953 
1954         const EmitTypeConfig etConfig;
1955 
1956         if (retType.isArray())
1957         {
1958             emitType(retType, etConfig);
1959             emitArgList("{", "}");
1960         }
1961         else if (retType.getStruct())
1962         {
1963             emitType(retType, etConfig);
1964             emitArgList("{", "}");
1965         }
1966         else
1967         {
1968             emitType(retType, etConfig);
1969             emitArgList("(", ")");
1970         }
1971 
1972         if (isStandalone)
1973         {
1974             mOut << ")";
1975         }
1976 
1977         return false;
1978     }
1979     else
1980     {
1981         const TOperator op = aggregateNode->getOp();
1982         if (op == EOpAtan)
1983         {
1984             TranslatorMetalReflection *reflection = mtl::getTranslatorMetalReflection(&mCompiler);
1985             reflection->hasAtan                   = true;
1986         }
1987         switch (op)
1988         {
1989             case TOperator::EOpCallFunctionInAST:
1990             case TOperator::EOpCallInternalRawFunction:
1991             {
1992                 const TFunction &func = *aggregateNode->getFunction();
1993                 emitNameOf(func);
1994                 //'@' symbol in name specifices a macro substitution marker.
1995                 if (!func.name().contains("@"))
1996                 {
1997                     emitArgList("(", ")");
1998                 }
1999                 else
2000                 {
2001                     mTemporarilyDisableSemicolon =
2002                         true;  // Disable semicolon for macro substitution.
2003                 }
2004                 return false;
2005             }
2006 
2007             default:
2008             {
2009                 auto getArgType = [&](size_t index) -> const TType * {
2010                     if (index < args.size())
2011                     {
2012                         TIntermTyped *arg = args[index]->getAsTyped();
2013                         ASSERT(arg);
2014                         return &arg->getType();
2015                     }
2016                     return nullptr;
2017                 };
2018 
2019                 ASSERT(!args.empty());
2020                 const TType *argType0 = getArgType(0);
2021                 const TType *argType1 = getArgType(1);
2022                 const TType *argType2 = getArgType(2);
2023                 ASSERT(argType0);
2024 
2025                 const char *opName = GetOperatorString(op, retType, argType0, argType1, argType2);
2026 
2027                 if (IsSymbolicOperator(op, retType, argType0, argType1))
2028                 {
2029                     switch (args.size())
2030                     {
2031                         case 1:
2032                         {
2033                             TIntermNode &operandNode = *aggregateNode->getChildNode(0);
2034                             if (IsPostfix(op))
2035                             {
2036                                 mOut << opName;
2037                                 groupedTraverse(operandNode);
2038                             }
2039                             else
2040                             {
2041                                 groupedTraverse(operandNode);
2042                                 mOut << opName;
2043                             }
2044                             return false;
2045                         }
2046 
2047                         case 2:
2048                         {
2049                             TIntermNode &leftNode  = *aggregateNode->getChildNode(0);
2050                             TIntermNode &rightNode = *aggregateNode->getChildNode(1);
2051                             groupedTraverse(leftNode);
2052                             mOut << " " << opName << " ";
2053                             groupedTraverse(rightNode);
2054                             return false;
2055                         }
2056 
2057                         default:
2058                             UNREACHABLE();
2059                             return false;
2060                     }
2061                 }
2062                 else if (opName == nullptr)
2063                 {
2064                     const TFunction &func = *aggregateNode->getFunction();
2065                     auto it               = mFuncToName.find(func.name());
2066                     ASSERT(it != mFuncToName.end());
2067                     EmitName(mOut, it->second);
2068                     emitArgList("(", ")");
2069                     return false;
2070                 }
2071                 else
2072                 {
2073                     mOut << opName;
2074                     emitArgList("(", ")");
2075                     return false;
2076                 }
2077             }
2078         }
2079     }
2080 }
2081 
emitOpenBrace()2082 void GenMetalTraverser::emitOpenBrace()
2083 {
2084     ASSERT(mIndentLevel >= 0);
2085 
2086     emitIndentation();
2087     mOut << "{\n";
2088     ++mIndentLevel;
2089 }
2090 
emitCloseBrace()2091 void GenMetalTraverser::emitCloseBrace()
2092 {
2093     ASSERT(mIndentLevel >= 1);
2094 
2095     --mIndentLevel;
2096     emitIndentation();
2097     mOut << "}";
2098 }
2099 
RequiresSemicolonTerminator(TIntermNode & node)2100 static bool RequiresSemicolonTerminator(TIntermNode &node)
2101 {
2102     if (node.getAsBlock())
2103     {
2104         return false;
2105     }
2106     if (node.getAsLoopNode())
2107     {
2108         return false;
2109     }
2110     if (node.getAsSwitchNode())
2111     {
2112         return false;
2113     }
2114     if (node.getAsIfElseNode())
2115     {
2116         return false;
2117     }
2118     if (node.getAsFunctionDefinition())
2119     {
2120         return false;
2121     }
2122     if (node.getAsCaseNode())
2123     {
2124         return false;
2125     }
2126 
2127     return true;
2128 }
2129 
NewlinePad(TIntermNode & node)2130 static bool NewlinePad(TIntermNode &node)
2131 {
2132     if (node.getAsFunctionDefinition())
2133     {
2134         return true;
2135     }
2136     if (TIntermDeclaration *declNode = node.getAsDeclarationNode())
2137     {
2138         ASSERT(declNode->getChildCount() == 1);
2139         TIntermNode &childNode = *declNode->getChildNode(0);
2140         if (TIntermSymbol *symbolNode = childNode.getAsSymbolNode())
2141         {
2142             const TVariable &var = symbolNode->variable();
2143             return var.getType().isStructSpecifier();
2144         }
2145         return false;
2146     }
2147     return false;
2148 }
2149 
visitBlock(Visit,TIntermBlock * blockNode)2150 bool GenMetalTraverser::visitBlock(Visit, TIntermBlock *blockNode)
2151 {
2152     ASSERT(mIndentLevel >= -1);
2153     const bool isGlobalScope  = mIndentLevel == -1;
2154     const bool parentIsSwitch = mParentIsSwitch;
2155     mParentIsSwitch           = false;
2156 
2157     if (isGlobalScope)
2158     {
2159         ++mIndentLevel;
2160     }
2161     else
2162     {
2163         emitOpenBrace();
2164         if (parentIsSwitch)
2165         {
2166             ++mIndentLevel;
2167         }
2168     }
2169 
2170     TIntermNode *prevStmtNode = nullptr;
2171 
2172     const size_t stmtCount = blockNode->getChildCount();
2173     for (size_t i = 0; i < stmtCount; ++i)
2174     {
2175         TIntermNode &stmtNode = *blockNode->getChildNode(i);
2176 
2177         if (isGlobalScope && prevStmtNode && (NewlinePad(*prevStmtNode) || NewlinePad(stmtNode)))
2178         {
2179             mOut << "\n";
2180         }
2181         const bool isCase = stmtNode.getAsCaseNode();
2182         mIndentLevel -= isCase;
2183         emitIndentation();
2184         mIndentLevel += isCase;
2185         stmtNode.traverse(this);
2186         if (RequiresSemicolonTerminator(stmtNode) && !mTemporarilyDisableSemicolon)
2187         {
2188             mOut << ";";
2189         }
2190         mTemporarilyDisableSemicolon = false;
2191         mOut << "\n";
2192 
2193         prevStmtNode = &stmtNode;
2194     }
2195 
2196     if (isGlobalScope)
2197     {
2198         ASSERT(mIndentLevel == 0);
2199         --mIndentLevel;
2200     }
2201     else
2202     {
2203         if (parentIsSwitch)
2204         {
2205             ASSERT(mIndentLevel >= 1);
2206             --mIndentLevel;
2207         }
2208         emitCloseBrace();
2209         mParentIsSwitch = parentIsSwitch;
2210     }
2211 
2212     return false;
2213 }
2214 
visitGlobalQualifierDeclaration(Visit,TIntermGlobalQualifierDeclaration *)2215 bool GenMetalTraverser::visitGlobalQualifierDeclaration(Visit, TIntermGlobalQualifierDeclaration *)
2216 {
2217     return false;
2218 }
2219 
visitDeclaration(Visit,TIntermDeclaration * declNode)2220 bool GenMetalTraverser::visitDeclaration(Visit, TIntermDeclaration *declNode)
2221 {
2222     ASSERT(declNode->getChildCount() == 1);
2223     TIntermNode &node = *declNode->getChildNode(0);
2224 
2225     EmitVariableDeclarationConfig evdConfig;
2226 
2227     if (TIntermSymbol *symbolNode = node.getAsSymbolNode())
2228     {
2229         const TVariable &var = symbolNode->variable();
2230         emitVariableDeclaration(VarDecl(var), evdConfig);
2231     }
2232     else if (TIntermBinary *initNode = node.getAsBinaryNode())
2233     {
2234         ASSERT(initNode->getOp() == TOperator::EOpInitialize);
2235         TIntermSymbol *leftSymbolNode = initNode->getLeft()->getAsSymbolNode();
2236         TIntermTyped *valueNode       = initNode->getRight()->getAsTyped();
2237         ASSERT(leftSymbolNode && valueNode);
2238 
2239         if (getRootNode() == getParentBlock())
2240         {
2241             // DeferGlobalInitializers should have turned non-const global initializers into
2242             // deferred initializers. Note that variables marked as EvqGlobal can be treated as
2243             // EvqConst in some ANGLE code but not actually have their qualifier actually changed to
2244             // EvqConst. Thus just assume all EvqGlobal are actually EvqConst for all code run after
2245             // DeferGlobalInitializers.
2246             mOut << "constant ";
2247         }
2248 
2249         const TVariable &var = leftSymbolNode->variable();
2250         const Name varName(var);
2251 
2252         if (ExpressionContainsName(varName, *valueNode))
2253         {
2254             mRenamedSymbols[&var] = mIdGen.createNewName(varName);
2255         }
2256 
2257         emitVariableDeclaration(VarDecl(var), evdConfig);
2258         mOut << " = ";
2259         groupedTraverse(*valueNode);
2260     }
2261     else
2262     {
2263         UNREACHABLE();
2264     }
2265 
2266     return false;
2267 }
2268 
visitLoop(Visit,TIntermLoop * loopNode)2269 bool GenMetalTraverser::visitLoop(Visit, TIntermLoop *loopNode)
2270 {
2271     const TLoopType loopType = loopNode->getType();
2272 
2273     switch (loopType)
2274     {
2275         case TLoopType::ELoopFor:
2276             return visitForLoop(loopNode);
2277         case TLoopType::ELoopWhile:
2278             return visitWhileLoop(loopNode);
2279         case TLoopType::ELoopDoWhile:
2280             return visitDoWhileLoop(loopNode);
2281     }
2282 }
2283 
visitForLoop(TIntermLoop * loopNode)2284 bool GenMetalTraverser::visitForLoop(TIntermLoop *loopNode)
2285 {
2286     ASSERT(loopNode->getType() == TLoopType::ELoopFor);
2287 
2288     TIntermNode *initNode  = loopNode->getInit();
2289     TIntermTyped *condNode = loopNode->getCondition();
2290     TIntermTyped *exprNode = loopNode->getExpression();
2291     TIntermBlock *bodyNode = loopNode->getBody();
2292     ASSERT(bodyNode);
2293 
2294     mOut << "for (";
2295 
2296     if (initNode)
2297     {
2298         initNode->traverse(this);
2299     }
2300     else
2301     {
2302         mOut << " ";
2303     }
2304 
2305     mOut << "; ";
2306 
2307     if (condNode)
2308     {
2309         condNode->traverse(this);
2310     }
2311 
2312     mOut << "; ";
2313 
2314     if (exprNode)
2315     {
2316         exprNode->traverse(this);
2317     }
2318 
2319     mOut << ")\n";
2320 
2321     bodyNode->traverse(this);
2322 
2323     return false;
2324 }
2325 
visitWhileLoop(TIntermLoop * loopNode)2326 bool GenMetalTraverser::visitWhileLoop(TIntermLoop *loopNode)
2327 {
2328     ASSERT(loopNode->getType() == TLoopType::ELoopWhile);
2329 
2330     TIntermNode *initNode  = loopNode->getInit();
2331     TIntermTyped *condNode = loopNode->getCondition();
2332     TIntermTyped *exprNode = loopNode->getExpression();
2333     TIntermBlock *bodyNode = loopNode->getBody();
2334     ASSERT(condNode && bodyNode);
2335     ASSERT(!initNode && !exprNode);
2336 
2337     emitIndentation();
2338     mOut << "while (";
2339     condNode->traverse(this);
2340     mOut << ")\n";
2341     bodyNode->traverse(this);
2342 
2343     return false;
2344 }
2345 
visitDoWhileLoop(TIntermLoop * loopNode)2346 bool GenMetalTraverser::visitDoWhileLoop(TIntermLoop *loopNode)
2347 {
2348     ASSERT(loopNode->getType() == TLoopType::ELoopDoWhile);
2349 
2350     TIntermNode *initNode  = loopNode->getInit();
2351     TIntermTyped *condNode = loopNode->getCondition();
2352     TIntermTyped *exprNode = loopNode->getExpression();
2353     TIntermBlock *bodyNode = loopNode->getBody();
2354     ASSERT(condNode && bodyNode);
2355     ASSERT(!initNode && !exprNode);
2356 
2357     emitIndentation();
2358     mOut << "do\n";
2359     bodyNode->traverse(this);
2360     mOut << "\n";
2361     emitIndentation();
2362     mOut << "while (";
2363     condNode->traverse(this);
2364     mOut << ");";
2365 
2366     return false;
2367 }
2368 
visitBranch(Visit,TIntermBranch * branchNode)2369 bool GenMetalTraverser::visitBranch(Visit, TIntermBranch *branchNode)
2370 {
2371     const TOperator flowOp = branchNode->getFlowOp();
2372     TIntermTyped *exprNode = branchNode->getExpression();
2373 
2374     emitIndentation();
2375 
2376     switch (flowOp)
2377     {
2378         case TOperator::EOpKill:
2379         {
2380             ASSERT(exprNode == nullptr);
2381             mOut << "metal::discard_fragment()";
2382         }
2383         break;
2384 
2385         case TOperator::EOpReturn:
2386         {
2387             if (isTraversingVertexMain)
2388             {
2389                 mOut << "#if TRANSFORM_FEEDBACK_ENABLED\n";
2390                 emitIndentation();
2391                 mOut << "return;\n";
2392                 emitIndentation();
2393                 mOut << "#else\n";
2394                 emitIndentation();
2395             }
2396             mOut << "return";
2397             if (exprNode)
2398             {
2399                 mOut << " ";
2400                 exprNode->traverse(this);
2401                 mOut << ";";
2402             }
2403             if (isTraversingVertexMain)
2404             {
2405                 mOut << "\n";
2406                 emitIndentation();
2407                 mOut << "#endif\n";
2408                 mTemporarilyDisableSemicolon = true;
2409             }
2410         }
2411         break;
2412 
2413         case TOperator::EOpBreak:
2414         {
2415             ASSERT(exprNode == nullptr);
2416             mOut << "break";
2417         }
2418         break;
2419 
2420         case TOperator::EOpContinue:
2421         {
2422             ASSERT(exprNode == nullptr);
2423             mOut << "continue";
2424         }
2425         break;
2426 
2427         default:
2428         {
2429             UNREACHABLE();
2430         }
2431     }
2432 
2433     return false;
2434 }
2435 
2436 static size_t emitMetalCallCount = 0;
2437 
EmitMetal(TCompiler & compiler,TIntermBlock & root,IdGen & idGen,const PipelineStructs & pipelineStructs,SymbolEnv & symbolEnv,const ProgramPreludeConfig & ppc,TSymbolTable * symbolTable)2438 bool sh::EmitMetal(TCompiler &compiler,
2439                    TIntermBlock &root,
2440                    IdGen &idGen,
2441                    const PipelineStructs &pipelineStructs,
2442                    SymbolEnv &symbolEnv,
2443                    const ProgramPreludeConfig &ppc,
2444                    TSymbolTable *symbolTable)
2445 {
2446     TInfoSinkBase &out = compiler.getInfoSink().obj;
2447 
2448     {
2449         ++emitMetalCallCount;
2450         std::string filenameProto = angle::GetEnvironmentVar("GMD_FIXED_EMIT");
2451         if (!filenameProto.empty())
2452         {
2453             if (filenameProto != "/dev/null")
2454             {
2455                 auto tryOpen = [&](char const *ext) {
2456                     auto filename = filenameProto;
2457                     filename += std::to_string(emitMetalCallCount);
2458                     filename += ".";
2459                     filename += ext;
2460                     return fopen(filename.c_str(), "rb");
2461                 };
2462                 FILE *file = tryOpen("metal");
2463                 if (!file)
2464                 {
2465                     file = tryOpen("cpp");
2466                 }
2467                 ASSERT(file);
2468 
2469                 fseek(file, 0, SEEK_END);
2470                 size_t fileSize = ftell(file);
2471                 fseek(file, 0, SEEK_SET);
2472 
2473                 std::vector<char> buff;
2474                 buff.resize(fileSize + 1);
2475                 fread(buff.data(), fileSize, 1, file);
2476                 buff.back() = '\0';
2477 
2478                 fclose(file);
2479 
2480                 out << buff.data();
2481             }
2482 
2483             return true;
2484         }
2485     }
2486 
2487     out << "\n\n";
2488 
2489     if (!EmitProgramPrelude(root, out, ppc))
2490     {
2491         return false;
2492     }
2493 
2494     {
2495 #if defined(ANGLE_ENABLE_ASSERTS)
2496         DebugSink outWrapper(out, angle::GetBoolEnvironmentVar("GMD_STDOUT"));
2497         outWrapper.watch(angle::GetEnvironmentVar("GMD_WATCH_STRING"));
2498 #else
2499         TInfoSinkBase &outWrapper = out;
2500 #endif
2501         GenMetalTraverser gen(compiler, outWrapper, idGen, pipelineStructs, symbolEnv, symbolTable);
2502         root.traverse(&gen);
2503     }
2504 
2505     out << "\n";
2506 
2507     return true;
2508 }
2509