• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include <cctype>
8 #include <map>
9 
10 #include "common/system_utils.h"
11 #include "compiler/translator/BaseTypes.h"
12 #include "compiler/translator/ImmutableStringBuilder.h"
13 #include "compiler/translator/OutputTree.h"
14 #include "compiler/translator/SymbolTable.h"
15 #include "compiler/translator/msl/AstHelpers.h"
16 #include "compiler/translator/msl/DebugSink.h"
17 #include "compiler/translator/msl/EmitMetal.h"
18 #include "compiler/translator/msl/Layout.h"
19 #include "compiler/translator/msl/Name.h"
20 #include "compiler/translator/msl/ProgramPrelude.h"
21 #include "compiler/translator/msl/RewritePipelines.h"
22 #include "compiler/translator/msl/TranslatorMSL.h"
23 #include "compiler/translator/tree_util/IntermTraverse.h"
24 
25 using namespace sh;
26 
27 ////////////////////////////////////////////////////////////////////////////////
28 
29 #if defined(ANGLE_ENABLE_ASSERTS)
30 using Sink = DebugSink;
31 #else
32 using Sink = TInfoSinkBase;
33 #endif
34 
35 ////////////////////////////////////////////////////////////////////////////////
36 
37 namespace
38 {
39 
40 struct VarDecl
41 {
VarDecl__anonc2fa8c430111::VarDecl42     explicit VarDecl(const TVariable &var) : mVariable(&var), mIsField(false) {}
VarDecl__anonc2fa8c430111::VarDecl43     explicit VarDecl(const TField &field) : mField(&field), mIsField(true) {}
44 
variable__anonc2fa8c430111::VarDecl45     ANGLE_INLINE const TVariable &variable() const
46     {
47         ASSERT(isVariable());
48         return *mVariable;
49     }
50 
field__anonc2fa8c430111::VarDecl51     ANGLE_INLINE const TField &field() const
52     {
53         ASSERT(isField());
54         return *mField;
55     }
56 
isVariable__anonc2fa8c430111::VarDecl57     ANGLE_INLINE bool isVariable() const { return !mIsField; }
58 
isField__anonc2fa8c430111::VarDecl59     ANGLE_INLINE bool isField() const { return mIsField; }
60 
type__anonc2fa8c430111::VarDecl61     const TType &type() const { return isField() ? *field().type() : variable().getType(); }
62 
symbolType__anonc2fa8c430111::VarDecl63     SymbolType symbolType() const
64     {
65         return isField() ? field().symbolType() : variable().symbolType();
66     }
67 
68   private:
69     union
70     {
71         const TVariable *mVariable;
72         const TField *mField;
73     };
74     bool mIsField;
75 };
76 
77 class GenMetalTraverser : public TIntermTraverser
78 {
79   public:
80     ~GenMetalTraverser() override;
81 
82     GenMetalTraverser(const TCompiler &compiler,
83                       Sink &out,
84                       IdGen &idGen,
85                       const PipelineStructs &pipelineStructs,
86                       SymbolEnv &symbolEnv,
87                       const ShCompileOptions &compileOptions);
88 
89     void visitSymbol(TIntermSymbol *) override;
90     void visitConstantUnion(TIntermConstantUnion *) override;
91     bool visitSwizzle(Visit, TIntermSwizzle *) override;
92     bool visitBinary(Visit, TIntermBinary *) override;
93     bool visitUnary(Visit, TIntermUnary *) override;
94     bool visitTernary(Visit, TIntermTernary *) override;
95     bool visitIfElse(Visit, TIntermIfElse *) override;
96     bool visitSwitch(Visit, TIntermSwitch *) override;
97     bool visitCase(Visit, TIntermCase *) override;
98     void visitFunctionPrototype(TIntermFunctionPrototype *) override;
99     bool visitFunctionDefinition(Visit, TIntermFunctionDefinition *) override;
100     bool visitAggregate(Visit, TIntermAggregate *) override;
101     bool visitBlock(Visit, TIntermBlock *) override;
102     bool visitGlobalQualifierDeclaration(Visit, TIntermGlobalQualifierDeclaration *) override;
103     bool visitDeclaration(Visit, TIntermDeclaration *) override;
104     bool visitLoop(Visit, TIntermLoop *) override;
105     bool visitForLoop(TIntermLoop *);
106     bool visitWhileLoop(TIntermLoop *);
107     bool visitDoWhileLoop(TIntermLoop *);
108     bool visitBranch(Visit, TIntermBranch *) override;
109 
110   private:
111     using FuncToName = std::map<ImmutableString, Name>;
112     static FuncToName BuildFuncToName();
113 
114     struct EmitVariableDeclarationConfig
115     {
116         bool isParameter                = false;
117         bool isMainParameter            = false;
118         bool emitPostQualifier          = false;
119         bool isPacked                   = false;
120         bool disableStructSpecifier     = false;
121         bool isUBO                      = false;
122         const AddressSpace *isPointer   = nullptr;
123         const AddressSpace *isReference = nullptr;
124     };
125 
126     struct EmitTypeConfig
127     {
128         const EmitVariableDeclarationConfig *evdConfig = nullptr;
129     };
130 
131     void emitIndentation();
132     void emitOpeningPointerParen();
133     void emitClosingPointerParen();
134     void emitFunctionSignature(const TFunction &func);
135     void emitFunctionReturn(const TFunction &func);
136     void emitFunctionParameter(const TFunction &func, const TVariable &param);
137 
138     void emitNameOf(const TField &object);
139     void emitNameOf(const TSymbol &object);
140     void emitNameOf(const VarDecl &object);
141 
142     void emitBareTypeName(const TType &type, const EmitTypeConfig &etConfig);
143     void emitType(const TType &type, const EmitTypeConfig &etConfig);
144     void emitPostQualifier(const EmitVariableDeclarationConfig &evdConfig,
145                            const VarDecl &decl,
146                            const TQualifier qualifier);
147 
148     void emitLoopBody(TIntermBlock *bodyNode);
149 
150     struct FieldAnnotationIndices
151     {
152         size_t attribute = 0;
153         size_t color     = 0;
154     };
155 
156     void emitFieldDeclaration(const TField &field,
157                               const TStructure &parent,
158                               FieldAnnotationIndices &annotationIndices);
159     void emitAttributeDeclaration(const TField &field, FieldAnnotationIndices &annotationIndices);
160     void emitUniformBufferDeclaration(const TField &field,
161                                       FieldAnnotationIndices &annotationIndices);
162     void emitStructDeclaration(const TType &type);
163     void emitOrdinaryVariableDeclaration(const VarDecl &decl,
164                                          const EmitVariableDeclarationConfig &evdConfig);
165     void emitVariableDeclaration(const VarDecl &decl,
166                                  const EmitVariableDeclarationConfig &evdConfig);
167 
168     void emitOpenBrace();
169     void emitCloseBrace();
170 
171     void groupedTraverse(TIntermNode &node);
172 
173     const TField &getDirectField(const TFieldListCollection &fieldsNode,
174                                  const TConstantUnion &index);
175     const TField &getDirectField(const TIntermTyped &fieldsNode, TIntermTyped &indexNode);
176 
177     const TConstantUnion *emitConstantUnionArray(const TConstantUnion *const constUnion,
178                                                  const size_t size);
179 
180     const TConstantUnion *emitConstantUnion(const TType &type, const TConstantUnion *constUnion);
181 
182     void emitSingleConstant(const TConstantUnion *const constUnion);
183 
184   private:
185     Sink &mOut;
186     const TCompiler &mCompiler;
187     const PipelineStructs &mPipelineStructs;
188     SymbolEnv &mSymbolEnv;
189     IdGen &mIdGen;
190     int mIndentLevel                  = -1;
191     int mLastIndentationPos           = -1;
192     int mOpenPointerParenCount        = 0;
193     bool mParentIsSwitch              = false;
194     bool isTraversingVertexMain       = false;
195     bool mTemporarilyDisableSemicolon = false;
196     std::unordered_map<const TSymbol *, Name> mRenamedSymbols;
197     const FuncToName mFuncToName           = BuildFuncToName();
198     size_t mMainTextureIndex               = 0;
199     size_t mMainSamplerIndex               = 0;
200     size_t mMainUniformBufferIndex         = 0;
201     size_t mDriverUniformsBindingIndex     = 0;
202     size_t mUBOArgumentBufferBindingIndex  = 0;
203     bool mRasterOrderGroupsSupported       = false;
204     bool mInjectAsmStatementIntoLoopBodies = false;
205 };
206 }  // anonymous namespace
207 
~GenMetalTraverser()208 GenMetalTraverser::~GenMetalTraverser()
209 {
210     ASSERT(mIndentLevel == -1);
211     ASSERT(!mParentIsSwitch);
212     ASSERT(mOpenPointerParenCount == 0);
213 }
214 
GenMetalTraverser(const TCompiler & compiler,Sink & out,IdGen & idGen,const PipelineStructs & pipelineStructs,SymbolEnv & symbolEnv,const ShCompileOptions & compileOptions)215 GenMetalTraverser::GenMetalTraverser(const TCompiler &compiler,
216                                      Sink &out,
217                                      IdGen &idGen,
218                                      const PipelineStructs &pipelineStructs,
219                                      SymbolEnv &symbolEnv,
220                                      const ShCompileOptions &compileOptions)
221     : TIntermTraverser(true, false, false),
222       mOut(out),
223       mCompiler(compiler),
224       mPipelineStructs(pipelineStructs),
225       mSymbolEnv(symbolEnv),
226       mIdGen(idGen),
227       mMainUniformBufferIndex(compileOptions.metal.defaultUniformsBindingIndex),
228       mDriverUniformsBindingIndex(compileOptions.metal.driverUniformsBindingIndex),
229       mUBOArgumentBufferBindingIndex(compileOptions.metal.UBOArgumentBufferBindingIndex),
230       mRasterOrderGroupsSupported(compileOptions.pls.fragmentSyncType ==
231                                   ShFragmentSynchronizationType::RasterOrderGroups_Metal),
232       mInjectAsmStatementIntoLoopBodies(compileOptions.metal.injectAsmStatementIntoLoopBodies)
233 {}
234 
emitIndentation()235 void GenMetalTraverser::emitIndentation()
236 {
237     ASSERT(mIndentLevel >= 0);
238 
239     if (mLastIndentationPos == mOut.size())
240     {
241         return;  // Line is already indented.
242     }
243 
244     for (int i = 0; i < mIndentLevel; ++i)
245     {
246         mOut << "  ";
247     }
248 
249     mLastIndentationPos = mOut.size();
250 }
251 
emitOpeningPointerParen()252 void GenMetalTraverser::emitOpeningPointerParen()
253 {
254     mOut << "(*";
255     mOpenPointerParenCount++;
256 }
257 
emitClosingPointerParen()258 void GenMetalTraverser::emitClosingPointerParen()
259 {
260     if (mOpenPointerParenCount > 0)
261     {
262         mOut << ")";
263         mOpenPointerParenCount--;
264     }
265 }
266 
GetOperatorString(TOperator op,const TType & resultType,const TType * argType0,const TType * argType1,const TType * argType2)267 static const char *GetOperatorString(TOperator op,
268                                      const TType &resultType,
269                                      const TType *argType0,
270                                      const TType *argType1,
271                                      const TType *argType2)
272 {
273     switch (op)
274     {
275         case TOperator::EOpComma:
276             return ",";
277         case TOperator::EOpAssign:
278             return "=";
279         case TOperator::EOpInitialize:
280             return "=";
281         case TOperator::EOpAddAssign:
282             return "+=";
283         case TOperator::EOpSubAssign:
284             return "-=";
285         case TOperator::EOpMulAssign:
286             return "*=";
287         case TOperator::EOpDivAssign:
288             return "/=";
289         case TOperator::EOpIModAssign:
290             return "%=";
291         case TOperator::EOpBitShiftLeftAssign:
292             return "<<=";  // TODO: Check logical vs arithmetic shifting.
293         case TOperator::EOpBitShiftRightAssign:
294             return ">>=";  // TODO: Check logical vs arithmetic shifting.
295         case TOperator::EOpBitwiseAndAssign:
296             return "&=";
297         case TOperator::EOpBitwiseXorAssign:
298             return "^=";
299         case TOperator::EOpBitwiseOrAssign:
300             return "|=";
301         case TOperator::EOpAdd:
302             return "+";
303         case TOperator::EOpSub:
304             return "-";
305         case TOperator::EOpMul:
306             return "*";
307         case TOperator::EOpDiv:
308             return "/";
309         case TOperator::EOpIMod:
310             return "%";
311         case TOperator::EOpBitShiftLeft:
312             return "<<";  // TODO: Check logical vs arithmetic shifting.
313         case TOperator::EOpBitShiftRight:
314             return ">>";  // TODO: Check logical vs arithmetic shifting.
315         case TOperator::EOpBitwiseAnd:
316             return "&";
317         case TOperator::EOpBitwiseXor:
318             return "^";
319         case TOperator::EOpBitwiseOr:
320             return "|";
321         case TOperator::EOpLessThan:
322             return "<";
323         case TOperator::EOpGreaterThan:
324             return ">";
325         case TOperator::EOpLessThanEqual:
326             return "<=";
327         case TOperator::EOpGreaterThanEqual:
328             return ">=";
329         case TOperator::EOpLessThanComponentWise:
330             return "<";
331         case TOperator::EOpLessThanEqualComponentWise:
332             return "<=";
333         case TOperator::EOpGreaterThanEqualComponentWise:
334             return ">=";
335         case TOperator::EOpGreaterThanComponentWise:
336             return ">";
337         case TOperator::EOpLogicalOr:
338             return "||";
339         case TOperator::EOpLogicalXor:
340             return "!=/*xor*/";  // XXX: This might need to be handled differently for some obtuse
341                                  // use case.
342         case TOperator::EOpLogicalAnd:
343             return "&&";
344         case TOperator::EOpNegative:
345             return "-";
346         case TOperator::EOpPositive:
347             if (argType0->isMatrix())
348             {
349                 return "";
350             }
351             return "+";
352         case TOperator::EOpLogicalNot:
353             return "!";
354         case TOperator::EOpNotComponentWise:
355             return "!";
356         case TOperator::EOpBitwiseNot:
357             return "~";
358         case TOperator::EOpPostIncrement:
359             return "++";
360         case TOperator::EOpPostDecrement:
361             return "--";
362         case TOperator::EOpPreIncrement:
363             return "++";
364         case TOperator::EOpPreDecrement:
365             return "--";
366         case TOperator::EOpVectorTimesScalarAssign:
367             return "*=";
368         case TOperator::EOpVectorTimesMatrixAssign:
369             return "*=";
370         case TOperator::EOpMatrixTimesScalarAssign:
371             return "*=";
372         case TOperator::EOpMatrixTimesMatrixAssign:
373             return "*=";
374         case TOperator::EOpVectorTimesScalar:
375             return "*";
376         case TOperator::EOpVectorTimesMatrix:
377             return "*";
378         case TOperator::EOpMatrixTimesVector:
379             return "*";
380         case TOperator::EOpMatrixTimesScalar:
381             return "*";
382         case TOperator::EOpMatrixTimesMatrix:
383             return "*";
384         case TOperator::EOpEqualComponentWise:
385             return "==";
386         case TOperator::EOpNotEqualComponentWise:
387             return "!=";
388 
389         case TOperator::EOpEqual:
390             if ((argType0->getStruct() && argType1->getStruct()) &&
391                 (argType0->isArray() && argType1->isArray()))
392             {
393                 return "ANGLE_equalStructArray";
394             }
395 
396             if ((argType0->isVector() && argType1->isVector()) ||
397                 (argType0->getStruct() && argType1->getStruct()) ||
398                 (argType0->isArray() && argType1->isArray()) ||
399                 (argType0->isMatrix() && argType1->isMatrix()))
400 
401             {
402                 return "ANGLE_equal";
403             }
404 
405             return "==";
406 
407         case TOperator::EOpNotEqual:
408             if ((argType0->getStruct() && argType1->getStruct()) &&
409                 (argType0->isArray() && argType1->isArray()))
410             {
411                 return "ANGLE_notEqualStructArray";
412             }
413 
414             if ((argType0->isVector() && argType1->isVector()) ||
415                 (argType0->isArray() && argType1->isArray()) ||
416                 (argType0->isMatrix() && argType1->isMatrix()))
417             {
418                 return "ANGLE_notEqual";
419             }
420             else if (argType0->getStruct() && argType1->getStruct())
421             {
422                 return "ANGLE_notEqualStruct";
423             }
424             return "!=";
425 
426         case TOperator::EOpKill:
427             UNIMPLEMENTED();
428             return "kill";
429         case TOperator::EOpReturn:
430             return "return";
431         case TOperator::EOpBreak:
432             return "break";
433         case TOperator::EOpContinue:
434             return "continue";
435 
436         case TOperator::EOpRadians:
437             return "ANGLE_radians";
438         case TOperator::EOpDegrees:
439             return "ANGLE_degrees";
440         case TOperator::EOpAtan:
441             return argType1 == nullptr ? "metal::atan" : "metal::atan2";
442         case TOperator::EOpMod:
443             return "ANGLE_mod";  // differs from metal::mod
444         case TOperator::EOpRefract:
445             return argType0->isVector() ? "metal::refract" : "ANGLE_refract_scalar";
446         case TOperator::EOpDistance:
447             return argType0->isVector() ? "metal::distance" : "ANGLE_distance_scalar";
448         case TOperator::EOpLength:
449             return argType0->isVector() ? "metal::length" : "metal::abs";
450         case TOperator::EOpDot:
451             return argType0->isVector() ? "metal::dot" : "*";
452         case TOperator::EOpNormalize:
453             return argType0->isVector() ? "metal::fast::normalize" : "metal::sign";
454         case TOperator::EOpFaceforward:
455             return argType0->isVector() ? "metal::faceforward" : "ANGLE_faceforward_scalar";
456         case TOperator::EOpReflect:
457             return argType0->isVector() ? "metal::reflect" : "ANGLE_reflect_scalar";
458         case TOperator::EOpMatrixCompMult:
459             return "ANGLE_componentWiseMultiply";
460         case TOperator::EOpOuterProduct:
461             return "ANGLE_outerProduct";
462         case TOperator::EOpSign:
463             return argType0->getBasicType() == EbtFloat ? "metal::sign" : "ANGLE_sign_int";
464 
465         case TOperator::EOpAbs:
466             return "metal::abs";
467         case TOperator::EOpAll:
468             return "metal::all";
469         case TOperator::EOpAny:
470             return "metal::any";
471         case TOperator::EOpSin:
472             return "metal::sin";
473         case TOperator::EOpCos:
474             return "metal::cos";
475         case TOperator::EOpTan:
476             return "metal::tan";
477         case TOperator::EOpAsin:
478             return "metal::asin";
479         case TOperator::EOpAcos:
480             return "metal::acos";
481         case TOperator::EOpSinh:
482             return "metal::sinh";
483         case TOperator::EOpCosh:
484             return "metal::cosh";
485         case TOperator::EOpTanh:
486             return resultType.getPrecision() == TPrecision::EbpHigh ? "metal::precise::tanh"
487                                                                     : "metal::tanh";
488         case TOperator::EOpAsinh:
489             return "metal::asinh";
490         case TOperator::EOpAcosh:
491             return "metal::acosh";
492         case TOperator::EOpAtanh:
493             return "metal::atanh";
494         case TOperator::EOpFma:
495             return "metal::fma";
496         case TOperator::EOpPow:
497             return "metal::powr";  // GLSL's pow excludes negative x
498         case TOperator::EOpExp:
499             return "metal::exp";
500         case TOperator::EOpExp2:
501             return "metal::exp2";
502         case TOperator::EOpLog:
503             return "metal::log";
504         case TOperator::EOpLog2:
505             return "metal::log2";
506         case TOperator::EOpSqrt:
507             return "metal::sqrt";
508         case TOperator::EOpFloor:
509             return "metal::floor";
510         case TOperator::EOpTrunc:
511             return "metal::trunc";
512         case TOperator::EOpCeil:
513             return "metal::ceil";
514         case TOperator::EOpFract:
515             return "metal::fract";
516         case TOperator::EOpMin:
517             return "metal::min";
518         case TOperator::EOpMax:
519             return "metal::max";
520         case TOperator::EOpRound:
521             return "metal::round";
522         case TOperator::EOpRoundEven:
523             return "metal::rint";
524         case TOperator::EOpClamp:
525             return "metal::clamp";  // TODO fast vs precise namespace
526         case TOperator::EOpSaturate:
527             return "metal::saturate";  // TODO fast vs precise namespace
528         case TOperator::EOpMix:
529             if (!argType1->isScalar() && argType2 && argType2->getBasicType() == EbtBool)
530             {
531                 return "ANGLE_mix_bool";
532             }
533             return "metal::mix";
534         case TOperator::EOpStep:
535             return "metal::step";
536         case TOperator::EOpSmoothstep:
537             return "metal::smoothstep";
538         case TOperator::EOpModf:
539             return "metal::modf";
540         case TOperator::EOpIsnan:
541             return "metal::isnan";
542         case TOperator::EOpIsinf:
543             return "metal::isinf";
544         case TOperator::EOpLdexp:
545             return "metal::ldexp";
546         case TOperator::EOpFrexp:
547             return "metal::frexp";
548         case TOperator::EOpInversesqrt:
549             return "metal::rsqrt";
550         case TOperator::EOpCross:
551             return "metal::cross";
552         case TOperator::EOpDFdx:
553             return "metal::dfdx";
554         case TOperator::EOpDFdy:
555             return "metal::dfdy";
556         case TOperator::EOpFwidth:
557             return "metal::fwidth";
558         case TOperator::EOpTranspose:
559             return "metal::transpose";
560         case TOperator::EOpDeterminant:
561             return "metal::determinant";
562 
563         case TOperator::EOpInverse:
564             return "ANGLE_inverse";
565 
566         case TOperator::EOpInterpolateAtCentroid:
567             return "ANGLE_interpolateAtCentroid";
568         case TOperator::EOpInterpolateAtSample:
569             return "ANGLE_interpolateAtSample";
570         case TOperator::EOpInterpolateAtOffset:
571             return "ANGLE_interpolateAtOffset";
572         case TOperator::EOpInterpolateAtCenter:
573             return "ANGLE_interpolateAtCenter";
574 
575         case TOperator::EOpFloatBitsToInt:
576         case TOperator::EOpFloatBitsToUint:
577         case TOperator::EOpIntBitsToFloat:
578         case TOperator::EOpUintBitsToFloat:
579         {
580 #define RETURN_AS_TYPE_SCALAR()             \
581     do                                      \
582         switch (resultType.getBasicType())  \
583         {                                   \
584             case TBasicType::EbtInt:        \
585                 return "as_type<int>";      \
586             case TBasicType::EbtUInt:       \
587                 return "as_type<uint32_t>"; \
588             case TBasicType::EbtFloat:      \
589                 return "as_type<float>";    \
590             default:                        \
591                 UNIMPLEMENTED();            \
592                 return "TOperator_TODO";    \
593         }                                   \
594     while (false)
595 
596 #define RETURN_AS_TYPE(post)                     \
597     do                                           \
598         switch (resultType.getBasicType())       \
599         {                                        \
600             case TBasicType::EbtInt:             \
601                 return "as_type<int" post ">";   \
602             case TBasicType::EbtUInt:            \
603                 return "as_type<uint" post ">";  \
604             case TBasicType::EbtFloat:           \
605                 return "as_type<float" post ">"; \
606             default:                             \
607                 UNIMPLEMENTED();                 \
608                 return "TOperator_TODO";         \
609         }                                        \
610     while (false)
611 
612             if (resultType.isScalar())
613             {
614                 RETURN_AS_TYPE_SCALAR();
615             }
616             else if (resultType.isVector())
617             {
618                 switch (resultType.getNominalSize())
619                 {
620                     case 2:
621                         RETURN_AS_TYPE("2");
622                     case 3:
623                         RETURN_AS_TYPE("3");
624                     case 4:
625                         RETURN_AS_TYPE("4");
626                     default:
627                         UNREACHABLE();
628                         return nullptr;
629                 }
630             }
631             else
632             {
633                 UNIMPLEMENTED();
634                 return "TOperator_TODO";
635             }
636 
637 #undef RETURN_AS_TYPE
638 #undef RETURN_AS_TYPE_SCALAR
639         }
640 
641         case TOperator::EOpPackUnorm2x16:
642             return "metal::pack_float_to_unorm2x16";
643         case TOperator::EOpPackSnorm2x16:
644             return "metal::pack_float_to_snorm2x16";
645 
646         case TOperator::EOpPackUnorm4x8:
647             return "metal::pack_float_to_unorm4x8";
648         case TOperator::EOpPackSnorm4x8:
649             return "metal::pack_float_to_snorm4x8";
650 
651         case TOperator::EOpUnpackUnorm2x16:
652             return "metal::unpack_unorm2x16_to_float";
653         case TOperator::EOpUnpackSnorm2x16:
654             return "metal::unpack_snorm2x16_to_float";
655 
656         case TOperator::EOpUnpackUnorm4x8:
657             return "metal::unpack_unorm4x8_to_float";
658         case TOperator::EOpUnpackSnorm4x8:
659             return "metal::unpack_snorm4x8_to_float";
660 
661         case TOperator::EOpPackHalf2x16:
662             return "ANGLE_pack_half_2x16";
663         case TOperator::EOpUnpackHalf2x16:
664             return "ANGLE_unpack_half_2x16";
665 
666         case TOperator::EOpNumSamples:
667             return "metal::get_num_samples";
668         case TOperator::EOpSamplePosition:
669             return "metal::get_sample_position";
670 
671         case TOperator::EOpBitfieldExtract:
672         case TOperator::EOpBitfieldInsert:
673         case TOperator::EOpBitfieldReverse:
674         case TOperator::EOpBitCount:
675         case TOperator::EOpFindLSB:
676         case TOperator::EOpFindMSB:
677         case TOperator::EOpUaddCarry:
678         case TOperator::EOpUsubBorrow:
679         case TOperator::EOpUmulExtended:
680         case TOperator::EOpImulExtended:
681         case TOperator::EOpBarrier:
682         case TOperator::EOpMemoryBarrier:
683         case TOperator::EOpMemoryBarrierAtomicCounter:
684         case TOperator::EOpMemoryBarrierBuffer:
685         case TOperator::EOpMemoryBarrierShared:
686         case TOperator::EOpGroupMemoryBarrier:
687         case TOperator::EOpAtomicAdd:
688         case TOperator::EOpAtomicMin:
689         case TOperator::EOpAtomicMax:
690         case TOperator::EOpAtomicAnd:
691         case TOperator::EOpAtomicOr:
692         case TOperator::EOpAtomicXor:
693         case TOperator::EOpAtomicExchange:
694         case TOperator::EOpAtomicCompSwap:
695         case TOperator::EOpEmitVertex:
696         case TOperator::EOpEndPrimitive:
697         case TOperator::EOpFtransform:
698         case TOperator::EOpPackDouble2x32:
699         case TOperator::EOpUnpackDouble2x32:
700         case TOperator::EOpArrayLength:
701             UNIMPLEMENTED();
702             return "TOperator_TODO";
703 
704         case TOperator::EOpNull:
705         case TOperator::EOpConstruct:
706         case TOperator::EOpCallFunctionInAST:
707         case TOperator::EOpCallInternalRawFunction:
708         case TOperator::EOpIndexDirect:
709         case TOperator::EOpIndexIndirect:
710         case TOperator::EOpIndexDirectStruct:
711         case TOperator::EOpIndexDirectInterfaceBlock:
712             UNREACHABLE();
713             return nullptr;
714         default:
715             // Any other built-in function.
716             return nullptr;
717     }
718 }
719 
IsSymbolicOperator(TOperator op,const TType & resultType,const TType * argType0,const TType * argType1)720 static bool IsSymbolicOperator(TOperator op,
721                                const TType &resultType,
722                                const TType *argType0,
723                                const TType *argType1)
724 {
725     const char *operatorString = GetOperatorString(op, resultType, argType0, argType1, nullptr);
726     if (operatorString == nullptr)
727     {
728         return false;
729     }
730     return !std::isalnum(operatorString[0]);
731 }
732 
AsSpecificBinaryNode(TIntermNode & node,TOperator op)733 static TIntermBinary *AsSpecificBinaryNode(TIntermNode &node, TOperator op)
734 {
735     TIntermBinary *binaryNode = node.getAsBinaryNode();
736     if (binaryNode)
737     {
738         return binaryNode->getOp() == op ? binaryNode : nullptr;
739     }
740     return nullptr;
741 }
742 
Parenthesize(TIntermNode & node)743 static bool Parenthesize(TIntermNode &node)
744 {
745     if (node.getAsSymbolNode())
746     {
747         return false;
748     }
749     if (node.getAsConstantUnion())
750     {
751         return false;
752     }
753     if (node.getAsAggregate())
754     {
755         return false;
756     }
757     if (node.getAsSwizzleNode())
758     {
759         return false;
760     }
761 
762     if (TIntermUnary *unaryNode = node.getAsUnaryNode())
763     {
764         // TODO: Use a precedence and associativity rules instead of this ad-hoc impl.
765         const TType &resultType = unaryNode->getType();
766         const TType &argType    = unaryNode->getOperand()->getType();
767         return IsSymbolicOperator(unaryNode->getOp(), resultType, &argType, nullptr);
768     }
769 
770     if (TIntermBinary *binaryNode = node.getAsBinaryNode())
771     {
772         // TODO: Use a precedence and associativity rules instead of this ad-hoc impl.
773         const TOperator op = binaryNode->getOp();
774         switch (op)
775         {
776             case TOperator::EOpIndexDirectStruct:
777             case TOperator::EOpIndexDirectInterfaceBlock:
778             case TOperator::EOpIndexDirect:
779             case TOperator::EOpIndexIndirect:
780                 return Parenthesize(*binaryNode->getLeft());
781 
782             case TOperator::EOpAssign:
783             case TOperator::EOpInitialize:
784                 return AsSpecificBinaryNode(*binaryNode->getRight(), TOperator::EOpComma);
785 
786             default:
787             {
788                 const TType &resultType = binaryNode->getType();
789                 const TType &leftType   = binaryNode->getLeft()->getType();
790                 const TType &rightType  = binaryNode->getRight()->getType();
791                 return IsSymbolicOperator(binaryNode->getOp(), resultType, &leftType, &rightType);
792             }
793         }
794     }
795 
796     return true;
797 }
798 
groupedTraverse(TIntermNode & node)799 void GenMetalTraverser::groupedTraverse(TIntermNode &node)
800 {
801     const bool emitParens = Parenthesize(node);
802 
803     if (emitParens)
804     {
805         mOut << "(";
806     }
807 
808     node.traverse(this);
809 
810     if (emitParens)
811     {
812         mOut << ")";
813     }
814 }
815 
emitPostQualifier(const EmitVariableDeclarationConfig & evdConfig,const VarDecl & decl,const TQualifier qualifier)816 void GenMetalTraverser::emitPostQualifier(const EmitVariableDeclarationConfig &evdConfig,
817                                           const VarDecl &decl,
818                                           const TQualifier qualifier)
819 {
820     bool isInvariant = false;
821     switch (qualifier)
822     {
823         case TQualifier::EvqPosition:
824             isInvariant = decl.type().isInvariant();
825             [[fallthrough]];
826         case TQualifier::EvqFragCoord:
827             mOut << " [[position]]";
828             break;
829 
830         case TQualifier::EvqClipDistance:
831             mOut << " [[clip_distance]] [" << decl.type().getOutermostArraySize() << "]";
832             break;
833 
834         case TQualifier::EvqPointSize:
835             mOut << " [[point_size]]";
836             break;
837 
838         case TQualifier::EvqVertexID:
839             if (evdConfig.isMainParameter)
840             {
841                 mOut << " [[vertex_id]]";
842             }
843             break;
844 
845         case TQualifier::EvqPointCoord:
846             if (evdConfig.isMainParameter)
847             {
848                 mOut << " [[point_coord]]";
849             }
850             break;
851 
852         case TQualifier::EvqFrontFacing:
853             if (evdConfig.isMainParameter)
854             {
855                 mOut << " [[front_facing]]";
856             }
857             break;
858 
859         case TQualifier::EvqSampleID:
860             if (evdConfig.isMainParameter)
861             {
862                 mOut << " [[sample_id]]";
863             }
864             break;
865 
866         case TQualifier::EvqSampleMaskIn:
867             if (evdConfig.isMainParameter)
868             {
869                 mOut << " [[sample_mask]]";
870             }
871             break;
872 
873         default:
874             break;
875     }
876 
877     if (isInvariant)
878     {
879         mOut << " [[invariant]]";
880 
881         TranslatorMetalReflection *reflection = mtl::getTranslatorMetalReflection(&mCompiler);
882         reflection->hasInvariance             = true;
883     }
884 }
885 
emitLoopBody(TIntermBlock * bodyNode)886 void GenMetalTraverser::emitLoopBody(TIntermBlock *bodyNode)
887 {
888     if (mInjectAsmStatementIntoLoopBodies)
889     {
890         emitOpenBrace();
891 
892         emitIndentation();
893         mOut << "__asm__(\"\");\n";
894     }
895 
896     bodyNode->traverse(this);
897 
898     if (mInjectAsmStatementIntoLoopBodies)
899     {
900         emitCloseBrace();
901     }
902 }
903 
EmitName(Sink & out,const Name & name)904 static void EmitName(Sink &out, const Name &name)
905 {
906 #if defined(ANGLE_ENABLE_ASSERTS)
907     DebugSink::EscapedSink escapedOut(out.escape());
908 #else
909     TInfoSinkBase &escapedOut = out;
910 #endif
911     name.emit(escapedOut);
912 }
913 
emitNameOf(const TField & object)914 void GenMetalTraverser::emitNameOf(const TField &object)
915 {
916     EmitName(mOut, Name(object));
917 }
918 
emitNameOf(const TSymbol & object)919 void GenMetalTraverser::emitNameOf(const TSymbol &object)
920 {
921     auto it = mRenamedSymbols.find(&object);
922     if (it == mRenamedSymbols.end())
923     {
924         EmitName(mOut, Name(object));
925     }
926     else
927     {
928         EmitName(mOut, it->second);
929     }
930 }
931 
emitNameOf(const VarDecl & object)932 void GenMetalTraverser::emitNameOf(const VarDecl &object)
933 {
934     if (object.isField())
935     {
936         emitNameOf(object.field());
937     }
938     else
939     {
940         emitNameOf(object.variable());
941     }
942 }
943 
emitBareTypeName(const TType & type,const EmitTypeConfig & etConfig)944 void GenMetalTraverser::emitBareTypeName(const TType &type, const EmitTypeConfig &etConfig)
945 {
946     const TBasicType basicType = type.getBasicType();
947 
948     switch (basicType)
949     {
950         case TBasicType::EbtVoid:
951         case TBasicType::EbtBool:
952         case TBasicType::EbtFloat:
953         case TBasicType::EbtInt:
954         {
955             mOut << type.getBasicString();
956             break;
957         }
958         case TBasicType::EbtUInt:
959         {
960             if (type.isScalar())
961             {
962                 mOut << "uint32_t";
963             }
964             else
965             {
966                 mOut << type.getBasicString();
967             }
968         }
969         break;
970 
971         case TBasicType::EbtStruct:
972         {
973             const TStructure &structure = *type.getStruct();
974             emitNameOf(structure);
975         }
976         break;
977 
978         case TBasicType::EbtInterfaceBlock:
979         {
980             const TInterfaceBlock &interfaceBlock = *type.getInterfaceBlock();
981             emitNameOf(interfaceBlock);
982         }
983         break;
984 
985         default:
986         {
987             if (IsSampler(basicType))
988             {
989                 if (etConfig.evdConfig && etConfig.evdConfig->isMainParameter)
990                 {
991                     EmitName(mOut, GetTextureTypeName(basicType));
992                 }
993                 else
994                 {
995                     const TStructure &env = mSymbolEnv.getTextureEnv(basicType);
996                     emitNameOf(env);
997                 }
998             }
999             else if (IsImage(basicType))
1000             {
1001                 mOut << "metal::texture2d<";
1002                 switch (type.getBasicType())
1003                 {
1004                     case EbtImage2D:
1005                         mOut << "float";
1006                         break;
1007                     case EbtIImage2D:
1008                         mOut << "int";
1009                         break;
1010                     case EbtUImage2D:
1011                         mOut << "uint";
1012                         break;
1013                     default:
1014                         UNIMPLEMENTED();
1015                         break;
1016                 }
1017                 if (type.getMemoryQualifier().readonly || type.getMemoryQualifier().writeonly)
1018                 {
1019                     UNIMPLEMENTED();
1020                 }
1021                 mOut << ", metal::access::read_write>";
1022             }
1023             else
1024             {
1025                 UNIMPLEMENTED();
1026             }
1027         }
1028     }
1029 }
1030 
emitType(const TType & type,const EmitTypeConfig & etConfig)1031 void GenMetalTraverser::emitType(const TType &type, const EmitTypeConfig &etConfig)
1032 {
1033     const bool isUBO = etConfig.evdConfig ? etConfig.evdConfig->isUBO : false;
1034     if (etConfig.evdConfig)
1035     {
1036         const auto &evdConfig = *etConfig.evdConfig;
1037         if (isUBO)
1038         {
1039             if (type.isArray())
1040             {
1041                 mOut << "ANGLE_tensor<";
1042             }
1043         }
1044         if (evdConfig.isPointer)
1045         {
1046             mOut << toString(*evdConfig.isPointer);
1047             mOut << " ";
1048         }
1049         else if (evdConfig.isReference)
1050         {
1051             mOut << toString(*evdConfig.isReference);
1052             mOut << " ";
1053         }
1054     }
1055 
1056     if (!isUBO)
1057     {
1058         if (type.isArray())
1059         {
1060             mOut << "ANGLE_tensor<";
1061         }
1062     }
1063 
1064     if (type.isInterpolant())
1065     {
1066         mOut << "metal::interpolant<";
1067     }
1068 
1069     if (type.isVector() || type.isMatrix())
1070     {
1071         mOut << "metal::";
1072     }
1073 
1074     if (etConfig.evdConfig && etConfig.evdConfig->isPacked)
1075     {
1076         mOut << "packed_";
1077     }
1078 
1079     emitBareTypeName(type, etConfig);
1080 
1081     if (type.isVector())
1082     {
1083         mOut << static_cast<uint32_t>(type.getNominalSize());
1084     }
1085     else if (type.isMatrix())
1086     {
1087         mOut << static_cast<uint32_t>(type.getCols()) << "x"
1088              << static_cast<uint32_t>(type.getRows());
1089     }
1090 
1091     if (type.isInterpolant())
1092     {
1093         mOut << ", metal::interpolation::";
1094         switch (type.getQualifier())
1095         {
1096             case EvqNoPerspectiveIn:
1097             case EvqNoPerspectiveCentroidIn:
1098             case EvqNoPerspectiveSampleIn:
1099                 mOut << "no_";
1100                 break;
1101             default:
1102                 break;
1103         }
1104         mOut << "perspective>";
1105     }
1106 
1107     if (!isUBO)
1108     {
1109         if (type.isArray())
1110         {
1111             for (auto size : type.getArraySizes())
1112             {
1113                 mOut << ", " << size;
1114             }
1115             mOut << ">";
1116         }
1117     }
1118 
1119     if (etConfig.evdConfig)
1120     {
1121         const auto &evdConfig = *etConfig.evdConfig;
1122         if (evdConfig.isPointer)
1123         {
1124             mOut << " *";
1125         }
1126         else if (evdConfig.isReference)
1127         {
1128             mOut << " &";
1129         }
1130         if (isUBO)
1131         {
1132             if (type.isArray())
1133             {
1134                 for (auto size : type.getArraySizes())
1135                 {
1136                     mOut << ", " << size;
1137                 }
1138                 mOut << ">";
1139             }
1140         }
1141     }
1142 }
1143 
emitFieldDeclaration(const TField & field,const TStructure & parent,FieldAnnotationIndices & annotationIndices)1144 void GenMetalTraverser::emitFieldDeclaration(const TField &field,
1145                                              const TStructure &parent,
1146                                              FieldAnnotationIndices &annotationIndices)
1147 {
1148     const TType &type      = *field.type();
1149     const TBasicType basic = type.getBasicType();
1150 
1151     EmitVariableDeclarationConfig evdConfig;
1152     evdConfig.emitPostQualifier      = true;
1153     evdConfig.disableStructSpecifier = true;
1154     evdConfig.isPacked               = mSymbolEnv.isPacked(field);
1155     evdConfig.isUBO                  = mSymbolEnv.isUBO(field);
1156     evdConfig.isPointer              = mSymbolEnv.isPointer(field);
1157     evdConfig.isReference            = mSymbolEnv.isReference(field);
1158     emitVariableDeclaration(VarDecl(field), evdConfig);
1159 
1160     const TQualifier qual = type.getQualifier();
1161     switch (qual)
1162     {
1163         case TQualifier::EvqFlatIn:
1164             if (mPipelineStructs.fragmentIn.external == &parent)
1165             {
1166                 mOut << " [[flat]]";
1167                 TranslatorMetalReflection *reflection =
1168                     mtl::getTranslatorMetalReflection(&mCompiler);
1169                 reflection->hasFlatInput = true;
1170             }
1171             break;
1172 
1173         case TQualifier::EvqNoPerspectiveIn:
1174             if (mPipelineStructs.fragmentIn.external == &parent && !type.isInterpolant())
1175             {
1176                 mOut << " [[center_no_perspective]]";
1177             }
1178             break;
1179 
1180         case TQualifier::EvqCentroidIn:
1181             if (mPipelineStructs.fragmentIn.external == &parent && !type.isInterpolant())
1182             {
1183                 mOut << " [[centroid_perspective]]";
1184             }
1185             break;
1186 
1187         case TQualifier::EvqSampleIn:
1188             if (mPipelineStructs.fragmentIn.external == &parent && !type.isInterpolant())
1189             {
1190                 mOut << " [[sample_perspective]]";
1191             }
1192             break;
1193 
1194         case TQualifier::EvqNoPerspectiveCentroidIn:
1195             if (mPipelineStructs.fragmentIn.external == &parent && !type.isInterpolant())
1196             {
1197                 mOut << " [[centroid_no_perspective]]";
1198             }
1199             break;
1200 
1201         case TQualifier::EvqNoPerspectiveSampleIn:
1202             if (mPipelineStructs.fragmentIn.external == &parent && !type.isInterpolant())
1203             {
1204                 mOut << " [[sample_no_perspective]]";
1205             }
1206             break;
1207 
1208         case TQualifier::EvqFragColor:
1209             mOut << " [[color(0)]]";
1210             break;
1211 
1212         case TQualifier::EvqSecondaryFragColorEXT:
1213         case TQualifier::EvqSecondaryFragDataEXT:
1214             mOut << " [[color(0), index(1)]]";
1215             break;
1216 
1217         case TQualifier::EvqFragmentOut:
1218         case TQualifier::EvqFragmentInOut:
1219         case TQualifier::EvqFragData:
1220             if (mPipelineStructs.fragmentOut.external == &parent ||
1221                 mPipelineStructs.fragmentOut.externalExtra == &parent)
1222             {
1223                 if ((type.isVector() &&
1224                      (basic == TBasicType::EbtInt || basic == TBasicType::EbtUInt ||
1225                       basic == TBasicType::EbtFloat)) ||
1226                     qual == EvqFragData)
1227                 {
1228                     // The OpenGL ES 3.0 spec says locations must be specified
1229                     // unless there is only a single output, in which case the
1230                     // location is 0. So, when we get to this point the shader
1231                     // will have been rejected if locations are not specified
1232                     // and there is more than one output.
1233                     const TLayoutQualifier &layoutQualifier = type.getLayoutQualifier();
1234                     if (layoutQualifier.locationsSpecified)
1235                     {
1236                         mOut << " [[color(" << layoutQualifier.location << ")";
1237                         ASSERT(layoutQualifier.index >= -1 && layoutQualifier.index <= 1);
1238                         if (layoutQualifier.index == 1)
1239                         {
1240                             mOut << ", index(1)";
1241                         }
1242                     }
1243                     else if (qual == EvqFragData)
1244                     {
1245                         mOut << " [[color(" << annotationIndices.color++ << ")";
1246                     }
1247                     else
1248                     {
1249                         // Either the only output or EXT_blend_func_extended is used;
1250                         // actual assignment will happen in UpdateFragmentShaderOutputs.
1251                         mOut << " [[" << sh::kUnassignedFragmentOutputString;
1252                     }
1253                     if (mRasterOrderGroupsSupported && qual == TQualifier::EvqFragmentInOut)
1254                     {
1255                         // Put fragment inouts in their own raster order group for better
1256                         // parallelism.
1257                         // NOTE: this is not required for the reads to be ordered and coherent.
1258                         // TODO(anglebug.com/7279): Consider making raster order groups a PLS layout
1259                         // qualifier?
1260                         mOut << ", raster_order_group(0)";
1261                     }
1262                     mOut << "]]";
1263                 }
1264             }
1265             break;
1266 
1267         case TQualifier::EvqFragDepth:
1268             mOut << " [[depth(";
1269             switch (type.getLayoutQualifier().depth)
1270             {
1271                 case EdGreater:
1272                     mOut << "greater";
1273                     break;
1274                 case EdLess:
1275                     mOut << "less";
1276                     break;
1277                 default:
1278                     mOut << "any";
1279                     break;
1280             }
1281             mOut << "), function_constant(" << sh::mtl::kDepthWriteEnabledConstName << ")]]";
1282             break;
1283 
1284         case TQualifier::EvqSampleMask:
1285             if (field.symbolType() == SymbolType::AngleInternal)
1286             {
1287                 mOut << " [[sample_mask, function_constant("
1288                      << sh::mtl::kSampleMaskWriteEnabledConstName << ")]]";
1289             }
1290             break;
1291 
1292         default:
1293             break;
1294     }
1295 
1296     if (IsImage(type.getBasicType()))
1297     {
1298         if (type.getArraySizeProduct() != 1)
1299         {
1300             UNIMPLEMENTED();
1301         }
1302         mOut << " [[";
1303         if (type.getLayoutQualifier().rasterOrdered)
1304         {
1305             mOut << "raster_order_group(0), ";
1306         }
1307         mOut << "texture(" << mMainTextureIndex << ")]]";
1308         TranslatorMetalReflection *reflection = mtl::getTranslatorMetalReflection(&mCompiler);
1309         reflection->addRWTextureBinding(type.getLayoutQualifier().binding,
1310                                         static_cast<int>(mMainTextureIndex));
1311         ++mMainTextureIndex;
1312     }
1313 }
1314 
BuildExternalAttributeIndexMap(const TCompiler & compiler,const PipelineScoped<TStructure> & structure)1315 static std::map<Name, size_t> BuildExternalAttributeIndexMap(
1316     const TCompiler &compiler,
1317     const PipelineScoped<TStructure> &structure)
1318 {
1319     ASSERT(structure.isTotallyFull());
1320 
1321     const auto &shaderVars     = compiler.getAttributes();
1322     const size_t shaderVarSize = shaderVars.size();
1323     size_t shaderVarIndex      = 0;
1324 
1325     const auto &externalFields = structure.external->fields();
1326     const size_t externalSize  = externalFields.size();
1327     size_t externalIndex       = 0;
1328 
1329     const auto &internalFields = structure.internal->fields();
1330     const size_t internalSize  = internalFields.size();
1331     size_t internalIndex       = 0;
1332 
1333     // Internal fields are never split. External fields are sometimes split.
1334     ASSERT(externalSize >= internalSize);
1335 
1336     // Structures do not contain any inactive fields.
1337     ASSERT(shaderVarSize >= internalSize);
1338 
1339     std::map<Name, size_t> externalNameToAttributeIndex;
1340     size_t attributeIndex = 0;
1341 
1342     while (internalIndex < internalSize)
1343     {
1344         const TField &internalField = *internalFields[internalIndex];
1345         const Name internalName     = Name(internalField);
1346         const TType &internalType   = *internalField.type();
1347         while (internalName.rawName() != shaderVars[shaderVarIndex].name &&
1348                internalName.rawName() != shaderVars[shaderVarIndex].mappedName)
1349         {
1350             // This case represents an inactive field.
1351 
1352             ++shaderVarIndex;
1353             ASSERT(shaderVarIndex < shaderVarSize);
1354 
1355             ++attributeIndex;  // TODO: Might need to increment more if shader var type is a matrix.
1356         }
1357 
1358         const size_t cols =
1359             (internalType.isMatrix() && !externalFields[externalIndex]->type()->isMatrix())
1360                 ? internalType.getCols()
1361                 : 1;
1362 
1363         for (size_t c = 0; c < cols; ++c)
1364         {
1365             const TField &externalField = *externalFields[externalIndex];
1366             const Name externalName     = Name(externalField);
1367 
1368             externalNameToAttributeIndex[externalName] = attributeIndex;
1369 
1370             ++externalIndex;
1371             ++attributeIndex;
1372         }
1373 
1374         ++shaderVarIndex;
1375         ++internalIndex;
1376     }
1377 
1378     ASSERT(shaderVarIndex <= shaderVarSize);
1379     ASSERT(externalIndex <= externalSize);  // less than if padding was introduced
1380     ASSERT(internalIndex == internalSize);
1381 
1382     return externalNameToAttributeIndex;
1383 }
1384 
emitAttributeDeclaration(const TField & field,FieldAnnotationIndices & annotationIndices)1385 void GenMetalTraverser::emitAttributeDeclaration(const TField &field,
1386                                                  FieldAnnotationIndices &annotationIndices)
1387 {
1388     EmitVariableDeclarationConfig evdConfig;
1389     evdConfig.disableStructSpecifier = true;
1390     emitVariableDeclaration(VarDecl(field), evdConfig);
1391     mOut << sh::kUnassignedAttributeString;
1392 }
1393 
emitUniformBufferDeclaration(const TField & field,FieldAnnotationIndices & annotationIndices)1394 void GenMetalTraverser::emitUniformBufferDeclaration(const TField &field,
1395                                                      FieldAnnotationIndices &annotationIndices)
1396 {
1397     EmitVariableDeclarationConfig evdConfig;
1398     evdConfig.disableStructSpecifier = true;
1399     evdConfig.isUBO                  = mSymbolEnv.isUBO(field);
1400     evdConfig.isPointer              = mSymbolEnv.isPointer(field);
1401     emitVariableDeclaration(VarDecl(field), evdConfig);
1402     mOut << "[[id(" << annotationIndices.attribute << ")]]";
1403 
1404     const TType &type   = *field.type();
1405     const int arraySize = type.isArray() ? type.getArraySizeProduct() : 1;
1406 
1407     TranslatorMetalReflection *reflection = mtl::getTranslatorMetalReflection(&mCompiler);
1408     ASSERT(type.getBasicType() == TBasicType::EbtStruct);
1409     const TStructure *structure    = type.getStruct();
1410     const std::string originalName = reflection->getOriginalName(structure->uniqueId().get());
1411     reflection->addUniformBufferBinding(
1412         originalName,
1413         {.bindIndex = annotationIndices.attribute, .arraySize = static_cast<size_t>(arraySize)});
1414 
1415     annotationIndices.attribute += arraySize;
1416 }
1417 
emitStructDeclaration(const TType & type)1418 void GenMetalTraverser::emitStructDeclaration(const TType &type)
1419 {
1420     ASSERT(type.getBasicType() == TBasicType::EbtStruct);
1421     ASSERT(type.isStructSpecifier());
1422 
1423     mOut << "struct ";
1424     emitBareTypeName(type, {});
1425 
1426     mOut << "\n";
1427     emitOpenBrace();
1428 
1429     const TStructure &structure = *type.getStruct();
1430     std::map<Name, size_t> fieldToAttributeIndex;
1431     const bool hasAttributeIndices      = mPipelineStructs.vertexIn.external == &structure;
1432     const bool hasUniformBufferIndicies = mPipelineStructs.uniformBuffers.external == &structure;
1433     const bool reclaimUnusedAttributeIndices = mCompiler.getShaderVersion() < 300;
1434 
1435     if (hasAttributeIndices)
1436     {
1437         // When attribute aliasing is supported, external attribute struct is filled post-link.
1438         if (mCompiler.supportsAttributeAliasing())
1439         {
1440             mtl::getTranslatorMetalReflection(&mCompiler)->hasAttributeAliasing = true;
1441             mOut << "@@Attrib-Bindings@@\n";
1442             emitCloseBrace();
1443             return;
1444         }
1445 
1446         fieldToAttributeIndex =
1447             BuildExternalAttributeIndexMap(mCompiler, mPipelineStructs.vertexIn);
1448     }
1449 
1450     FieldAnnotationIndices annotationIndices;
1451 
1452     for (const TField *field : structure.fields())
1453     {
1454         emitIndentation();
1455         if (hasAttributeIndices)
1456         {
1457             const auto it = fieldToAttributeIndex.find(Name(*field));
1458             if (it == fieldToAttributeIndex.end())
1459             {
1460                 ASSERT(field->symbolType() == SymbolType::AngleInternal);
1461                 ASSERT(field->name().beginsWith("_"));
1462                 ASSERT(angle::EndsWith(field->name().data(), "_pad"));
1463                 emitFieldDeclaration(*field, structure, annotationIndices);
1464             }
1465             else
1466             {
1467                 ASSERT(field->symbolType() != SymbolType::AngleInternal ||
1468                        !field->name().beginsWith("_") ||
1469                        !angle::EndsWith(field->name().data(), "_pad"));
1470                 if (!reclaimUnusedAttributeIndices)
1471                 {
1472                     annotationIndices.attribute = it->second;
1473                 }
1474                 emitAttributeDeclaration(*field, annotationIndices);
1475             }
1476         }
1477         else if (hasUniformBufferIndicies)
1478         {
1479             emitUniformBufferDeclaration(*field, annotationIndices);
1480         }
1481         else
1482         {
1483             emitFieldDeclaration(*field, structure, annotationIndices);
1484         }
1485         mOut << ";\n";
1486     }
1487 
1488     emitCloseBrace();
1489 }
1490 
emitOrdinaryVariableDeclaration(const VarDecl & decl,const EmitVariableDeclarationConfig & evdConfig)1491 void GenMetalTraverser::emitOrdinaryVariableDeclaration(
1492     const VarDecl &decl,
1493     const EmitVariableDeclarationConfig &evdConfig)
1494 {
1495     EmitTypeConfig etConfig;
1496     etConfig.evdConfig = &evdConfig;
1497 
1498     const TType &type = decl.type();
1499     if (type.getQualifier() == TQualifier::EvqClipDistance)
1500     {
1501         // Clip distance output uses float[n] type instead of metal::array.
1502         // The element count is emitted after the post qualifier.
1503         ASSERT(type.getBasicType() == TBasicType::EbtFloat);
1504         mOut << "float";
1505     }
1506     else if (type.getQualifier() == TQualifier::EvqSampleID && evdConfig.isMainParameter)
1507     {
1508         // Metal's [[sample_id]] must be unsigned
1509         ASSERT(type.getBasicType() == TBasicType::EbtInt);
1510         mOut << "uint32_t";
1511     }
1512     else
1513     {
1514         emitType(type, etConfig);
1515     }
1516     if (decl.symbolType() != SymbolType::Empty)
1517     {
1518         mOut << " ";
1519         emitNameOf(decl);
1520     }
1521 }
1522 
emitVariableDeclaration(const VarDecl & decl,const EmitVariableDeclarationConfig & evdConfig)1523 void GenMetalTraverser::emitVariableDeclaration(const VarDecl &decl,
1524                                                 const EmitVariableDeclarationConfig &evdConfig)
1525 {
1526     const SymbolType symbolType = decl.symbolType();
1527     const TType &type           = decl.type();
1528     const TBasicType basicType  = type.getBasicType();
1529 
1530     switch (basicType)
1531     {
1532         case TBasicType::EbtStruct:
1533         {
1534             if (type.isStructSpecifier() && !evdConfig.disableStructSpecifier)
1535             {
1536                 // It's invalid to declare a struct inside a function argument. When emitting a
1537                 // function parameter, the callsite should set evdConfig.disableStructSpecifier.
1538                 ASSERT(!evdConfig.isParameter);
1539                 emitStructDeclaration(type);
1540                 if (symbolType != SymbolType::Empty)
1541                 {
1542                     mOut << " ";
1543                     emitNameOf(decl);
1544                 }
1545             }
1546             else
1547             {
1548                 emitOrdinaryVariableDeclaration(decl, evdConfig);
1549             }
1550         }
1551         break;
1552 
1553         default:
1554         {
1555             ASSERT(symbolType != SymbolType::Empty || evdConfig.isParameter);
1556             emitOrdinaryVariableDeclaration(decl, evdConfig);
1557         }
1558     }
1559 
1560     if (evdConfig.emitPostQualifier)
1561     {
1562         emitPostQualifier(evdConfig, decl, type.getQualifier());
1563     }
1564 }
1565 
visitSymbol(TIntermSymbol * symbolNode)1566 void GenMetalTraverser::visitSymbol(TIntermSymbol *symbolNode)
1567 {
1568     const TVariable &var = symbolNode->variable();
1569     const TType &type    = var.getType();
1570     ASSERT(var.symbolType() != SymbolType::Empty);
1571 
1572     if (type.getBasicType() == TBasicType::EbtVoid)
1573     {
1574         mOut << "/*";
1575         emitNameOf(var);
1576         mOut << "*/";
1577     }
1578     else
1579     {
1580         emitNameOf(var);
1581     }
1582 }
1583 
emitSingleConstant(const TConstantUnion * const constUnion)1584 void GenMetalTraverser::emitSingleConstant(const TConstantUnion *const constUnion)
1585 {
1586     switch (constUnion->getType())
1587     {
1588         case TBasicType::EbtBool:
1589         {
1590             mOut << (constUnion->getBConst() ? "true" : "false");
1591         }
1592         break;
1593 
1594         case TBasicType::EbtFloat:
1595         {
1596             float value = constUnion->getFConst();
1597             if (std::isnan(value))
1598             {
1599                 mOut << "NAN";
1600             }
1601             else if (std::isinf(value))
1602             {
1603                 if (value < 0)
1604                 {
1605                     mOut << "-";
1606                 }
1607                 mOut << "INFINITY";
1608             }
1609             else
1610             {
1611                 mOut << value << "f";
1612             }
1613         }
1614         break;
1615 
1616         case TBasicType::EbtInt:
1617         {
1618             mOut << constUnion->getIConst();
1619         }
1620         break;
1621 
1622         case TBasicType::EbtUInt:
1623         {
1624             mOut << constUnion->getUConst() << "u";
1625         }
1626         break;
1627 
1628         default:
1629         {
1630             UNIMPLEMENTED();
1631         }
1632     }
1633 }
1634 
emitConstantUnionArray(const TConstantUnion * const constUnion,const size_t size)1635 const TConstantUnion *GenMetalTraverser::emitConstantUnionArray(
1636     const TConstantUnion *const constUnion,
1637     const size_t size)
1638 {
1639     const TConstantUnion *constUnionIterated = constUnion;
1640     for (size_t i = 0; i < size; i++, constUnionIterated++)
1641     {
1642         emitSingleConstant(constUnionIterated);
1643 
1644         if (i != size - 1)
1645         {
1646             mOut << ", ";
1647         }
1648     }
1649     return constUnionIterated;
1650 }
1651 
emitConstantUnion(const TType & type,const TConstantUnion * constUnionBegin)1652 const TConstantUnion *GenMetalTraverser::emitConstantUnion(const TType &type,
1653                                                            const TConstantUnion *constUnionBegin)
1654 {
1655     const TConstantUnion *constUnionCurr = constUnionBegin;
1656     const TStructure *structure          = type.getStruct();
1657     if (structure)
1658     {
1659         EmitTypeConfig config = EmitTypeConfig{nullptr};
1660         emitType(type, config);
1661         mOut << "{";
1662         const TFieldList &fields = structure->fields();
1663         for (size_t i = 0; i < fields.size(); ++i)
1664         {
1665             const TType *fieldType = fields[i]->type();
1666             constUnionCurr         = emitConstantUnion(*fieldType, constUnionCurr);
1667             if (i != fields.size() - 1)
1668             {
1669                 mOut << ", ";
1670             }
1671         }
1672         mOut << "}";
1673     }
1674     else
1675     {
1676         size_t size    = type.getObjectSize();
1677         bool writeType = size > 1;
1678         if (writeType)
1679         {
1680             EmitTypeConfig config = EmitTypeConfig{nullptr};
1681             emitType(type, config);
1682             mOut << "(";
1683         }
1684         constUnionCurr = emitConstantUnionArray(constUnionCurr, size);
1685         if (writeType)
1686         {
1687             mOut << ")";
1688         }
1689     }
1690     return constUnionCurr;
1691 }
1692 
visitConstantUnion(TIntermConstantUnion * constValueNode)1693 void GenMetalTraverser::visitConstantUnion(TIntermConstantUnion *constValueNode)
1694 {
1695     emitConstantUnion(constValueNode->getType(), constValueNode->getConstantValue());
1696 }
1697 
visitSwizzle(Visit,TIntermSwizzle * swizzleNode)1698 bool GenMetalTraverser::visitSwizzle(Visit, TIntermSwizzle *swizzleNode)
1699 {
1700     groupedTraverse(*swizzleNode->getOperand());
1701     mOut << ".";
1702 
1703     {
1704 #if defined(ANGLE_ENABLE_ASSERTS)
1705         DebugSink::EscapedSink escapedOut(mOut.escape());
1706         TInfoSinkBase &out = escapedOut.get();
1707 #else
1708         TInfoSinkBase &out = mOut;
1709 #endif
1710         swizzleNode->writeOffsetsAsXYZW(&out);
1711     }
1712 
1713     return false;
1714 }
1715 
getDirectField(const TFieldListCollection & fieldListCollection,const TConstantUnion & index)1716 const TField &GenMetalTraverser::getDirectField(const TFieldListCollection &fieldListCollection,
1717                                                 const TConstantUnion &index)
1718 {
1719     ASSERT(index.getType() == TBasicType::EbtInt);
1720 
1721     const TFieldList &fieldList = fieldListCollection.fields();
1722     const int indexVal          = index.getIConst();
1723     const TField &field         = *fieldList[indexVal];
1724 
1725     return field;
1726 }
1727 
getDirectField(const TIntermTyped & fieldsNode,TIntermTyped & indexNode)1728 const TField &GenMetalTraverser::getDirectField(const TIntermTyped &fieldsNode,
1729                                                 TIntermTyped &indexNode)
1730 {
1731     const TType &fieldsType = fieldsNode.getType();
1732 
1733     const TFieldListCollection *fieldListCollection = fieldsType.getStruct();
1734     if (fieldListCollection == nullptr)
1735     {
1736         fieldListCollection = fieldsType.getInterfaceBlock();
1737     }
1738     ASSERT(fieldListCollection);
1739 
1740     const TIntermConstantUnion *indexNode_ = indexNode.getAsConstantUnion();
1741     ASSERT(indexNode_);
1742     const TConstantUnion &index = *indexNode_->getConstantValue();
1743 
1744     return getDirectField(*fieldListCollection, index);
1745 }
1746 
visitBinary(Visit,TIntermBinary * binaryNode)1747 bool GenMetalTraverser::visitBinary(Visit, TIntermBinary *binaryNode)
1748 {
1749     const TOperator op      = binaryNode->getOp();
1750     TIntermTyped &leftNode  = *binaryNode->getLeft();
1751     TIntermTyped &rightNode = *binaryNode->getRight();
1752 
1753     switch (op)
1754     {
1755         case TOperator::EOpIndexDirectStruct:
1756         case TOperator::EOpIndexDirectInterfaceBlock:
1757         {
1758             const TField &field = getDirectField(leftNode, rightNode);
1759             if (mSymbolEnv.isPointer(field) && mSymbolEnv.isUBO(field))
1760             {
1761                 emitOpeningPointerParen();
1762             }
1763             groupedTraverse(leftNode);
1764             if (!mSymbolEnv.isPointer(field))
1765             {
1766                 emitClosingPointerParen();
1767             }
1768             mOut << ".";
1769             emitNameOf(field);
1770         }
1771         break;
1772 
1773         case TOperator::EOpIndexDirect:
1774         case TOperator::EOpIndexIndirect:
1775         {
1776             TType leftType = leftNode.getType();
1777             groupedTraverse(leftNode);
1778             mOut << "[";
1779             const TConstantUnion *constIndex = rightNode.getConstantValue();
1780             // TODO(anglebug.com/8491): Convert type and bound checks to
1781             // assertions after AST validation is enabled for MSL translation.
1782             if (!leftType.isUnsizedArray() && constIndex != nullptr &&
1783                 constIndex->getType() == EbtInt && constIndex->getIConst() >= 0 &&
1784                 constIndex->getIConst() < static_cast<int>(leftType.isArray()
1785                                                                ? leftType.getOutermostArraySize()
1786                                                                : leftType.getNominalSize()))
1787             {
1788                 emitSingleConstant(constIndex);
1789             }
1790             else
1791             {
1792                 mOut << "ANGLE_int_clamp(";
1793                 groupedTraverse(rightNode);
1794                 mOut << ", 0, ";
1795                 if (leftType.isUnsizedArray())
1796                 {
1797                     groupedTraverse(leftNode);
1798                     mOut << ".size()";
1799                 }
1800                 else
1801                 {
1802                     uint32_t maxSize;
1803                     if (leftType.isArray())
1804                     {
1805                         maxSize = leftType.getOutermostArraySize() - 1;
1806                     }
1807                     else
1808                     {
1809                         maxSize = leftType.getNominalSize() - 1;
1810                     }
1811                     mOut << maxSize;
1812                 }
1813                 mOut << ")";
1814             }
1815             mOut << "]";
1816         }
1817         break;
1818 
1819         default:
1820         {
1821             const TType &resultType = binaryNode->getType();
1822             const TType &leftType   = leftNode.getType();
1823             const TType &rightType  = rightNode.getType();
1824 
1825             if (IsSymbolicOperator(op, resultType, &leftType, &rightType))
1826             {
1827                 groupedTraverse(leftNode);
1828                 if (op != TOperator::EOpComma)
1829                 {
1830                     mOut << " ";
1831                 }
1832                 else
1833                 {
1834                     emitClosingPointerParen();
1835                 }
1836                 mOut << GetOperatorString(op, resultType, &leftType, &rightType, nullptr) << " ";
1837                 groupedTraverse(rightNode);
1838             }
1839             else
1840             {
1841                 emitClosingPointerParen();
1842                 mOut << GetOperatorString(op, resultType, &leftType, &rightType, nullptr) << "(";
1843                 leftNode.traverse(this);
1844                 mOut << ", ";
1845                 rightNode.traverse(this);
1846                 mOut << ")";
1847             }
1848         }
1849     }
1850 
1851     return false;
1852 }
1853 
IsPostfix(TOperator op)1854 static bool IsPostfix(TOperator op)
1855 {
1856     switch (op)
1857     {
1858         case TOperator::EOpPostIncrement:
1859         case TOperator::EOpPostDecrement:
1860             return true;
1861 
1862         default:
1863             return false;
1864     }
1865 }
1866 
visitUnary(Visit,TIntermUnary * unaryNode)1867 bool GenMetalTraverser::visitUnary(Visit, TIntermUnary *unaryNode)
1868 {
1869     const TOperator op      = unaryNode->getOp();
1870     const TType &resultType = unaryNode->getType();
1871 
1872     TIntermTyped &arg    = *unaryNode->getOperand();
1873     const TType &argType = arg.getType();
1874 
1875     const char *name = GetOperatorString(op, resultType, &argType, nullptr, nullptr);
1876 
1877     if (IsSymbolicOperator(op, resultType, &argType, nullptr))
1878     {
1879         const bool postfix = IsPostfix(op);
1880         if (!postfix)
1881         {
1882             mOut << name;
1883         }
1884         groupedTraverse(arg);
1885         if (postfix)
1886         {
1887             mOut << name;
1888         }
1889     }
1890     else
1891     {
1892         mOut << name << "(";
1893         arg.traverse(this);
1894         mOut << ")";
1895     }
1896 
1897     return false;
1898 }
1899 
visitTernary(Visit,TIntermTernary * conditionalNode)1900 bool GenMetalTraverser::visitTernary(Visit, TIntermTernary *conditionalNode)
1901 {
1902     groupedTraverse(*conditionalNode->getCondition());
1903     mOut << " ? ";
1904     groupedTraverse(*conditionalNode->getTrueExpression());
1905     mOut << " : ";
1906     groupedTraverse(*conditionalNode->getFalseExpression());
1907 
1908     return false;
1909 }
1910 
visitIfElse(Visit,TIntermIfElse * ifThenElseNode)1911 bool GenMetalTraverser::visitIfElse(Visit, TIntermIfElse *ifThenElseNode)
1912 {
1913     TIntermTyped &condNode = *ifThenElseNode->getCondition();
1914     TIntermBlock *thenNode = ifThenElseNode->getTrueBlock();
1915     TIntermBlock *elseNode = ifThenElseNode->getFalseBlock();
1916 
1917     emitIndentation();
1918     mOut << "if (";
1919     condNode.traverse(this);
1920     mOut << ")";
1921 
1922     if (thenNode)
1923     {
1924         mOut << "\n";
1925         thenNode->traverse(this);
1926     }
1927     else
1928     {
1929         mOut << " {}";
1930     }
1931 
1932     if (elseNode)
1933     {
1934         mOut << "\n";
1935         emitIndentation();
1936         mOut << "else\n";
1937         elseNode->traverse(this);
1938     }
1939     else
1940     {
1941         // Always emit "else" even when empty block to avoid nested if-stmt issues.
1942         mOut << " else {}";
1943     }
1944 
1945     return false;
1946 }
1947 
visitSwitch(Visit,TIntermSwitch * switchNode)1948 bool GenMetalTraverser::visitSwitch(Visit, TIntermSwitch *switchNode)
1949 {
1950     emitIndentation();
1951     mOut << "switch (";
1952     switchNode->getInit()->traverse(this);
1953     mOut << ")\n";
1954 
1955     ASSERT(!mParentIsSwitch);
1956     mParentIsSwitch = true;
1957     switchNode->getStatementList()->traverse(this);
1958     mParentIsSwitch = false;
1959 
1960     return false;
1961 }
1962 
visitCase(Visit,TIntermCase * caseNode)1963 bool GenMetalTraverser::visitCase(Visit, TIntermCase *caseNode)
1964 {
1965     emitIndentation();
1966 
1967     if (caseNode->hasCondition())
1968     {
1969         TIntermTyped *condExpr = caseNode->getCondition();
1970         mOut << "case ";
1971         condExpr->traverse(this);
1972         mOut << ":";
1973     }
1974     else
1975     {
1976         mOut << "default:\n";
1977     }
1978 
1979     return false;
1980 }
1981 
emitFunctionSignature(const TFunction & func)1982 void GenMetalTraverser::emitFunctionSignature(const TFunction &func)
1983 {
1984     const bool isMain = func.isMain();
1985 
1986     emitFunctionReturn(func);
1987 
1988     mOut << " ";
1989     emitNameOf(func);
1990     if (isMain)
1991     {
1992         mOut << "0";
1993     }
1994     mOut << "(";
1995 
1996     bool emitComma          = false;
1997     const size_t paramCount = func.getParamCount();
1998     for (size_t i = 0; i < paramCount; ++i)
1999     {
2000         if (emitComma)
2001         {
2002             mOut << ", ";
2003         }
2004         emitComma = true;
2005 
2006         const TVariable &param = *func.getParam(i);
2007         emitFunctionParameter(func, param);
2008     }
2009 
2010     if (isTraversingVertexMain)
2011     {
2012         mOut << " @@XFB-Bindings@@ ";
2013     }
2014 
2015     mOut << ")";
2016 }
2017 
emitFunctionReturn(const TFunction & func)2018 void GenMetalTraverser::emitFunctionReturn(const TFunction &func)
2019 {
2020     const bool isMain       = func.isMain();
2021     bool isVertexMain       = false;
2022     const TType &returnType = func.getReturnType();
2023     if (isMain)
2024     {
2025         const TStructure *structure = returnType.getStruct();
2026         if (mPipelineStructs.fragmentOut.matches(*structure))
2027         {
2028             if (mCompiler.isEarlyFragmentTestsSpecified())
2029             {
2030                 mOut << "[[early_fragment_tests]]\n";
2031             }
2032             mOut << "fragment ";
2033         }
2034         else if (mPipelineStructs.vertexOut.matches(*structure))
2035         {
2036             ASSERT(structure != nullptr);
2037             mOut << "vertex __VERTEX_OUT(";
2038             isVertexMain = true;
2039         }
2040         else
2041         {
2042             UNIMPLEMENTED();
2043         }
2044     }
2045     emitType(returnType, EmitTypeConfig());
2046     if (isVertexMain)
2047         mOut << ") ";
2048 }
2049 
emitFunctionParameter(const TFunction & func,const TVariable & param)2050 void GenMetalTraverser::emitFunctionParameter(const TFunction &func, const TVariable &param)
2051 {
2052     const bool isMain = func.isMain();
2053 
2054     const TType &type           = param.getType();
2055     const TStructure *structure = type.getStruct();
2056 
2057     EmitVariableDeclarationConfig evdConfig;
2058     evdConfig.isParameter            = true;
2059     evdConfig.disableStructSpecifier = true;  // It's invalid to declare a struct in a function arg.
2060     evdConfig.isMainParameter        = isMain;
2061     evdConfig.emitPostQualifier      = isMain;
2062     evdConfig.isUBO                  = mSymbolEnv.isUBO(param);
2063     evdConfig.isPointer              = mSymbolEnv.isPointer(param);
2064     evdConfig.isReference            = mSymbolEnv.isReference(param);
2065     emitVariableDeclaration(VarDecl(param), evdConfig);
2066 
2067     if (isMain)
2068     {
2069         TranslatorMetalReflection *reflection = mtl::getTranslatorMetalReflection(&mCompiler);
2070         if (structure)
2071         {
2072             if (mPipelineStructs.fragmentIn.matches(*structure) ||
2073                 mPipelineStructs.vertexIn.matches(*structure))
2074             {
2075                 mOut << " [[stage_in]]";
2076             }
2077             else if (mPipelineStructs.angleUniforms.matches(*structure))
2078             {
2079                 mOut << " [[buffer(" << mDriverUniformsBindingIndex << ")]]";
2080             }
2081             else if (mPipelineStructs.uniformBuffers.matches(*structure))
2082             {
2083                 mOut << " [[buffer(" << mUBOArgumentBufferBindingIndex << ")]]";
2084                 reflection->hasUBOs = true;
2085             }
2086             else if (mPipelineStructs.userUniforms.matches(*structure))
2087             {
2088                 mOut << " [[buffer(" << mMainUniformBufferIndex << ")]]";
2089                 reflection->addUserUniformBufferBinding(param.name().data(),
2090                                                         mMainUniformBufferIndex);
2091                 mMainUniformBufferIndex += type.getArraySizeProduct();
2092             }
2093             else if (structure->name() == "metal::sampler")
2094             {
2095                 mOut << " [[sampler(" << (mMainSamplerIndex) << ")]]";
2096                 const std::string originalName =
2097                     reflection->getOriginalName(param.uniqueId().get());
2098                 reflection->addSamplerBinding(originalName, mMainSamplerIndex);
2099                 mMainSamplerIndex += type.getArraySizeProduct();
2100             }
2101         }
2102         else if (IsSampler(type.getBasicType()))
2103         {
2104             mOut << " [[texture(" << (mMainTextureIndex) << ")]]";
2105             const std::string originalName = reflection->getOriginalName(param.uniqueId().get());
2106             reflection->addTextureBinding(originalName, mMainTextureIndex);
2107             mMainTextureIndex += type.getArraySizeProduct();
2108         }
2109         else if (Name(param) == Pipeline{Pipeline::Type::InstanceId, nullptr}.getStructInstanceName(
2110                                     Pipeline::Variant::Modified))
2111         {
2112             mOut << " [[instance_id]]";
2113         }
2114         else if (Name(param) == kBaseInstanceName)
2115         {
2116             mOut << " [[base_instance]]";
2117         }
2118     }
2119 }
2120 
visitFunctionPrototype(TIntermFunctionPrototype * funcProtoNode)2121 void GenMetalTraverser::visitFunctionPrototype(TIntermFunctionPrototype *funcProtoNode)
2122 {
2123     const TFunction &func = *funcProtoNode->getFunction();
2124 
2125     emitIndentation();
2126     emitFunctionSignature(func);
2127 }
2128 
visitFunctionDefinition(Visit,TIntermFunctionDefinition * funcDefNode)2129 bool GenMetalTraverser::visitFunctionDefinition(Visit, TIntermFunctionDefinition *funcDefNode)
2130 {
2131     const TFunction &func = *funcDefNode->getFunction();
2132     TIntermBlock &body    = *funcDefNode->getBody();
2133     if (func.isMain())
2134     {
2135         const TType &returnType     = func.getReturnType();
2136         const TStructure *structure = returnType.getStruct();
2137         isTraversingVertexMain      = (mPipelineStructs.vertexOut.matches(*structure)) &&
2138                                  mCompiler.getShaderType() == GL_VERTEX_SHADER;
2139     }
2140     emitIndentation();
2141     emitFunctionSignature(func);
2142     mOut << "\n";
2143     body.traverse(this);
2144     if (isTraversingVertexMain)
2145     {
2146         isTraversingVertexMain = false;
2147     }
2148     return false;
2149 }
2150 
BuildFuncToName()2151 GenMetalTraverser::FuncToName GenMetalTraverser::BuildFuncToName()
2152 {
2153     FuncToName map;
2154 
2155     auto putAngle = [&](const char *nameStr) {
2156         const ImmutableString name(nameStr);
2157         ASSERT(map.find(name) == map.end());
2158         map[name] = Name(nameStr, SymbolType::AngleInternal);
2159     };
2160 
2161     putAngle("texelFetch");
2162     putAngle("texelFetchOffset");
2163     putAngle("texture");
2164     putAngle("texture1D");
2165     putAngle("texture1DLod");
2166     putAngle("texture1DProjLod");
2167     putAngle("texture2D");
2168     putAngle("texture2DGradEXT");
2169     putAngle("texture2DLod");
2170     putAngle("texture2DLodEXT");
2171     putAngle("texture2DProj");
2172     putAngle("texture2DProjGradEXT");
2173     putAngle("texture2DProjLod");
2174     putAngle("texture2DProjLodEXT");
2175     putAngle("texture2DRect");
2176     putAngle("texture2DRectProj");
2177     putAngle("texture3D");
2178     putAngle("texture3DLod");
2179     putAngle("texture3DProjLod");
2180     putAngle("textureCube");
2181     putAngle("textureCubeGradEXT");
2182     putAngle("textureCubeLod");
2183     putAngle("textureCubeLodEXT");
2184     putAngle("textureCubeProjLod");
2185     putAngle("textureGrad");
2186     putAngle("textureGradOffset");
2187     putAngle("textureLod");
2188     putAngle("textureLodOffset");
2189     putAngle("textureOffset");
2190     putAngle("textureProj");
2191     putAngle("textureProjGrad");
2192     putAngle("textureProjGradOffset");
2193     putAngle("textureProjLod");
2194     putAngle("textureProjLodOffset");
2195     putAngle("textureProjOffset");
2196     putAngle("textureSize");
2197     putAngle("imageLoad");
2198     putAngle("imageStore");
2199     putAngle("memoryBarrierImage");
2200 
2201     return map;
2202 }
2203 
visitAggregate(Visit,TIntermAggregate * aggregateNode)2204 bool GenMetalTraverser::visitAggregate(Visit, TIntermAggregate *aggregateNode)
2205 {
2206     const TIntermSequence &args = *aggregateNode->getSequence();
2207 
2208     auto emitArgList = [&](const char *open, const char *close) {
2209         mOut << open;
2210 
2211         bool emitComma = false;
2212         for (TIntermNode *arg : args)
2213         {
2214             if (emitComma)
2215             {
2216                 emitClosingPointerParen();
2217                 mOut << ", ";
2218             }
2219             emitComma = true;
2220             arg->traverse(this);
2221         }
2222 
2223         mOut << close;
2224     };
2225 
2226     const TType &retType = aggregateNode->getType();
2227 
2228     if (aggregateNode->isConstructor())
2229     {
2230         const bool isStandalone = getParentNode()->getAsBlock();
2231         if (isStandalone)
2232         {
2233             // Prevent constructor from being interpreted as a declaration by wrapping in parens.
2234             // This can happen if given something like:
2235             //      int(symbol); // <- This will be treated like `int symbol;`... don't want that.
2236             // So instead emit:
2237             //      (int(symbol));
2238             mOut << "(";
2239         }
2240 
2241         const EmitTypeConfig etConfig;
2242 
2243         if (retType.isArray())
2244         {
2245             emitType(retType, etConfig);
2246             emitArgList("{", "}");
2247         }
2248         else if (retType.getStruct())
2249         {
2250             emitType(retType, etConfig);
2251             emitArgList("{", "}");
2252         }
2253         else
2254         {
2255             emitType(retType, etConfig);
2256             emitArgList("(", ")");
2257         }
2258 
2259         if (isStandalone)
2260         {
2261             mOut << ")";
2262         }
2263 
2264         return false;
2265     }
2266     else
2267     {
2268         const TOperator op = aggregateNode->getOp();
2269         switch (op)
2270         {
2271             case TOperator::EOpCallFunctionInAST:
2272             case TOperator::EOpCallInternalRawFunction:
2273             {
2274                 const TFunction &func = *aggregateNode->getFunction();
2275                 emitNameOf(func);
2276                 //'@' symbol in name specifices a macro substitution marker.
2277                 if (!func.name().contains("@"))
2278                 {
2279                     emitArgList("(", ")");
2280                 }
2281                 else
2282                 {
2283                     mTemporarilyDisableSemicolon =
2284                         true;  // Disable semicolon for macro substitution.
2285                 }
2286                 return false;
2287             }
2288 
2289             default:
2290             {
2291                 auto getArgType = [&](size_t index) -> const TType * {
2292                     if (index < args.size())
2293                     {
2294                         TIntermTyped *arg = args[index]->getAsTyped();
2295                         ASSERT(arg);
2296                         return &arg->getType();
2297                     }
2298                     return nullptr;
2299                 };
2300 
2301                 const TType *argType0 = getArgType(0);
2302                 const TType *argType1 = getArgType(1);
2303                 const TType *argType2 = getArgType(2);
2304 
2305                 const char *opName = GetOperatorString(op, retType, argType0, argType1, argType2);
2306 
2307                 if (IsSymbolicOperator(op, retType, argType0, argType1))
2308                 {
2309                     switch (args.size())
2310                     {
2311                         case 1:
2312                         {
2313                             TIntermNode &operandNode = *aggregateNode->getChildNode(0);
2314                             if (IsPostfix(op))
2315                             {
2316                                 mOut << opName;
2317                                 groupedTraverse(operandNode);
2318                             }
2319                             else
2320                             {
2321                                 groupedTraverse(operandNode);
2322                                 mOut << opName;
2323                             }
2324                             return false;
2325                         }
2326 
2327                         case 2:
2328                         {
2329                             TIntermNode &leftNode  = *aggregateNode->getChildNode(0);
2330                             TIntermNode &rightNode = *aggregateNode->getChildNode(1);
2331                             groupedTraverse(leftNode);
2332                             mOut << " " << opName << " ";
2333                             groupedTraverse(rightNode);
2334                             return false;
2335                         }
2336 
2337                         default:
2338                             UNREACHABLE();
2339                             return false;
2340                     }
2341                 }
2342                 else if (opName == nullptr)
2343                 {
2344                     const TFunction &func = *aggregateNode->getFunction();
2345                     auto it               = mFuncToName.find(func.name());
2346                     ASSERT(it != mFuncToName.end());
2347                     EmitName(mOut, it->second);
2348                     emitArgList("(", ")");
2349                     return false;
2350                 }
2351                 else
2352                 {
2353                     mOut << opName;
2354                     emitArgList("(", ")");
2355                     return false;
2356                 }
2357             }
2358         }
2359     }
2360 }
2361 
emitOpenBrace()2362 void GenMetalTraverser::emitOpenBrace()
2363 {
2364     ASSERT(mIndentLevel >= 0);
2365 
2366     emitIndentation();
2367     mOut << "{\n";
2368     ++mIndentLevel;
2369 }
2370 
emitCloseBrace()2371 void GenMetalTraverser::emitCloseBrace()
2372 {
2373     ASSERT(mIndentLevel >= 1);
2374 
2375     --mIndentLevel;
2376     emitIndentation();
2377     mOut << "}";
2378 }
2379 
RequiresSemicolonTerminator(TIntermNode & node)2380 static bool RequiresSemicolonTerminator(TIntermNode &node)
2381 {
2382     if (node.getAsBlock())
2383     {
2384         return false;
2385     }
2386     if (node.getAsLoopNode())
2387     {
2388         return false;
2389     }
2390     if (node.getAsSwitchNode())
2391     {
2392         return false;
2393     }
2394     if (node.getAsIfElseNode())
2395     {
2396         return false;
2397     }
2398     if (node.getAsFunctionDefinition())
2399     {
2400         return false;
2401     }
2402     if (node.getAsCaseNode())
2403     {
2404         return false;
2405     }
2406 
2407     return true;
2408 }
2409 
NewlinePad(TIntermNode & node)2410 static bool NewlinePad(TIntermNode &node)
2411 {
2412     if (node.getAsFunctionDefinition())
2413     {
2414         return true;
2415     }
2416     if (TIntermDeclaration *declNode = node.getAsDeclarationNode())
2417     {
2418         ASSERT(declNode->getChildCount() == 1);
2419         TIntermNode &childNode = *declNode->getChildNode(0);
2420         if (TIntermSymbol *symbolNode = childNode.getAsSymbolNode())
2421         {
2422             const TVariable &var = symbolNode->variable();
2423             return var.getType().isStructSpecifier();
2424         }
2425         return false;
2426     }
2427     return false;
2428 }
2429 
visitBlock(Visit,TIntermBlock * blockNode)2430 bool GenMetalTraverser::visitBlock(Visit, TIntermBlock *blockNode)
2431 {
2432     ASSERT(mIndentLevel >= -1);
2433     const bool isGlobalScope  = mIndentLevel == -1;
2434     const bool parentIsSwitch = mParentIsSwitch;
2435     mParentIsSwitch           = false;
2436 
2437     if (isGlobalScope)
2438     {
2439         ++mIndentLevel;
2440     }
2441     else
2442     {
2443         emitOpenBrace();
2444         if (parentIsSwitch)
2445         {
2446             ++mIndentLevel;
2447         }
2448     }
2449 
2450     TIntermNode *prevStmtNode = nullptr;
2451 
2452     const size_t stmtCount = blockNode->getChildCount();
2453     for (size_t i = 0; i < stmtCount; ++i)
2454     {
2455         TIntermNode &stmtNode = *blockNode->getChildNode(i);
2456 
2457         if (isGlobalScope && prevStmtNode && (NewlinePad(*prevStmtNode) || NewlinePad(stmtNode)))
2458         {
2459             mOut << "\n";
2460         }
2461         const bool isCase = stmtNode.getAsCaseNode();
2462         mIndentLevel -= isCase;
2463         emitIndentation();
2464         mIndentLevel += isCase;
2465         stmtNode.traverse(this);
2466         if (RequiresSemicolonTerminator(stmtNode) && !mTemporarilyDisableSemicolon)
2467         {
2468             mOut << ";";
2469         }
2470         mTemporarilyDisableSemicolon = false;
2471         mOut << "\n";
2472 
2473         prevStmtNode = &stmtNode;
2474     }
2475 
2476     if (isGlobalScope)
2477     {
2478         ASSERT(mIndentLevel == 0);
2479         --mIndentLevel;
2480     }
2481     else
2482     {
2483         if (parentIsSwitch)
2484         {
2485             ASSERT(mIndentLevel >= 1);
2486             --mIndentLevel;
2487         }
2488         emitCloseBrace();
2489         mParentIsSwitch = parentIsSwitch;
2490     }
2491 
2492     return false;
2493 }
2494 
visitGlobalQualifierDeclaration(Visit,TIntermGlobalQualifierDeclaration *)2495 bool GenMetalTraverser::visitGlobalQualifierDeclaration(Visit, TIntermGlobalQualifierDeclaration *)
2496 {
2497     return false;
2498 }
2499 
visitDeclaration(Visit,TIntermDeclaration * declNode)2500 bool GenMetalTraverser::visitDeclaration(Visit, TIntermDeclaration *declNode)
2501 {
2502     ASSERT(declNode->getChildCount() == 1);
2503     TIntermNode &node = *declNode->getChildNode(0);
2504 
2505     EmitVariableDeclarationConfig evdConfig;
2506 
2507     if (TIntermSymbol *symbolNode = node.getAsSymbolNode())
2508     {
2509         const TVariable &var = symbolNode->variable();
2510         emitVariableDeclaration(VarDecl(var), evdConfig);
2511         if (var.getType().isArray() && var.getType().getQualifier() == EvqTemporary)
2512         {
2513             // The translator frontend injects a loop-based init for user arrays when the source
2514             // shader is using ESSL 1.00. Some Metal drivers may fail to access elements of such
2515             // arrays at runtime depending on the array size. An empty literal initializer added
2516             // to the generated MSL bypasses the issue. The frontend may be further optimized to
2517             // skip the loop-based init when targeting MSL.
2518             mOut << "{}";
2519         }
2520     }
2521     else if (TIntermBinary *initNode = node.getAsBinaryNode())
2522     {
2523         ASSERT(initNode->getOp() == TOperator::EOpInitialize);
2524         TIntermSymbol *leftSymbolNode = initNode->getLeft()->getAsSymbolNode();
2525         TIntermTyped *valueNode       = initNode->getRight()->getAsTyped();
2526         ASSERT(leftSymbolNode && valueNode);
2527 
2528         if (getRootNode() == getParentBlock())
2529         {
2530             // DeferGlobalInitializers should have turned non-const global initializers into
2531             // deferred initializers. Note that variables marked as EvqGlobal can be treated as
2532             // EvqConst in some ANGLE code but not actually have their qualifier actually changed to
2533             // EvqConst. Thus just assume all EvqGlobal are actually EvqConst for all code run after
2534             // DeferGlobalInitializers.
2535             mOut << "constant ";
2536         }
2537 
2538         const TVariable &var = leftSymbolNode->variable();
2539         const Name varName(var);
2540 
2541         if (ExpressionContainsName(varName, *valueNode))
2542         {
2543             mRenamedSymbols[&var] = mIdGen.createNewName(varName);
2544         }
2545 
2546         emitVariableDeclaration(VarDecl(var), evdConfig);
2547         mOut << " = ";
2548         groupedTraverse(*valueNode);
2549     }
2550     else
2551     {
2552         UNREACHABLE();
2553     }
2554 
2555     return false;
2556 }
2557 
visitLoop(Visit,TIntermLoop * loopNode)2558 bool GenMetalTraverser::visitLoop(Visit, TIntermLoop *loopNode)
2559 {
2560     const TLoopType loopType = loopNode->getType();
2561 
2562     switch (loopType)
2563     {
2564         case TLoopType::ELoopFor:
2565             return visitForLoop(loopNode);
2566         case TLoopType::ELoopWhile:
2567             return visitWhileLoop(loopNode);
2568         case TLoopType::ELoopDoWhile:
2569             return visitDoWhileLoop(loopNode);
2570     }
2571 }
2572 
visitForLoop(TIntermLoop * loopNode)2573 bool GenMetalTraverser::visitForLoop(TIntermLoop *loopNode)
2574 {
2575     ASSERT(loopNode->getType() == TLoopType::ELoopFor);
2576 
2577     TIntermNode *initNode  = loopNode->getInit();
2578     TIntermTyped *condNode = loopNode->getCondition();
2579     TIntermTyped *exprNode = loopNode->getExpression();
2580 
2581     mOut << "for (";
2582 
2583     if (initNode)
2584     {
2585         initNode->traverse(this);
2586     }
2587     else
2588     {
2589         mOut << " ";
2590     }
2591 
2592     mOut << "; ";
2593 
2594     if (condNode)
2595     {
2596         condNode->traverse(this);
2597     }
2598 
2599     mOut << "; ";
2600 
2601     if (exprNode)
2602     {
2603         exprNode->traverse(this);
2604     }
2605 
2606     mOut << ")\n";
2607 
2608     emitLoopBody(loopNode->getBody());
2609 
2610     return false;
2611 }
2612 
visitWhileLoop(TIntermLoop * loopNode)2613 bool GenMetalTraverser::visitWhileLoop(TIntermLoop *loopNode)
2614 {
2615     ASSERT(loopNode->getType() == TLoopType::ELoopWhile);
2616 
2617     TIntermNode *initNode  = loopNode->getInit();
2618     TIntermTyped *condNode = loopNode->getCondition();
2619     TIntermTyped *exprNode = loopNode->getExpression();
2620     ASSERT(condNode);
2621     ASSERT(!initNode && !exprNode);
2622 
2623     emitIndentation();
2624     mOut << "while (";
2625     condNode->traverse(this);
2626     mOut << ")\n";
2627     emitLoopBody(loopNode->getBody());
2628 
2629     return false;
2630 }
2631 
visitDoWhileLoop(TIntermLoop * loopNode)2632 bool GenMetalTraverser::visitDoWhileLoop(TIntermLoop *loopNode)
2633 {
2634     ASSERT(loopNode->getType() == TLoopType::ELoopDoWhile);
2635 
2636     TIntermNode *initNode  = loopNode->getInit();
2637     TIntermTyped *condNode = loopNode->getCondition();
2638     TIntermTyped *exprNode = loopNode->getExpression();
2639     ASSERT(condNode);
2640     ASSERT(!initNode && !exprNode);
2641 
2642     emitIndentation();
2643     mOut << "do\n";
2644     emitLoopBody(loopNode->getBody());
2645     mOut << "\n";
2646     emitIndentation();
2647     mOut << "while (";
2648     condNode->traverse(this);
2649     mOut << ");";
2650 
2651     return false;
2652 }
2653 
visitBranch(Visit,TIntermBranch * branchNode)2654 bool GenMetalTraverser::visitBranch(Visit, TIntermBranch *branchNode)
2655 {
2656     const TOperator flowOp = branchNode->getFlowOp();
2657     TIntermTyped *exprNode = branchNode->getExpression();
2658 
2659     emitIndentation();
2660 
2661     switch (flowOp)
2662     {
2663         case TOperator::EOpKill:
2664         {
2665             ASSERT(exprNode == nullptr);
2666             mOut << "metal::discard_fragment()";
2667         }
2668         break;
2669 
2670         case TOperator::EOpReturn:
2671         {
2672             if (isTraversingVertexMain)
2673             {
2674                 mOut << "#if TRANSFORM_FEEDBACK_ENABLED\n";
2675                 emitIndentation();
2676                 mOut << "return;\n";
2677                 emitIndentation();
2678                 mOut << "#else\n";
2679                 emitIndentation();
2680             }
2681             mOut << "return";
2682             if (exprNode)
2683             {
2684                 mOut << " ";
2685                 exprNode->traverse(this);
2686                 mOut << ";";
2687             }
2688             if (isTraversingVertexMain)
2689             {
2690                 mOut << "\n";
2691                 emitIndentation();
2692                 mOut << "#endif\n";
2693                 mTemporarilyDisableSemicolon = true;
2694             }
2695         }
2696         break;
2697 
2698         case TOperator::EOpBreak:
2699         {
2700             ASSERT(exprNode == nullptr);
2701             mOut << "break";
2702         }
2703         break;
2704 
2705         case TOperator::EOpContinue:
2706         {
2707             ASSERT(exprNode == nullptr);
2708             mOut << "continue";
2709         }
2710         break;
2711 
2712         default:
2713         {
2714             UNREACHABLE();
2715         }
2716     }
2717 
2718     return false;
2719 }
2720 
2721 static size_t emitMetalCallCount = 0;
2722 
EmitMetal(TCompiler & compiler,TIntermBlock & root,IdGen & idGen,const PipelineStructs & pipelineStructs,SymbolEnv & symbolEnv,const ProgramPreludeConfig & ppc,const ShCompileOptions & compileOptions)2723 bool sh::EmitMetal(TCompiler &compiler,
2724                    TIntermBlock &root,
2725                    IdGen &idGen,
2726                    const PipelineStructs &pipelineStructs,
2727                    SymbolEnv &symbolEnv,
2728                    const ProgramPreludeConfig &ppc,
2729                    const ShCompileOptions &compileOptions)
2730 {
2731     TInfoSinkBase &out = compiler.getInfoSink().obj;
2732 
2733     {
2734         ++emitMetalCallCount;
2735         std::string filenameProto = angle::GetEnvironmentVar("GMD_FIXED_EMIT");
2736         if (!filenameProto.empty())
2737         {
2738             if (filenameProto != "/dev/null")
2739             {
2740                 auto tryOpen = [&](char const *ext) {
2741                     auto filename = filenameProto;
2742                     filename += std::to_string(emitMetalCallCount);
2743                     filename += ".";
2744                     filename += ext;
2745                     return fopen(filename.c_str(), "rb");
2746                 };
2747                 FILE *file = tryOpen("metal");
2748                 if (!file)
2749                 {
2750                     file = tryOpen("cpp");
2751                 }
2752                 ASSERT(file);
2753 
2754                 fseek(file, 0, SEEK_END);
2755                 size_t fileSize = ftell(file);
2756                 fseek(file, 0, SEEK_SET);
2757 
2758                 std::vector<char> buff;
2759                 buff.resize(fileSize + 1);
2760                 fread(buff.data(), fileSize, 1, file);
2761                 buff.back() = '\0';
2762 
2763                 fclose(file);
2764 
2765                 out << buff.data();
2766             }
2767 
2768             return true;
2769         }
2770     }
2771 
2772     out << "\n\n";
2773 
2774     if (!EmitProgramPrelude(root, out, ppc))
2775     {
2776         return false;
2777     }
2778 
2779     {
2780 #if defined(ANGLE_ENABLE_ASSERTS)
2781         DebugSink outWrapper(out, angle::GetBoolEnvironmentVar("GMD_STDOUT"));
2782         outWrapper.watch(angle::GetEnvironmentVar("GMD_WATCH_STRING"));
2783 #else
2784         TInfoSinkBase &outWrapper = out;
2785 #endif
2786         GenMetalTraverser gen(compiler, outWrapper, idGen, pipelineStructs, symbolEnv,
2787                               compileOptions);
2788         root.traverse(&gen);
2789     }
2790 
2791     out << "\n";
2792 
2793     return true;
2794 }
2795