• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include <cctype>
8 #include <map>
9 
10 #include "common/system_utils.h"
11 #include "compiler/translator/BaseTypes.h"
12 #include "compiler/translator/ImmutableStringBuilder.h"
13 #include "compiler/translator/SymbolTable.h"
14 #include "compiler/translator/msl/AstHelpers.h"
15 #include "compiler/translator/msl/DebugSink.h"
16 #include "compiler/translator/msl/EmitMetal.h"
17 #include "compiler/translator/msl/Layout.h"
18 #include "compiler/translator/msl/Name.h"
19 #include "compiler/translator/msl/ProgramPrelude.h"
20 #include "compiler/translator/msl/RewritePipelines.h"
21 #include "compiler/translator/msl/TranslatorMSL.h"
22 #include "compiler/translator/tree_util/IntermTraverse.h"
23 
24 using namespace sh;
25 
26 ////////////////////////////////////////////////////////////////////////////////
27 
28 #if defined(ANGLE_ENABLE_ASSERTS)
29 using Sink = DebugSink;
30 #else
31 using Sink = TInfoSinkBase;
32 #endif
33 
34 ////////////////////////////////////////////////////////////////////////////////
35 
36 namespace
37 {
38 
39 struct VarDecl
40 {
VarDecl__anonb64d01340111::VarDecl41     explicit VarDecl(const TVariable &var) : mVariable(&var), mIsField(false) {}
VarDecl__anonb64d01340111::VarDecl42     explicit VarDecl(const TField &field) : mField(&field), mIsField(true) {}
43 
variable__anonb64d01340111::VarDecl44     ANGLE_INLINE const TVariable &variable() const
45     {
46         ASSERT(isVariable());
47         return *mVariable;
48     }
49 
field__anonb64d01340111::VarDecl50     ANGLE_INLINE const TField &field() const
51     {
52         ASSERT(isField());
53         return *mField;
54     }
55 
isVariable__anonb64d01340111::VarDecl56     ANGLE_INLINE bool isVariable() const { return !mIsField; }
57 
isField__anonb64d01340111::VarDecl58     ANGLE_INLINE bool isField() const { return mIsField; }
59 
type__anonb64d01340111::VarDecl60     const TType &type() const { return isField() ? *field().type() : variable().getType(); }
61 
symbolType__anonb64d01340111::VarDecl62     SymbolType symbolType() const
63     {
64         return isField() ? field().symbolType() : variable().symbolType();
65     }
66 
67   private:
68     union
69     {
70         const TVariable *mVariable;
71         const TField *mField;
72     };
73     bool mIsField;
74 };
75 
76 class GenMetalTraverser : public TIntermTraverser
77 {
78   public:
79     ~GenMetalTraverser() override;
80 
81     GenMetalTraverser(const TCompiler &compiler,
82                       Sink &out,
83                       IdGen &idGen,
84                       const PipelineStructs &pipelineStructs,
85                       SymbolEnv &symbolEnv,
86                       const ShCompileOptions &compileOptions);
87 
88     void visitSymbol(TIntermSymbol *) override;
89     void visitConstantUnion(TIntermConstantUnion *) override;
90     bool visitSwizzle(Visit, TIntermSwizzle *) override;
91     bool visitBinary(Visit, TIntermBinary *) override;
92     bool visitUnary(Visit, TIntermUnary *) override;
93     bool visitTernary(Visit, TIntermTernary *) override;
94     bool visitIfElse(Visit, TIntermIfElse *) override;
95     bool visitSwitch(Visit, TIntermSwitch *) override;
96     bool visitCase(Visit, TIntermCase *) override;
97     void visitFunctionPrototype(TIntermFunctionPrototype *) override;
98     bool visitFunctionDefinition(Visit, TIntermFunctionDefinition *) override;
99     bool visitAggregate(Visit, TIntermAggregate *) override;
100     bool visitBlock(Visit, TIntermBlock *) override;
101     bool visitGlobalQualifierDeclaration(Visit, TIntermGlobalQualifierDeclaration *) override;
102     bool visitDeclaration(Visit, TIntermDeclaration *) override;
103     bool visitLoop(Visit, TIntermLoop *) override;
104     bool visitForLoop(TIntermLoop *);
105     bool visitWhileLoop(TIntermLoop *);
106     bool visitDoWhileLoop(TIntermLoop *);
107     bool visitBranch(Visit, TIntermBranch *) override;
108 
109   private:
110     using FuncToName = std::map<ImmutableString, Name>;
111     static FuncToName BuildFuncToName();
112 
113     struct EmitVariableDeclarationConfig
114     {
115         bool isParameter                = false;
116         bool isMainParameter            = false;
117         bool emitPostQualifier          = false;
118         bool isPacked                   = false;
119         bool disableStructSpecifier     = false;
120         bool isUBO                      = false;
121         const AddressSpace *isPointer   = nullptr;
122         const AddressSpace *isReference = nullptr;
123     };
124 
125     struct EmitTypeConfig
126     {
127         const EmitVariableDeclarationConfig *evdConfig = nullptr;
128     };
129 
130     void emitIndentation();
131     void emitOpeningPointerParen();
132     void emitClosingPointerParen();
133     void emitFunctionSignature(const TFunction &func);
134     void emitFunctionReturn(const TFunction &func);
135     void emitFunctionParameter(const TFunction &func, const TVariable &param);
136 
137     void emitNameOf(const TField &object);
138     void emitNameOf(const TSymbol &object);
139     void emitNameOf(const VarDecl &object);
140 
141     void emitBareTypeName(const TType &type, const EmitTypeConfig &etConfig);
142     void emitType(const TType &type, const EmitTypeConfig &etConfig);
143     void emitPostQualifier(const EmitVariableDeclarationConfig &evdConfig,
144                            const VarDecl &decl,
145                            const TQualifier qualifier);
146 
147     struct FieldAnnotationIndices
148     {
149         size_t attribute = 0;
150         size_t color     = 0;
151     };
152 
153     void emitFieldDeclaration(const TField &field,
154                               const TStructure &parent,
155                               FieldAnnotationIndices &annotationIndices);
156     void emitAttributeDeclaration(const TField &field, FieldAnnotationIndices &annotationIndices);
157     void emitUniformBufferDeclaration(const TField &field,
158                                       FieldAnnotationIndices &annotationIndices);
159     void emitStructDeclaration(const TType &type);
160     void emitOrdinaryVariableDeclaration(const VarDecl &decl,
161                                          const EmitVariableDeclarationConfig &evdConfig);
162     void emitVariableDeclaration(const VarDecl &decl,
163                                  const EmitVariableDeclarationConfig &evdConfig);
164 
165     void emitOpenBrace();
166     void emitCloseBrace();
167 
168     void groupedTraverse(TIntermNode &node);
169 
170     const TField &getDirectField(const TFieldListCollection &fieldsNode,
171                                  const TConstantUnion &index);
172     const TField &getDirectField(const TIntermTyped &fieldsNode, TIntermTyped &indexNode);
173 
174     const TConstantUnion *emitConstantUnionArray(const TConstantUnion *const constUnion,
175                                                  const size_t size);
176 
177     const TConstantUnion *emitConstantUnion(const TType &type, const TConstantUnion *constUnion);
178 
179     void emitSingleConstant(const TConstantUnion *const constUnion);
180 
181   private:
182     Sink &mOut;
183     const TCompiler &mCompiler;
184     const PipelineStructs &mPipelineStructs;
185     SymbolEnv &mSymbolEnv;
186     IdGen &mIdGen;
187     int mIndentLevel                  = -1;
188     int mLastIndentationPos           = -1;
189     int mOpenPointerParenCount        = 0;
190     bool mParentIsSwitch              = false;
191     bool isTraversingVertexMain       = false;
192     bool mTemporarilyDisableSemicolon = false;
193     std::unordered_map<const TSymbol *, Name> mRenamedSymbols;
194     const FuncToName mFuncToName          = BuildFuncToName();
195     size_t mMainTextureIndex              = 0;
196     size_t mMainSamplerIndex              = 0;
197     size_t mMainUniformBufferIndex        = 0;
198     size_t mDriverUniformsBindingIndex    = 0;
199     size_t mUBOArgumentBufferBindingIndex = 0;
200     bool mRasterOrderGroupsSupported      = false;
201 };
202 }  // anonymous namespace
203 
~GenMetalTraverser()204 GenMetalTraverser::~GenMetalTraverser()
205 {
206     ASSERT(mIndentLevel == -1);
207     ASSERT(!mParentIsSwitch);
208     ASSERT(mOpenPointerParenCount == 0);
209 }
210 
GenMetalTraverser(const TCompiler & compiler,Sink & out,IdGen & idGen,const PipelineStructs & pipelineStructs,SymbolEnv & symbolEnv,const ShCompileOptions & compileOptions)211 GenMetalTraverser::GenMetalTraverser(const TCompiler &compiler,
212                                      Sink &out,
213                                      IdGen &idGen,
214                                      const PipelineStructs &pipelineStructs,
215                                      SymbolEnv &symbolEnv,
216                                      const ShCompileOptions &compileOptions)
217     : TIntermTraverser(true, false, false),
218       mOut(out),
219       mCompiler(compiler),
220       mPipelineStructs(pipelineStructs),
221       mSymbolEnv(symbolEnv),
222       mIdGen(idGen),
223       mMainUniformBufferIndex(compileOptions.metal.defaultUniformsBindingIndex),
224       mDriverUniformsBindingIndex(compileOptions.metal.driverUniformsBindingIndex),
225       mUBOArgumentBufferBindingIndex(compileOptions.metal.UBOArgumentBufferBindingIndex),
226       mRasterOrderGroupsSupported(compileOptions.pls.fragmentSyncType ==
227                                   ShFragmentSynchronizationType::RasterOrderGroups_Metal)
228 {}
229 
emitIndentation()230 void GenMetalTraverser::emitIndentation()
231 {
232     ASSERT(mIndentLevel >= 0);
233 
234     if (mLastIndentationPos == mOut.size())
235     {
236         return;  // Line is already indented.
237     }
238 
239     for (int i = 0; i < mIndentLevel; ++i)
240     {
241         mOut << "  ";
242     }
243 
244     mLastIndentationPos = mOut.size();
245 }
246 
emitOpeningPointerParen()247 void GenMetalTraverser::emitOpeningPointerParen()
248 {
249     mOut << "(*";
250     mOpenPointerParenCount++;
251 }
252 
emitClosingPointerParen()253 void GenMetalTraverser::emitClosingPointerParen()
254 {
255     if (mOpenPointerParenCount > 0)
256     {
257         mOut << ")";
258         mOpenPointerParenCount--;
259     }
260 }
261 
GetOperatorString(TOperator op,const TType & resultType,const TType * argType0,const TType * argType1,const TType * argType2)262 static const char *GetOperatorString(TOperator op,
263                                      const TType &resultType,
264                                      const TType *argType0,
265                                      const TType *argType1,
266                                      const TType *argType2)
267 {
268     switch (op)
269     {
270         case TOperator::EOpComma:
271             return ",";
272         case TOperator::EOpAssign:
273             return "=";
274         case TOperator::EOpInitialize:
275             return "=";
276         case TOperator::EOpAddAssign:
277             return "+=";
278         case TOperator::EOpSubAssign:
279             return "-=";
280         case TOperator::EOpMulAssign:
281             return "*=";
282         case TOperator::EOpDivAssign:
283             return "/=";
284         case TOperator::EOpIModAssign:
285             return "%=";
286         case TOperator::EOpBitShiftLeftAssign:
287             return "<<=";  // TODO: Check logical vs arithmetic shifting.
288         case TOperator::EOpBitShiftRightAssign:
289             return ">>=";  // TODO: Check logical vs arithmetic shifting.
290         case TOperator::EOpBitwiseAndAssign:
291             return "&=";
292         case TOperator::EOpBitwiseXorAssign:
293             return "^=";
294         case TOperator::EOpBitwiseOrAssign:
295             return "|=";
296         case TOperator::EOpAdd:
297             return "+";
298         case TOperator::EOpSub:
299             return "-";
300         case TOperator::EOpMul:
301             return "*";
302         case TOperator::EOpDiv:
303             return "/";
304         case TOperator::EOpIMod:
305             return "%";
306         case TOperator::EOpBitShiftLeft:
307             return "<<";  // TODO: Check logical vs arithmetic shifting.
308         case TOperator::EOpBitShiftRight:
309             return ">>";  // TODO: Check logical vs arithmetic shifting.
310         case TOperator::EOpBitwiseAnd:
311             return "&";
312         case TOperator::EOpBitwiseXor:
313             return "^";
314         case TOperator::EOpBitwiseOr:
315             return "|";
316         case TOperator::EOpLessThan:
317             return "<";
318         case TOperator::EOpGreaterThan:
319             return ">";
320         case TOperator::EOpLessThanEqual:
321             return "<=";
322         case TOperator::EOpGreaterThanEqual:
323             return ">=";
324         case TOperator::EOpLessThanComponentWise:
325             return "<";
326         case TOperator::EOpLessThanEqualComponentWise:
327             return "<=";
328         case TOperator::EOpGreaterThanEqualComponentWise:
329             return ">=";
330         case TOperator::EOpGreaterThanComponentWise:
331             return ">";
332         case TOperator::EOpLogicalOr:
333             return "||";
334         case TOperator::EOpLogicalXor:
335             return "!=/*xor*/";  // XXX: This might need to be handled differently for some obtuse
336                                  // use case.
337         case TOperator::EOpLogicalAnd:
338             return "&&";
339         case TOperator::EOpNegative:
340             return "-";
341         case TOperator::EOpPositive:
342             if (argType0->isMatrix())
343             {
344                 return "";
345             }
346             return "+";
347         case TOperator::EOpLogicalNot:
348             return "!";
349         case TOperator::EOpNotComponentWise:
350             return "!";
351         case TOperator::EOpBitwiseNot:
352             return "~";
353         case TOperator::EOpPostIncrement:
354             return "++";
355         case TOperator::EOpPostDecrement:
356             return "--";
357         case TOperator::EOpPreIncrement:
358             return "++";
359         case TOperator::EOpPreDecrement:
360             return "--";
361         case TOperator::EOpVectorTimesScalarAssign:
362             return "*=";
363         case TOperator::EOpVectorTimesMatrixAssign:
364             return "*=";
365         case TOperator::EOpMatrixTimesScalarAssign:
366             return "*=";
367         case TOperator::EOpMatrixTimesMatrixAssign:
368             return "*=";
369         case TOperator::EOpVectorTimesScalar:
370             return "*";
371         case TOperator::EOpVectorTimesMatrix:
372             return "*";
373         case TOperator::EOpMatrixTimesVector:
374             return "*";
375         case TOperator::EOpMatrixTimesScalar:
376             return "*";
377         case TOperator::EOpMatrixTimesMatrix:
378             return "*";
379         case TOperator::EOpEqualComponentWise:
380             return "==";
381         case TOperator::EOpNotEqualComponentWise:
382             return "!=";
383 
384         case TOperator::EOpEqual:
385             if ((argType0->getStruct() && argType1->getStruct()) &&
386                 (argType0->isArray() && argType1->isArray()))
387             {
388                 return "ANGLE_equalStructArray";
389             }
390 
391             if ((argType0->isVector() && argType1->isVector()) ||
392                 (argType0->getStruct() && argType1->getStruct()) ||
393                 (argType0->isArray() && argType1->isArray()) ||
394                 (argType0->isMatrix() && argType1->isMatrix()))
395 
396             {
397                 return "ANGLE_equal";
398             }
399 
400             return "==";
401 
402         case TOperator::EOpNotEqual:
403             if ((argType0->getStruct() && argType1->getStruct()) &&
404                 (argType0->isArray() && argType1->isArray()))
405             {
406                 return "ANGLE_notEqualStructArray";
407             }
408 
409             if ((argType0->isVector() && argType1->isVector()) ||
410                 (argType0->isArray() && argType1->isArray()) ||
411                 (argType0->isMatrix() && argType1->isMatrix()))
412             {
413                 return "ANGLE_notEqual";
414             }
415             else if (argType0->getStruct() && argType1->getStruct())
416             {
417                 return "ANGLE_notEqualStruct";
418             }
419             return "!=";
420 
421         case TOperator::EOpKill:
422             UNIMPLEMENTED();
423             return "kill";
424         case TOperator::EOpReturn:
425             return "return";
426         case TOperator::EOpBreak:
427             return "break";
428         case TOperator::EOpContinue:
429             return "continue";
430 
431         case TOperator::EOpRadians:
432             return "ANGLE_radians";
433         case TOperator::EOpDegrees:
434             return "ANGLE_degrees";
435         case TOperator::EOpAtan:
436             return "ANGLE_atan";
437         case TOperator::EOpMod:
438             return "ANGLE_mod";  // differs from metal::mod
439         case TOperator::EOpRefract:
440             return "ANGLE_refract";
441         case TOperator::EOpDistance:
442             return "ANGLE_distance";
443         case TOperator::EOpLength:
444             return "ANGLE_length";
445         case TOperator::EOpDot:
446             return "ANGLE_dot";
447         case TOperator::EOpNormalize:
448             return "ANGLE_normalize";
449         case TOperator::EOpFaceforward:
450             return "ANGLE_faceforward";
451         case TOperator::EOpReflect:
452             return "ANGLE_reflect";
453         case TOperator::EOpMatrixCompMult:
454             return "ANGLE_componentWiseMultiply";
455         case TOperator::EOpOuterProduct:
456             return "ANGLE_outerProduct";
457         case TOperator::EOpSign:
458             return "ANGLE_sign";
459 
460         case TOperator::EOpAbs:
461             return "metal::abs";
462         case TOperator::EOpAll:
463             return "metal::all";
464         case TOperator::EOpAny:
465             return "metal::any";
466         case TOperator::EOpSin:
467             return "metal::sin";
468         case TOperator::EOpCos:
469             return "metal::cos";
470         case TOperator::EOpTan:
471             return "metal::tan";
472         case TOperator::EOpAsin:
473             return "metal::asin";
474         case TOperator::EOpAcos:
475             return "metal::acos";
476         case TOperator::EOpSinh:
477             return "metal::sinh";
478         case TOperator::EOpCosh:
479             return "metal::cosh";
480         case TOperator::EOpTanh:
481             return "metal::tanh";
482         case TOperator::EOpAsinh:
483             return "metal::asinh";
484         case TOperator::EOpAcosh:
485             return "metal::acosh";
486         case TOperator::EOpAtanh:
487             return "metal::atanh";
488         case TOperator::EOpFma:
489             return "metal::fma";
490         case TOperator::EOpPow:
491             return "metal::pow";
492         case TOperator::EOpExp:
493             return "metal::exp";
494         case TOperator::EOpExp2:
495             return "metal::exp2";
496         case TOperator::EOpLog:
497             return "metal::log";
498         case TOperator::EOpLog2:
499             return "metal::log2";
500         case TOperator::EOpSqrt:
501             return "metal::sqrt";
502         case TOperator::EOpFloor:
503             return "metal::floor";
504         case TOperator::EOpTrunc:
505             return "metal::trunc";
506         case TOperator::EOpCeil:
507             return "metal::ceil";
508         case TOperator::EOpFract:
509             return "metal::fract";
510         case TOperator::EOpMin:
511             return "metal::min";
512         case TOperator::EOpMax:
513             return "metal::max";
514         case TOperator::EOpRound:
515             return "metal::round";
516         case TOperator::EOpRoundEven:
517             return "metal::rint";
518         case TOperator::EOpClamp:
519             return "metal::clamp";  // TODO fast vs precise namespace
520         case TOperator::EOpSaturate:
521             return "metal::saturate";  // TODO fast vs precise namespace
522         case TOperator::EOpMix:
523             if (argType2 && argType2->getBasicType() == EbtBool)
524                 return "ANGLE_mix_bool";
525             return "metal::mix";
526         case TOperator::EOpStep:
527             return "metal::step";
528         case TOperator::EOpSmoothstep:
529             return "metal::smoothstep";
530         case TOperator::EOpModf:
531             return "metal::modf";
532         case TOperator::EOpIsnan:
533             return "metal::isnan";
534         case TOperator::EOpIsinf:
535             return "metal::isinf";
536         case TOperator::EOpLdexp:
537             return "metal::ldexp";
538         case TOperator::EOpFrexp:
539             return "metal::frexp";
540         case TOperator::EOpInversesqrt:
541             return "metal::rsqrt";
542         case TOperator::EOpCross:
543             return "metal::cross";
544         case TOperator::EOpDFdx:
545             return "metal::dfdx";
546         case TOperator::EOpDFdy:
547             return "metal::dfdy";
548         case TOperator::EOpFwidth:
549             return "metal::fwidth";
550         case TOperator::EOpTranspose:
551             return "metal::transpose";
552         case TOperator::EOpDeterminant:
553             return "metal::determinant";
554 
555         case TOperator::EOpInverse:
556             return "ANGLE_inverse";
557 
558         case TOperator::EOpInterpolateAtCentroid:
559             return "ANGLE_interpolateAtCentroid";
560         case TOperator::EOpInterpolateAtSample:
561             return "ANGLE_interpolateAtSample";
562         case TOperator::EOpInterpolateAtOffset:
563             return "ANGLE_interpolateAtOffset";
564         case TOperator::EOpInterpolateAtCenter:
565             return "ANGLE_interpolateAtCenter";
566 
567         case TOperator::EOpFloatBitsToInt:
568         case TOperator::EOpFloatBitsToUint:
569         case TOperator::EOpIntBitsToFloat:
570         case TOperator::EOpUintBitsToFloat:
571         {
572 #define RETURN_AS_TYPE_SCALAR()             \
573     do                                      \
574         switch (resultType.getBasicType())  \
575         {                                   \
576             case TBasicType::EbtInt:        \
577                 return "as_type<int>";      \
578             case TBasicType::EbtUInt:       \
579                 return "as_type<uint32_t>"; \
580             case TBasicType::EbtFloat:      \
581                 return "as_type<float>";    \
582             default:                        \
583                 UNIMPLEMENTED();            \
584                 return "TOperator_TODO";    \
585         }                                   \
586     while (false)
587 
588 #define RETURN_AS_TYPE(post)                     \
589     do                                           \
590         switch (resultType.getBasicType())       \
591         {                                        \
592             case TBasicType::EbtInt:             \
593                 return "as_type<int" post ">";   \
594             case TBasicType::EbtUInt:            \
595                 return "as_type<uint" post ">";  \
596             case TBasicType::EbtFloat:           \
597                 return "as_type<float" post ">"; \
598             default:                             \
599                 UNIMPLEMENTED();                 \
600                 return "TOperator_TODO";         \
601         }                                        \
602     while (false)
603 
604             if (resultType.isScalar())
605             {
606                 RETURN_AS_TYPE_SCALAR();
607             }
608             else if (resultType.isVector())
609             {
610                 switch (resultType.getNominalSize())
611                 {
612                     case 2:
613                         RETURN_AS_TYPE("2");
614                     case 3:
615                         RETURN_AS_TYPE("3");
616                     case 4:
617                         RETURN_AS_TYPE("4");
618                     default:
619                         UNREACHABLE();
620                         return nullptr;
621                 }
622             }
623             else
624             {
625                 UNIMPLEMENTED();
626                 return "TOperator_TODO";
627             }
628 
629 #undef RETURN_AS_TYPE
630 #undef RETURN_AS_TYPE_SCALAR
631         }
632 
633         case TOperator::EOpPackUnorm2x16:
634             return "metal::pack_float_to_unorm2x16";
635         case TOperator::EOpPackSnorm2x16:
636             return "metal::pack_float_to_snorm2x16";
637 
638         case TOperator::EOpPackUnorm4x8:
639             return "metal::pack_float_to_unorm4x8";
640         case TOperator::EOpPackSnorm4x8:
641             return "metal::pack_float_to_snorm4x8";
642 
643         case TOperator::EOpUnpackUnorm2x16:
644             return "metal::unpack_unorm2x16_to_float";
645         case TOperator::EOpUnpackSnorm2x16:
646             return "metal::unpack_snorm2x16_to_float";
647 
648         case TOperator::EOpUnpackUnorm4x8:
649             return "metal::unpack_unorm4x8_to_float";
650         case TOperator::EOpUnpackSnorm4x8:
651             return "metal::unpack_snorm4x8_to_float";
652 
653         case TOperator::EOpPackHalf2x16:
654             return "ANGLE_pack_half_2x16";
655         case TOperator::EOpUnpackHalf2x16:
656             return "ANGLE_unpack_half_2x16";
657 
658         case TOperator::EOpNumSamples:
659             return "metal::get_num_samples";
660         case TOperator::EOpSamplePosition:
661             return "metal::get_sample_position";
662 
663         case TOperator::EOpBitfieldExtract:
664         case TOperator::EOpBitfieldInsert:
665         case TOperator::EOpBitfieldReverse:
666         case TOperator::EOpBitCount:
667         case TOperator::EOpFindLSB:
668         case TOperator::EOpFindMSB:
669         case TOperator::EOpUaddCarry:
670         case TOperator::EOpUsubBorrow:
671         case TOperator::EOpUmulExtended:
672         case TOperator::EOpImulExtended:
673         case TOperator::EOpBarrier:
674         case TOperator::EOpMemoryBarrier:
675         case TOperator::EOpMemoryBarrierAtomicCounter:
676         case TOperator::EOpMemoryBarrierBuffer:
677         case TOperator::EOpMemoryBarrierShared:
678         case TOperator::EOpGroupMemoryBarrier:
679         case TOperator::EOpAtomicAdd:
680         case TOperator::EOpAtomicMin:
681         case TOperator::EOpAtomicMax:
682         case TOperator::EOpAtomicAnd:
683         case TOperator::EOpAtomicOr:
684         case TOperator::EOpAtomicXor:
685         case TOperator::EOpAtomicExchange:
686         case TOperator::EOpAtomicCompSwap:
687         case TOperator::EOpEmitVertex:
688         case TOperator::EOpEndPrimitive:
689         case TOperator::EOpFtransform:
690         case TOperator::EOpPackDouble2x32:
691         case TOperator::EOpUnpackDouble2x32:
692         case TOperator::EOpArrayLength:
693             UNIMPLEMENTED();
694             return "TOperator_TODO";
695 
696         case TOperator::EOpNull:
697         case TOperator::EOpConstruct:
698         case TOperator::EOpCallFunctionInAST:
699         case TOperator::EOpCallInternalRawFunction:
700         case TOperator::EOpIndexDirect:
701         case TOperator::EOpIndexIndirect:
702         case TOperator::EOpIndexDirectStruct:
703         case TOperator::EOpIndexDirectInterfaceBlock:
704             UNREACHABLE();
705             return nullptr;
706         default:
707             // Any other built-in function.
708             return nullptr;
709     }
710 }
711 
IsSymbolicOperator(TOperator op,const TType & resultType,const TType * argType0,const TType * argType1)712 static bool IsSymbolicOperator(TOperator op,
713                                const TType &resultType,
714                                const TType *argType0,
715                                const TType *argType1)
716 {
717     const char *operatorString = GetOperatorString(op, resultType, argType0, argType1, nullptr);
718     if (operatorString == nullptr)
719     {
720         return false;
721     }
722     return !std::isalnum(operatorString[0]);
723 }
724 
AsSpecificBinaryNode(TIntermNode & node,TOperator op)725 static TIntermBinary *AsSpecificBinaryNode(TIntermNode &node, TOperator op)
726 {
727     TIntermBinary *binaryNode = node.getAsBinaryNode();
728     if (binaryNode)
729     {
730         return binaryNode->getOp() == op ? binaryNode : nullptr;
731     }
732     return nullptr;
733 }
734 
Parenthesize(TIntermNode & node)735 static bool Parenthesize(TIntermNode &node)
736 {
737     if (node.getAsSymbolNode())
738     {
739         return false;
740     }
741     if (node.getAsConstantUnion())
742     {
743         return false;
744     }
745     if (node.getAsAggregate())
746     {
747         return false;
748     }
749     if (node.getAsSwizzleNode())
750     {
751         return false;
752     }
753 
754     if (TIntermUnary *unaryNode = node.getAsUnaryNode())
755     {
756         // TODO: Use a precedence and associativity rules instead of this ad-hoc impl.
757         const TType &resultType = unaryNode->getType();
758         const TType &argType    = unaryNode->getOperand()->getType();
759         return IsSymbolicOperator(unaryNode->getOp(), resultType, &argType, nullptr);
760     }
761 
762     if (TIntermBinary *binaryNode = node.getAsBinaryNode())
763     {
764         // TODO: Use a precedence and associativity rules instead of this ad-hoc impl.
765         const TOperator op = binaryNode->getOp();
766         switch (op)
767         {
768             case TOperator::EOpIndexDirectStruct:
769             case TOperator::EOpIndexDirectInterfaceBlock:
770             case TOperator::EOpIndexDirect:
771             case TOperator::EOpIndexIndirect:
772                 return Parenthesize(*binaryNode->getLeft());
773 
774             case TOperator::EOpAssign:
775             case TOperator::EOpInitialize:
776                 return AsSpecificBinaryNode(*binaryNode->getRight(), TOperator::EOpComma);
777 
778             default:
779             {
780                 const TType &resultType = binaryNode->getType();
781                 const TType &leftType   = binaryNode->getLeft()->getType();
782                 const TType &rightType  = binaryNode->getRight()->getType();
783                 return IsSymbolicOperator(binaryNode->getOp(), resultType, &leftType, &rightType);
784             }
785         }
786     }
787 
788     return true;
789 }
790 
groupedTraverse(TIntermNode & node)791 void GenMetalTraverser::groupedTraverse(TIntermNode &node)
792 {
793     const bool emitParens = Parenthesize(node);
794 
795     if (emitParens)
796     {
797         mOut << "(";
798     }
799 
800     node.traverse(this);
801 
802     if (emitParens)
803     {
804         mOut << ")";
805     }
806 }
807 
emitPostQualifier(const EmitVariableDeclarationConfig & evdConfig,const VarDecl & decl,const TQualifier qualifier)808 void GenMetalTraverser::emitPostQualifier(const EmitVariableDeclarationConfig &evdConfig,
809                                           const VarDecl &decl,
810                                           const TQualifier qualifier)
811 {
812     bool isInvariant = false;
813     switch (qualifier)
814     {
815         case TQualifier::EvqPosition:
816             isInvariant = decl.type().isInvariant();
817             [[fallthrough]];
818         case TQualifier::EvqFragCoord:
819             mOut << " [[position]]";
820             break;
821 
822         case TQualifier::EvqClipDistance:
823             mOut << " [[clip_distance]] [" << decl.type().getOutermostArraySize() << "]";
824             break;
825 
826         case TQualifier::EvqPointSize:
827             mOut << " [[point_size]]";
828             break;
829 
830         case TQualifier::EvqVertexID:
831             if (evdConfig.isMainParameter)
832             {
833                 mOut << " [[vertex_id]]";
834             }
835             break;
836 
837         case TQualifier::EvqPointCoord:
838             if (evdConfig.isMainParameter)
839             {
840                 mOut << " [[point_coord]]";
841             }
842             break;
843 
844         case TQualifier::EvqFrontFacing:
845             if (evdConfig.isMainParameter)
846             {
847                 mOut << " [[front_facing]]";
848             }
849             break;
850 
851         case TQualifier::EvqSampleID:
852             if (evdConfig.isMainParameter)
853             {
854                 mOut << " [[sample_id]]";
855             }
856             break;
857 
858         case TQualifier::EvqSampleMaskIn:
859             if (evdConfig.isMainParameter)
860             {
861                 mOut << " [[sample_mask]]";
862             }
863             break;
864 
865         default:
866             break;
867     }
868 
869     if (isInvariant)
870     {
871         mOut << " [[invariant]]";
872 
873         TranslatorMetalReflection *reflection = mtl::getTranslatorMetalReflection(&mCompiler);
874         reflection->hasInvariance             = true;
875     }
876 }
877 
EmitName(Sink & out,const Name & name)878 static void EmitName(Sink &out, const Name &name)
879 {
880 #if defined(ANGLE_ENABLE_ASSERTS)
881     DebugSink::EscapedSink escapedOut(out.escape());
882 #else
883     TInfoSinkBase &escapedOut = out;
884 #endif
885     name.emit(escapedOut);
886 }
887 
emitNameOf(const TField & object)888 void GenMetalTraverser::emitNameOf(const TField &object)
889 {
890     EmitName(mOut, Name(object));
891 }
892 
emitNameOf(const TSymbol & object)893 void GenMetalTraverser::emitNameOf(const TSymbol &object)
894 {
895     auto it = mRenamedSymbols.find(&object);
896     if (it == mRenamedSymbols.end())
897     {
898         EmitName(mOut, Name(object));
899     }
900     else
901     {
902         EmitName(mOut, it->second);
903     }
904 }
905 
emitNameOf(const VarDecl & object)906 void GenMetalTraverser::emitNameOf(const VarDecl &object)
907 {
908     if (object.isField())
909     {
910         emitNameOf(object.field());
911     }
912     else
913     {
914         emitNameOf(object.variable());
915     }
916 }
917 
emitBareTypeName(const TType & type,const EmitTypeConfig & etConfig)918 void GenMetalTraverser::emitBareTypeName(const TType &type, const EmitTypeConfig &etConfig)
919 {
920     const TBasicType basicType = type.getBasicType();
921 
922     switch (basicType)
923     {
924         case TBasicType::EbtVoid:
925         case TBasicType::EbtBool:
926         case TBasicType::EbtFloat:
927         case TBasicType::EbtInt:
928         {
929             mOut << type.getBasicString();
930             break;
931         }
932         case TBasicType::EbtUInt:
933         {
934             if (type.isScalar())
935             {
936                 mOut << "uint32_t";
937             }
938             else
939             {
940                 mOut << type.getBasicString();
941             }
942         }
943         break;
944 
945         case TBasicType::EbtStruct:
946         {
947             const TStructure &structure = *type.getStruct();
948             emitNameOf(structure);
949         }
950         break;
951 
952         case TBasicType::EbtInterfaceBlock:
953         {
954             const TInterfaceBlock &interfaceBlock = *type.getInterfaceBlock();
955             emitNameOf(interfaceBlock);
956         }
957         break;
958 
959         default:
960         {
961             if (IsSampler(basicType))
962             {
963                 if (etConfig.evdConfig && etConfig.evdConfig->isMainParameter)
964                 {
965                     EmitName(mOut, GetTextureTypeName(basicType));
966                 }
967                 else
968                 {
969                     const TStructure &env = mSymbolEnv.getTextureEnv(basicType);
970                     emitNameOf(env);
971                 }
972             }
973             else if (IsImage(basicType))
974             {
975                 mOut << "metal::texture2d<";
976                 switch (type.getBasicType())
977                 {
978                     case EbtImage2D:
979                         mOut << "float";
980                         break;
981                     case EbtIImage2D:
982                         mOut << "int";
983                         break;
984                     case EbtUImage2D:
985                         mOut << "uint";
986                         break;
987                     default:
988                         UNIMPLEMENTED();
989                         break;
990                 }
991                 if (type.getMemoryQualifier().readonly || type.getMemoryQualifier().writeonly)
992                 {
993                     UNIMPLEMENTED();
994                 }
995                 mOut << ", metal::access::read_write>";
996             }
997             else
998             {
999                 UNIMPLEMENTED();
1000             }
1001         }
1002     }
1003 }
1004 
emitType(const TType & type,const EmitTypeConfig & etConfig)1005 void GenMetalTraverser::emitType(const TType &type, const EmitTypeConfig &etConfig)
1006 {
1007     const bool isUBO = etConfig.evdConfig ? etConfig.evdConfig->isUBO : false;
1008     if (etConfig.evdConfig)
1009     {
1010         const auto &evdConfig = *etConfig.evdConfig;
1011         if (isUBO)
1012         {
1013             if (type.isArray())
1014             {
1015                 mOut << "ANGLE_tensor<";
1016             }
1017         }
1018         if (evdConfig.isPointer)
1019         {
1020             mOut << toString(*evdConfig.isPointer);
1021             mOut << " ";
1022         }
1023         else if (evdConfig.isReference)
1024         {
1025             mOut << toString(*evdConfig.isReference);
1026             mOut << " ";
1027         }
1028     }
1029 
1030     if (!isUBO)
1031     {
1032         if (type.isArray())
1033         {
1034             mOut << "ANGLE_tensor<";
1035         }
1036     }
1037 
1038     if (type.isInterpolant())
1039     {
1040         mOut << "metal::interpolant<";
1041     }
1042 
1043     if (type.isVector() || type.isMatrix())
1044     {
1045         mOut << "metal::";
1046     }
1047 
1048     if (etConfig.evdConfig && etConfig.evdConfig->isPacked)
1049     {
1050         mOut << "packed_";
1051     }
1052 
1053     emitBareTypeName(type, etConfig);
1054 
1055     if (type.isVector())
1056     {
1057         mOut << static_cast<uint32_t>(type.getNominalSize());
1058     }
1059     else if (type.isMatrix())
1060     {
1061         mOut << static_cast<uint32_t>(type.getCols()) << "x"
1062              << static_cast<uint32_t>(type.getRows());
1063     }
1064 
1065     if (type.isInterpolant())
1066     {
1067         mOut << ", metal::interpolation::";
1068         switch (type.getQualifier())
1069         {
1070             case EvqNoPerspectiveIn:
1071             case EvqNoPerspectiveCentroidIn:
1072             case EvqNoPerspectiveSampleIn:
1073                 mOut << "no_";
1074                 break;
1075             default:
1076                 break;
1077         }
1078         mOut << "perspective>";
1079     }
1080 
1081     if (!isUBO)
1082     {
1083         if (type.isArray())
1084         {
1085             for (auto size : type.getArraySizes())
1086             {
1087                 mOut << ", " << size;
1088             }
1089             mOut << ">";
1090         }
1091     }
1092 
1093     if (etConfig.evdConfig)
1094     {
1095         const auto &evdConfig = *etConfig.evdConfig;
1096         if (evdConfig.isPointer)
1097         {
1098             mOut << " *";
1099         }
1100         else if (evdConfig.isReference)
1101         {
1102             mOut << " &";
1103         }
1104         if (isUBO)
1105         {
1106             if (type.isArray())
1107             {
1108                 for (auto size : type.getArraySizes())
1109                 {
1110                     mOut << ", " << size;
1111                 }
1112                 mOut << ">";
1113             }
1114         }
1115     }
1116 }
1117 
emitFieldDeclaration(const TField & field,const TStructure & parent,FieldAnnotationIndices & annotationIndices)1118 void GenMetalTraverser::emitFieldDeclaration(const TField &field,
1119                                              const TStructure &parent,
1120                                              FieldAnnotationIndices &annotationIndices)
1121 {
1122     const TType &type      = *field.type();
1123     const TBasicType basic = type.getBasicType();
1124 
1125     EmitVariableDeclarationConfig evdConfig;
1126     evdConfig.emitPostQualifier      = true;
1127     evdConfig.disableStructSpecifier = true;
1128     evdConfig.isPacked               = mSymbolEnv.isPacked(field);
1129     evdConfig.isUBO                  = mSymbolEnv.isUBO(field);
1130     evdConfig.isPointer              = mSymbolEnv.isPointer(field);
1131     evdConfig.isReference            = mSymbolEnv.isReference(field);
1132     emitVariableDeclaration(VarDecl(field), evdConfig);
1133 
1134     const TQualifier qual = type.getQualifier();
1135     switch (qual)
1136     {
1137         case TQualifier::EvqFlatIn:
1138             if (mPipelineStructs.fragmentIn.external == &parent)
1139             {
1140                 mOut << " [[flat]]";
1141                 TranslatorMetalReflection *reflection =
1142                     mtl::getTranslatorMetalReflection(&mCompiler);
1143                 reflection->hasFlatInput = true;
1144             }
1145             break;
1146 
1147         case TQualifier::EvqNoPerspectiveIn:
1148             if (mPipelineStructs.fragmentIn.external == &parent && !type.isInterpolant())
1149             {
1150                 mOut << " [[center_no_perspective]]";
1151             }
1152             break;
1153 
1154         case TQualifier::EvqCentroidIn:
1155             if (mPipelineStructs.fragmentIn.external == &parent && !type.isInterpolant())
1156             {
1157                 mOut << " [[centroid_perspective]]";
1158             }
1159             break;
1160 
1161         case TQualifier::EvqSampleIn:
1162             if (mPipelineStructs.fragmentIn.external == &parent && !type.isInterpolant())
1163             {
1164                 mOut << " [[sample_perspective]]";
1165             }
1166             break;
1167 
1168         case TQualifier::EvqNoPerspectiveCentroidIn:
1169             if (mPipelineStructs.fragmentIn.external == &parent && !type.isInterpolant())
1170             {
1171                 mOut << " [[centroid_no_perspective]]";
1172             }
1173             break;
1174 
1175         case TQualifier::EvqNoPerspectiveSampleIn:
1176             if (mPipelineStructs.fragmentIn.external == &parent && !type.isInterpolant())
1177             {
1178                 mOut << " [[sample_no_perspective]]";
1179             }
1180             break;
1181 
1182         case TQualifier::EvqFragColor:
1183             mOut << " [[color(0)]]";
1184             break;
1185 
1186         case TQualifier::EvqSecondaryFragColorEXT:
1187         case TQualifier::EvqSecondaryFragDataEXT:
1188             mOut << " [[color(0), index(1)]]";
1189             break;
1190 
1191         case TQualifier::EvqFragmentOut:
1192         case TQualifier::EvqFragmentInOut:
1193         case TQualifier::EvqFragData:
1194             if (mPipelineStructs.fragmentOut.external == &parent)
1195             {
1196                 if ((type.isVector() &&
1197                      (basic == TBasicType::EbtInt || basic == TBasicType::EbtUInt ||
1198                       basic == TBasicType::EbtFloat)) ||
1199                     qual == EvqFragData)
1200                 {
1201                     // The OpenGL ES 3.0 spec says locations must be specified
1202                     // unless there is only a single output, in which case the
1203                     // location is 0. So, when we get to this point the shader
1204                     // will have been rejected if locations are not specified
1205                     // and there is more than one output.
1206                     const TLayoutQualifier &layoutQualifier = type.getLayoutQualifier();
1207                     if (layoutQualifier.locationsSpecified)
1208                     {
1209                         mOut << " [[color(" << layoutQualifier.location << ")";
1210                         ASSERT(layoutQualifier.index >= -1 && layoutQualifier.index <= 1);
1211                         if (layoutQualifier.index == 1)
1212                         {
1213                             mOut << ", index(1)";
1214                         }
1215                     }
1216                     else if (qual == EvqFragData)
1217                     {
1218                         mOut << " [[color(" << annotationIndices.color++ << ")";
1219                     }
1220                     else
1221                     {
1222                         // Either the only output or EXT_blend_func_extended is used;
1223                         // actual assignment will happen in UpdateFragmentShaderOutputs.
1224                         mOut << " [[" << sh::kUnassignedFragmentOutputString;
1225                     }
1226                     if (mRasterOrderGroupsSupported && qual == TQualifier::EvqFragmentInOut)
1227                     {
1228                         // Put fragment inouts in their own raster order group for better
1229                         // parallelism.
1230                         // NOTE: this is not required for the reads to be ordered and coherent.
1231                         // TODO(anglebug.com/7279): Consider making raster order groups a PLS layout
1232                         // qualifier?
1233                         mOut << ", raster_order_group(0)";
1234                     }
1235                     mOut << "]]";
1236                 }
1237             }
1238             break;
1239 
1240         case TQualifier::EvqFragDepth:
1241             mOut << " [[depth(";
1242             switch (type.getLayoutQualifier().depth)
1243             {
1244                 case EdGreater:
1245                     mOut << "greater";
1246                     break;
1247                 case EdLess:
1248                     mOut << "less";
1249                     break;
1250                 default:
1251                     mOut << "any";
1252                     break;
1253             }
1254             mOut << "), function_constant(" << sh::mtl::kDepthWriteEnabledConstName << ")]]";
1255             break;
1256 
1257         case TQualifier::EvqSampleMask:
1258             if (field.symbolType() == SymbolType::AngleInternal)
1259             {
1260                 mOut << " [[sample_mask, function_constant("
1261                      << sh::mtl::kMultisampledRenderingConstName << ")]]";
1262             }
1263             break;
1264 
1265         default:
1266             break;
1267     }
1268 
1269     if (IsImage(type.getBasicType()))
1270     {
1271         if (type.getArraySizeProduct() != 1)
1272         {
1273             UNIMPLEMENTED();
1274         }
1275         mOut << " [[";
1276         if (type.getLayoutQualifier().rasterOrdered)
1277         {
1278             mOut << "raster_order_group(0), ";
1279         }
1280         mOut << "texture(" << mMainTextureIndex << ")]]";
1281         TranslatorMetalReflection *reflection = mtl::getTranslatorMetalReflection(&mCompiler);
1282         reflection->addRWTextureBinding(type.getLayoutQualifier().binding,
1283                                         static_cast<int>(mMainTextureIndex));
1284         ++mMainTextureIndex;
1285     }
1286 }
1287 
BuildExternalAttributeIndexMap(const TCompiler & compiler,const PipelineScoped<TStructure> & structure)1288 static std::map<Name, size_t> BuildExternalAttributeIndexMap(
1289     const TCompiler &compiler,
1290     const PipelineScoped<TStructure> &structure)
1291 {
1292     ASSERT(structure.isTotallyFull());
1293 
1294     const auto &shaderVars     = compiler.getAttributes();
1295     const size_t shaderVarSize = shaderVars.size();
1296     size_t shaderVarIndex      = 0;
1297 
1298     const auto &externalFields = structure.external->fields();
1299     const size_t externalSize  = externalFields.size();
1300     size_t externalIndex       = 0;
1301 
1302     const auto &internalFields = structure.internal->fields();
1303     const size_t internalSize  = internalFields.size();
1304     size_t internalIndex       = 0;
1305 
1306     // Internal fields are never split. External fields are sometimes split.
1307     ASSERT(externalSize >= internalSize);
1308 
1309     // Structures do not contain any inactive fields.
1310     ASSERT(shaderVarSize >= internalSize);
1311 
1312     std::map<Name, size_t> externalNameToAttributeIndex;
1313     size_t attributeIndex = 0;
1314 
1315     while (internalIndex < internalSize)
1316     {
1317         const TField &internalField = *internalFields[internalIndex];
1318         const Name internalName     = Name(internalField);
1319         const TType &internalType   = *internalField.type();
1320         while (internalName.rawName() != shaderVars[shaderVarIndex].name &&
1321                internalName.rawName() != shaderVars[shaderVarIndex].mappedName)
1322         {
1323             // This case represents an inactive field.
1324 
1325             ++shaderVarIndex;
1326             ASSERT(shaderVarIndex < shaderVarSize);
1327 
1328             ++attributeIndex;  // TODO: Might need to increment more if shader var type is a matrix.
1329         }
1330 
1331         const size_t cols =
1332             (internalType.isMatrix() && !externalFields[externalIndex]->type()->isMatrix())
1333                 ? internalType.getCols()
1334                 : 1;
1335 
1336         for (size_t c = 0; c < cols; ++c)
1337         {
1338             const TField &externalField = *externalFields[externalIndex];
1339             const Name externalName     = Name(externalField);
1340 
1341             externalNameToAttributeIndex[externalName] = attributeIndex;
1342 
1343             ++externalIndex;
1344             ++attributeIndex;
1345         }
1346 
1347         ++shaderVarIndex;
1348         ++internalIndex;
1349     }
1350 
1351     ASSERT(shaderVarIndex <= shaderVarSize);
1352     ASSERT(externalIndex <= externalSize);  // less than if padding was introduced
1353     ASSERT(internalIndex == internalSize);
1354 
1355     return externalNameToAttributeIndex;
1356 }
1357 
emitAttributeDeclaration(const TField & field,FieldAnnotationIndices & annotationIndices)1358 void GenMetalTraverser::emitAttributeDeclaration(const TField &field,
1359                                                  FieldAnnotationIndices &annotationIndices)
1360 {
1361     EmitVariableDeclarationConfig evdConfig;
1362     evdConfig.disableStructSpecifier = true;
1363     emitVariableDeclaration(VarDecl(field), evdConfig);
1364     mOut << sh::kUnassignedAttributeString;
1365 }
1366 
emitUniformBufferDeclaration(const TField & field,FieldAnnotationIndices & annotationIndices)1367 void GenMetalTraverser::emitUniformBufferDeclaration(const TField &field,
1368                                                      FieldAnnotationIndices &annotationIndices)
1369 {
1370     EmitVariableDeclarationConfig evdConfig;
1371     evdConfig.disableStructSpecifier = true;
1372     evdConfig.isUBO                  = mSymbolEnv.isUBO(field);
1373     evdConfig.isPointer              = mSymbolEnv.isPointer(field);
1374     emitVariableDeclaration(VarDecl(field), evdConfig);
1375     mOut << "[[id(" << annotationIndices.attribute << ")]]";
1376 
1377     const TType &type   = *field.type();
1378     const int arraySize = type.isArray() ? type.getArraySizeProduct() : 1;
1379 
1380     TranslatorMetalReflection *reflection = mtl::getTranslatorMetalReflection(&mCompiler);
1381     ASSERT(type.getBasicType() == TBasicType::EbtStruct);
1382     const TStructure *structure    = type.getStruct();
1383     const std::string originalName = reflection->getOriginalName(structure->uniqueId().get());
1384     reflection->addUniformBufferBinding(
1385         originalName,
1386         {.bindIndex = annotationIndices.attribute, .arraySize = static_cast<size_t>(arraySize)});
1387 
1388     annotationIndices.attribute += arraySize;
1389 }
1390 
emitStructDeclaration(const TType & type)1391 void GenMetalTraverser::emitStructDeclaration(const TType &type)
1392 {
1393     ASSERT(type.getBasicType() == TBasicType::EbtStruct);
1394     ASSERT(type.isStructSpecifier());
1395 
1396     mOut << "struct ";
1397     emitBareTypeName(type, {});
1398 
1399     mOut << "\n";
1400     emitOpenBrace();
1401 
1402     const TStructure &structure = *type.getStruct();
1403     std::map<Name, size_t> fieldToAttributeIndex;
1404     const bool hasAttributeIndices      = mPipelineStructs.vertexIn.external == &structure;
1405     const bool hasUniformBufferIndicies = mPipelineStructs.uniformBuffers.external == &structure;
1406     const bool reclaimUnusedAttributeIndices = mCompiler.getShaderVersion() < 300;
1407 
1408     if (hasAttributeIndices)
1409     {
1410         fieldToAttributeIndex =
1411             BuildExternalAttributeIndexMap(mCompiler, mPipelineStructs.vertexIn);
1412     }
1413 
1414     FieldAnnotationIndices annotationIndices;
1415 
1416     for (const TField *field : structure.fields())
1417     {
1418         emitIndentation();
1419         if (hasAttributeIndices)
1420         {
1421             const auto it = fieldToAttributeIndex.find(Name(*field));
1422             if (it == fieldToAttributeIndex.end())
1423             {
1424                 ASSERT(field->symbolType() == SymbolType::AngleInternal);
1425                 ASSERT(field->name().beginsWith("_"));
1426                 ASSERT(angle::EndsWith(field->name().data(), "_pad"));
1427                 emitFieldDeclaration(*field, structure, annotationIndices);
1428             }
1429             else
1430             {
1431                 ASSERT(field->symbolType() != SymbolType::AngleInternal ||
1432                        !field->name().beginsWith("_") ||
1433                        !angle::EndsWith(field->name().data(), "_pad"));
1434                 if (!reclaimUnusedAttributeIndices)
1435                 {
1436                     annotationIndices.attribute = it->second;
1437                 }
1438                 emitAttributeDeclaration(*field, annotationIndices);
1439             }
1440         }
1441         else if (hasUniformBufferIndicies)
1442         {
1443             emitUniformBufferDeclaration(*field, annotationIndices);
1444         }
1445         else
1446         {
1447             emitFieldDeclaration(*field, structure, annotationIndices);
1448         }
1449         mOut << ";\n";
1450     }
1451 
1452     emitCloseBrace();
1453 }
1454 
emitOrdinaryVariableDeclaration(const VarDecl & decl,const EmitVariableDeclarationConfig & evdConfig)1455 void GenMetalTraverser::emitOrdinaryVariableDeclaration(
1456     const VarDecl &decl,
1457     const EmitVariableDeclarationConfig &evdConfig)
1458 {
1459     EmitTypeConfig etConfig;
1460     etConfig.evdConfig = &evdConfig;
1461 
1462     const TType &type = decl.type();
1463     if (type.getQualifier() == TQualifier::EvqClipDistance)
1464     {
1465         // Clip distance output uses float[n] type instead of metal::array.
1466         // The element count is emitted after the post qualifier.
1467         ASSERT(type.getBasicType() == TBasicType::EbtFloat);
1468         mOut << "float";
1469     }
1470     else if (type.getQualifier() == TQualifier::EvqSampleID && evdConfig.isMainParameter)
1471     {
1472         // Metal's [[sample_id]] must be unsigned
1473         ASSERT(type.getBasicType() == TBasicType::EbtInt);
1474         mOut << "uint32_t";
1475     }
1476     else
1477     {
1478         emitType(type, etConfig);
1479     }
1480     if (decl.symbolType() != SymbolType::Empty)
1481     {
1482         mOut << " ";
1483         emitNameOf(decl);
1484     }
1485 }
1486 
emitVariableDeclaration(const VarDecl & decl,const EmitVariableDeclarationConfig & evdConfig)1487 void GenMetalTraverser::emitVariableDeclaration(const VarDecl &decl,
1488                                                 const EmitVariableDeclarationConfig &evdConfig)
1489 {
1490     const SymbolType symbolType = decl.symbolType();
1491     const TType &type           = decl.type();
1492     const TBasicType basicType  = type.getBasicType();
1493 
1494     switch (basicType)
1495     {
1496         case TBasicType::EbtStruct:
1497         {
1498             if (type.isStructSpecifier() && !evdConfig.disableStructSpecifier)
1499             {
1500                 // It's invalid to declare a struct inside a function argument. When emitting a
1501                 // function parameter, the callsite should set evdConfig.disableStructSpecifier.
1502                 ASSERT(!evdConfig.isParameter);
1503                 emitStructDeclaration(type);
1504                 if (symbolType != SymbolType::Empty)
1505                 {
1506                     mOut << " ";
1507                     emitNameOf(decl);
1508                 }
1509             }
1510             else
1511             {
1512                 emitOrdinaryVariableDeclaration(decl, evdConfig);
1513             }
1514         }
1515         break;
1516 
1517         default:
1518         {
1519             ASSERT(symbolType != SymbolType::Empty || evdConfig.isParameter);
1520             emitOrdinaryVariableDeclaration(decl, evdConfig);
1521         }
1522     }
1523 
1524     if (evdConfig.emitPostQualifier)
1525     {
1526         emitPostQualifier(evdConfig, decl, type.getQualifier());
1527     }
1528 }
1529 
visitSymbol(TIntermSymbol * symbolNode)1530 void GenMetalTraverser::visitSymbol(TIntermSymbol *symbolNode)
1531 {
1532     const TVariable &var = symbolNode->variable();
1533     const TType &type    = var.getType();
1534     ASSERT(var.symbolType() != SymbolType::Empty);
1535 
1536     if (type.getBasicType() == TBasicType::EbtVoid)
1537     {
1538         mOut << "/*";
1539         emitNameOf(var);
1540         mOut << "*/";
1541     }
1542     else
1543     {
1544         emitNameOf(var);
1545     }
1546 }
1547 
emitSingleConstant(const TConstantUnion * const constUnion)1548 void GenMetalTraverser::emitSingleConstant(const TConstantUnion *const constUnion)
1549 {
1550     switch (constUnion->getType())
1551     {
1552         case TBasicType::EbtBool:
1553         {
1554             mOut << (constUnion->getBConst() ? "true" : "false");
1555         }
1556         break;
1557 
1558         case TBasicType::EbtFloat:
1559         {
1560             float value = constUnion->getFConst();
1561             if (std::isnan(value))
1562             {
1563                 mOut << "NAN";
1564             }
1565             else if (std::isinf(value))
1566             {
1567                 if (value < 0)
1568                 {
1569                     mOut << "-";
1570                 }
1571                 mOut << "INFINITY";
1572             }
1573             else
1574             {
1575                 mOut << value << "f";
1576             }
1577         }
1578         break;
1579 
1580         case TBasicType::EbtInt:
1581         {
1582             mOut << constUnion->getIConst();
1583         }
1584         break;
1585 
1586         case TBasicType::EbtUInt:
1587         {
1588             mOut << constUnion->getUConst() << "u";
1589         }
1590         break;
1591 
1592         default:
1593         {
1594             UNIMPLEMENTED();
1595         }
1596     }
1597 }
1598 
emitConstantUnionArray(const TConstantUnion * const constUnion,const size_t size)1599 const TConstantUnion *GenMetalTraverser::emitConstantUnionArray(
1600     const TConstantUnion *const constUnion,
1601     const size_t size)
1602 {
1603     const TConstantUnion *constUnionIterated = constUnion;
1604     for (size_t i = 0; i < size; i++, constUnionIterated++)
1605     {
1606         emitSingleConstant(constUnionIterated);
1607 
1608         if (i != size - 1)
1609         {
1610             mOut << ", ";
1611         }
1612     }
1613     return constUnionIterated;
1614 }
1615 
emitConstantUnion(const TType & type,const TConstantUnion * constUnionBegin)1616 const TConstantUnion *GenMetalTraverser::emitConstantUnion(const TType &type,
1617                                                            const TConstantUnion *constUnionBegin)
1618 {
1619     const TConstantUnion *constUnionCurr = constUnionBegin;
1620     const TStructure *structure          = type.getStruct();
1621     if (structure)
1622     {
1623         EmitTypeConfig config = EmitTypeConfig{nullptr};
1624         emitType(type, config);
1625         mOut << "{";
1626         const TFieldList &fields = structure->fields();
1627         for (size_t i = 0; i < fields.size(); ++i)
1628         {
1629             const TType *fieldType = fields[i]->type();
1630             constUnionCurr         = emitConstantUnion(*fieldType, constUnionCurr);
1631             if (i != fields.size() - 1)
1632             {
1633                 mOut << ", ";
1634             }
1635         }
1636         mOut << "}";
1637     }
1638     else
1639     {
1640         size_t size    = type.getObjectSize();
1641         bool writeType = size > 1;
1642         if (writeType)
1643         {
1644             EmitTypeConfig config = EmitTypeConfig{nullptr};
1645             emitType(type, config);
1646             mOut << "(";
1647         }
1648         constUnionCurr = emitConstantUnionArray(constUnionCurr, size);
1649         if (writeType)
1650         {
1651             mOut << ")";
1652         }
1653     }
1654     return constUnionCurr;
1655 }
1656 
visitConstantUnion(TIntermConstantUnion * constValueNode)1657 void GenMetalTraverser::visitConstantUnion(TIntermConstantUnion *constValueNode)
1658 {
1659     emitConstantUnion(constValueNode->getType(), constValueNode->getConstantValue());
1660 }
1661 
visitSwizzle(Visit,TIntermSwizzle * swizzleNode)1662 bool GenMetalTraverser::visitSwizzle(Visit, TIntermSwizzle *swizzleNode)
1663 {
1664     groupedTraverse(*swizzleNode->getOperand());
1665     mOut << ".";
1666 
1667     {
1668 #if defined(ANGLE_ENABLE_ASSERTS)
1669         DebugSink::EscapedSink escapedOut(mOut.escape());
1670         TInfoSinkBase &out = escapedOut.get();
1671 #else
1672         TInfoSinkBase &out        = mOut;
1673 #endif
1674         swizzleNode->writeOffsetsAsXYZW(&out);
1675     }
1676 
1677     return false;
1678 }
1679 
getDirectField(const TFieldListCollection & fieldListCollection,const TConstantUnion & index)1680 const TField &GenMetalTraverser::getDirectField(const TFieldListCollection &fieldListCollection,
1681                                                 const TConstantUnion &index)
1682 {
1683     ASSERT(index.getType() == TBasicType::EbtInt);
1684 
1685     const TFieldList &fieldList = fieldListCollection.fields();
1686     const int indexVal          = index.getIConst();
1687     const TField &field         = *fieldList[indexVal];
1688 
1689     return field;
1690 }
1691 
getDirectField(const TIntermTyped & fieldsNode,TIntermTyped & indexNode)1692 const TField &GenMetalTraverser::getDirectField(const TIntermTyped &fieldsNode,
1693                                                 TIntermTyped &indexNode)
1694 {
1695     const TType &fieldsType = fieldsNode.getType();
1696 
1697     const TFieldListCollection *fieldListCollection = fieldsType.getStruct();
1698     if (fieldListCollection == nullptr)
1699     {
1700         fieldListCollection = fieldsType.getInterfaceBlock();
1701     }
1702     ASSERT(fieldListCollection);
1703 
1704     const TIntermConstantUnion *indexNode_ = indexNode.getAsConstantUnion();
1705     ASSERT(indexNode_);
1706     const TConstantUnion &index = *indexNode_->getConstantValue();
1707 
1708     return getDirectField(*fieldListCollection, index);
1709 }
1710 
visitBinary(Visit,TIntermBinary * binaryNode)1711 bool GenMetalTraverser::visitBinary(Visit, TIntermBinary *binaryNode)
1712 {
1713     const TOperator op      = binaryNode->getOp();
1714     TIntermTyped &leftNode  = *binaryNode->getLeft();
1715     TIntermTyped &rightNode = *binaryNode->getRight();
1716 
1717     switch (op)
1718     {
1719         case TOperator::EOpIndexDirectStruct:
1720         case TOperator::EOpIndexDirectInterfaceBlock:
1721         {
1722             const TField &field = getDirectField(leftNode, rightNode);
1723             if (mSymbolEnv.isPointer(field) && mSymbolEnv.isUBO(field))
1724             {
1725                 emitOpeningPointerParen();
1726             }
1727             groupedTraverse(leftNode);
1728             if (!mSymbolEnv.isPointer(field))
1729             {
1730                 emitClosingPointerParen();
1731             }
1732             mOut << ".";
1733             emitNameOf(field);
1734         }
1735         break;
1736 
1737         case TOperator::EOpIndexDirect:
1738         case TOperator::EOpIndexIndirect:
1739         {
1740             TType leftType = leftNode.getType();
1741             groupedTraverse(leftNode);
1742             mOut << "[";
1743             {
1744                 mOut << "ANGLE_int_clamp(";
1745                 groupedTraverse(rightNode);
1746                 mOut << ", 0, ";
1747                 if (leftType.isUnsizedArray())
1748                 {
1749                     groupedTraverse(leftNode);
1750                     mOut << ".size()";
1751                 }
1752                 else
1753                 {
1754                     uint32_t maxSize;
1755                     if (leftType.isArray())
1756                     {
1757                         maxSize = leftType.getOutermostArraySize() - 1;
1758                     }
1759                     else
1760                     {
1761                         maxSize = leftType.getNominalSize() - 1;
1762                     }
1763                     mOut << maxSize;
1764                 }
1765                 mOut << ")";
1766             }
1767             mOut << "]";
1768         }
1769         break;
1770 
1771         default:
1772         {
1773             const TType &resultType = binaryNode->getType();
1774             const TType &leftType   = leftNode.getType();
1775             const TType &rightType  = rightNode.getType();
1776 
1777             if (IsSymbolicOperator(op, resultType, &leftType, &rightType))
1778             {
1779                 groupedTraverse(leftNode);
1780                 if (op != TOperator::EOpComma)
1781                 {
1782                     mOut << " ";
1783                 }
1784                 else
1785                 {
1786                     emitClosingPointerParen();
1787                 }
1788                 mOut << GetOperatorString(op, resultType, &leftType, &rightType, nullptr) << " ";
1789                 groupedTraverse(rightNode);
1790             }
1791             else
1792             {
1793                 emitClosingPointerParen();
1794                 mOut << GetOperatorString(op, resultType, &leftType, &rightType, nullptr) << "(";
1795                 leftNode.traverse(this);
1796                 mOut << ", ";
1797                 rightNode.traverse(this);
1798                 mOut << ")";
1799             }
1800         }
1801     }
1802 
1803     return false;
1804 }
1805 
IsPostfix(TOperator op)1806 static bool IsPostfix(TOperator op)
1807 {
1808     switch (op)
1809     {
1810         case TOperator::EOpPostIncrement:
1811         case TOperator::EOpPostDecrement:
1812             return true;
1813 
1814         default:
1815             return false;
1816     }
1817 }
1818 
visitUnary(Visit,TIntermUnary * unaryNode)1819 bool GenMetalTraverser::visitUnary(Visit, TIntermUnary *unaryNode)
1820 {
1821     const TOperator op      = unaryNode->getOp();
1822     const TType &resultType = unaryNode->getType();
1823 
1824     TIntermTyped &arg    = *unaryNode->getOperand();
1825     const TType &argType = arg.getType();
1826 
1827     const char *name = GetOperatorString(op, resultType, &argType, nullptr, nullptr);
1828 
1829     if (IsSymbolicOperator(op, resultType, &argType, nullptr))
1830     {
1831         const bool postfix = IsPostfix(op);
1832         if (!postfix)
1833         {
1834             mOut << name;
1835         }
1836         groupedTraverse(arg);
1837         if (postfix)
1838         {
1839             mOut << name;
1840         }
1841     }
1842     else
1843     {
1844         mOut << name << "(";
1845         arg.traverse(this);
1846         mOut << ")";
1847     }
1848 
1849     return false;
1850 }
1851 
visitTernary(Visit,TIntermTernary * conditionalNode)1852 bool GenMetalTraverser::visitTernary(Visit, TIntermTernary *conditionalNode)
1853 {
1854     groupedTraverse(*conditionalNode->getCondition());
1855     mOut << " ? ";
1856     groupedTraverse(*conditionalNode->getTrueExpression());
1857     mOut << " : ";
1858     groupedTraverse(*conditionalNode->getFalseExpression());
1859 
1860     return false;
1861 }
1862 
visitIfElse(Visit,TIntermIfElse * ifThenElseNode)1863 bool GenMetalTraverser::visitIfElse(Visit, TIntermIfElse *ifThenElseNode)
1864 {
1865     TIntermTyped &condNode = *ifThenElseNode->getCondition();
1866     TIntermBlock *thenNode = ifThenElseNode->getTrueBlock();
1867     TIntermBlock *elseNode = ifThenElseNode->getFalseBlock();
1868 
1869     emitIndentation();
1870     mOut << "if (";
1871     condNode.traverse(this);
1872     mOut << ")";
1873 
1874     if (thenNode)
1875     {
1876         mOut << "\n";
1877         thenNode->traverse(this);
1878     }
1879     else
1880     {
1881         mOut << " {}";
1882     }
1883 
1884     if (elseNode)
1885     {
1886         mOut << "\n";
1887         emitIndentation();
1888         mOut << "else\n";
1889         elseNode->traverse(this);
1890     }
1891     else
1892     {
1893         // Always emit "else" even when empty block to avoid nested if-stmt issues.
1894         mOut << " else {}";
1895     }
1896 
1897     return false;
1898 }
1899 
visitSwitch(Visit,TIntermSwitch * switchNode)1900 bool GenMetalTraverser::visitSwitch(Visit, TIntermSwitch *switchNode)
1901 {
1902     emitIndentation();
1903     mOut << "switch (";
1904     switchNode->getInit()->traverse(this);
1905     mOut << ")\n";
1906 
1907     ASSERT(!mParentIsSwitch);
1908     mParentIsSwitch = true;
1909     switchNode->getStatementList()->traverse(this);
1910     mParentIsSwitch = false;
1911 
1912     return false;
1913 }
1914 
visitCase(Visit,TIntermCase * caseNode)1915 bool GenMetalTraverser::visitCase(Visit, TIntermCase *caseNode)
1916 {
1917     emitIndentation();
1918 
1919     if (caseNode->hasCondition())
1920     {
1921         TIntermTyped *condExpr = caseNode->getCondition();
1922         mOut << "case ";
1923         condExpr->traverse(this);
1924         mOut << ":";
1925     }
1926     else
1927     {
1928         mOut << "default:\n";
1929     }
1930 
1931     return false;
1932 }
1933 
emitFunctionSignature(const TFunction & func)1934 void GenMetalTraverser::emitFunctionSignature(const TFunction &func)
1935 {
1936     const bool isMain = func.isMain();
1937 
1938     emitFunctionReturn(func);
1939 
1940     mOut << " ";
1941     emitNameOf(func);
1942     if (isMain)
1943     {
1944         mOut << "0";
1945     }
1946     mOut << "(";
1947 
1948     bool emitComma          = false;
1949     const size_t paramCount = func.getParamCount();
1950     for (size_t i = 0; i < paramCount; ++i)
1951     {
1952         if (emitComma)
1953         {
1954             mOut << ", ";
1955         }
1956         emitComma = true;
1957 
1958         const TVariable &param = *func.getParam(i);
1959         emitFunctionParameter(func, param);
1960     }
1961 
1962     if (isTraversingVertexMain)
1963     {
1964         mOut << " @@XFB-Bindings@@ ";
1965     }
1966 
1967     mOut << ")";
1968 }
1969 
emitFunctionReturn(const TFunction & func)1970 void GenMetalTraverser::emitFunctionReturn(const TFunction &func)
1971 {
1972     const bool isMain       = func.isMain();
1973     bool isVertexMain       = false;
1974     const TType &returnType = func.getReturnType();
1975     if (isMain)
1976     {
1977         const TStructure *structure = returnType.getStruct();
1978         if (mPipelineStructs.fragmentOut.matches(*structure))
1979         {
1980             if (mCompiler.isEarlyFragmentTestsSpecified())
1981             {
1982                 mOut << "[[early_fragment_tests]]\n";
1983             }
1984             mOut << "fragment ";
1985         }
1986         else if (mPipelineStructs.vertexOut.matches(*structure))
1987         {
1988             ASSERT(structure != nullptr);
1989             mOut << "vertex __VERTEX_OUT(";
1990             isVertexMain = true;
1991         }
1992         else
1993         {
1994             UNIMPLEMENTED();
1995         }
1996     }
1997     emitType(returnType, EmitTypeConfig());
1998     if (isVertexMain)
1999         mOut << ") ";
2000 }
2001 
emitFunctionParameter(const TFunction & func,const TVariable & param)2002 void GenMetalTraverser::emitFunctionParameter(const TFunction &func, const TVariable &param)
2003 {
2004     const bool isMain = func.isMain();
2005 
2006     const TType &type           = param.getType();
2007     const TStructure *structure = type.getStruct();
2008 
2009     EmitVariableDeclarationConfig evdConfig;
2010     evdConfig.isParameter            = true;
2011     evdConfig.disableStructSpecifier = true;  // It's invalid to declare a struct in a function arg.
2012     evdConfig.isMainParameter        = isMain;
2013     evdConfig.emitPostQualifier      = isMain;
2014     evdConfig.isUBO                  = mSymbolEnv.isUBO(param);
2015     evdConfig.isPointer              = mSymbolEnv.isPointer(param);
2016     evdConfig.isReference            = mSymbolEnv.isReference(param);
2017     emitVariableDeclaration(VarDecl(param), evdConfig);
2018 
2019     if (isMain)
2020     {
2021         TranslatorMetalReflection *reflection = mtl::getTranslatorMetalReflection(&mCompiler);
2022         if (structure)
2023         {
2024             if (mPipelineStructs.fragmentIn.matches(*structure) ||
2025                 mPipelineStructs.vertexIn.matches(*structure))
2026             {
2027                 mOut << " [[stage_in]]";
2028             }
2029             else if (mPipelineStructs.angleUniforms.matches(*structure))
2030             {
2031                 mOut << " [[buffer(" << mDriverUniformsBindingIndex << ")]]";
2032             }
2033             else if (mPipelineStructs.uniformBuffers.matches(*structure))
2034             {
2035                 mOut << " [[buffer(" << mUBOArgumentBufferBindingIndex << ")]]";
2036                 reflection->hasUBOs = true;
2037             }
2038             else if (mPipelineStructs.userUniforms.matches(*structure))
2039             {
2040                 mOut << " [[buffer(" << mMainUniformBufferIndex << ")]]";
2041                 reflection->addUserUniformBufferBinding(param.name().data(),
2042                                                         mMainUniformBufferIndex);
2043                 mMainUniformBufferIndex += type.getArraySizeProduct();
2044             }
2045             else if (structure->name() == "metal::sampler")
2046             {
2047                 mOut << " [[sampler(" << (mMainSamplerIndex) << ")]]";
2048                 const std::string originalName =
2049                     reflection->getOriginalName(param.uniqueId().get());
2050                 reflection->addSamplerBinding(originalName, mMainSamplerIndex);
2051                 mMainSamplerIndex += type.getArraySizeProduct();
2052             }
2053         }
2054         else if (IsSampler(type.getBasicType()))
2055         {
2056             mOut << " [[texture(" << (mMainTextureIndex) << ")]]";
2057             const std::string originalName = reflection->getOriginalName(param.uniqueId().get());
2058             reflection->addTextureBinding(originalName, mMainTextureIndex);
2059             mMainTextureIndex += type.getArraySizeProduct();
2060         }
2061         else if (Name(param) == Pipeline{Pipeline::Type::InstanceId, nullptr}.getStructInstanceName(
2062                                     Pipeline::Variant::Modified))
2063         {
2064             mOut << " [[instance_id]]";
2065         }
2066         else if (Name(param) == kBaseInstanceName)
2067         {
2068             mOut << " [[base_instance]]";
2069         }
2070     }
2071 }
2072 
visitFunctionPrototype(TIntermFunctionPrototype * funcProtoNode)2073 void GenMetalTraverser::visitFunctionPrototype(TIntermFunctionPrototype *funcProtoNode)
2074 {
2075     const TFunction &func = *funcProtoNode->getFunction();
2076 
2077     emitIndentation();
2078     emitFunctionSignature(func);
2079 }
2080 
visitFunctionDefinition(Visit,TIntermFunctionDefinition * funcDefNode)2081 bool GenMetalTraverser::visitFunctionDefinition(Visit, TIntermFunctionDefinition *funcDefNode)
2082 {
2083     const TFunction &func = *funcDefNode->getFunction();
2084     TIntermBlock &body    = *funcDefNode->getBody();
2085     if (func.isMain())
2086     {
2087         const TType &returnType     = func.getReturnType();
2088         const TStructure *structure = returnType.getStruct();
2089         isTraversingVertexMain      = (mPipelineStructs.vertexOut.matches(*structure)) &&
2090                                  mCompiler.getShaderType() == GL_VERTEX_SHADER;
2091     }
2092     emitIndentation();
2093     emitFunctionSignature(func);
2094     mOut << "\n";
2095     body.traverse(this);
2096     if (isTraversingVertexMain)
2097     {
2098         isTraversingVertexMain = false;
2099     }
2100     return false;
2101 }
2102 
BuildFuncToName()2103 GenMetalTraverser::FuncToName GenMetalTraverser::BuildFuncToName()
2104 {
2105     FuncToName map;
2106 
2107     auto putAngle = [&](const char *nameStr) {
2108         const ImmutableString name(nameStr);
2109         ASSERT(map.find(name) == map.end());
2110         map[name] = Name(nameStr, SymbolType::AngleInternal);
2111     };
2112 
2113     putAngle("texelFetch");
2114     putAngle("texelFetchOffset");
2115     putAngle("texture");
2116     putAngle("texture1D");
2117     putAngle("texture1DLod");
2118     putAngle("texture1DProjLod");
2119     putAngle("texture2D");
2120     putAngle("texture2DGradEXT");
2121     putAngle("texture2DLod");
2122     putAngle("texture2DLodEXT");
2123     putAngle("texture2DProj");
2124     putAngle("texture2DProjGradEXT");
2125     putAngle("texture2DProjLod");
2126     putAngle("texture2DProjLodEXT");
2127     putAngle("texture2DRect");
2128     putAngle("texture2DRectProj");
2129     putAngle("texture3D");
2130     putAngle("texture3DLod");
2131     putAngle("texture3DProjLod");
2132     putAngle("textureCube");
2133     putAngle("textureCubeGradEXT");
2134     putAngle("textureCubeLod");
2135     putAngle("textureCubeLodEXT");
2136     putAngle("textureCubeProjLod");
2137     putAngle("textureGrad");
2138     putAngle("textureGradOffset");
2139     putAngle("textureLod");
2140     putAngle("textureLodOffset");
2141     putAngle("textureOffset");
2142     putAngle("textureProj");
2143     putAngle("textureProjGrad");
2144     putAngle("textureProjGradOffset");
2145     putAngle("textureProjLod");
2146     putAngle("textureProjLodOffset");
2147     putAngle("textureProjOffset");
2148     putAngle("textureSize");
2149     putAngle("imageLoad");
2150     putAngle("imageStore");
2151     putAngle("memoryBarrierImage");
2152 
2153     return map;
2154 }
2155 
visitAggregate(Visit,TIntermAggregate * aggregateNode)2156 bool GenMetalTraverser::visitAggregate(Visit, TIntermAggregate *aggregateNode)
2157 {
2158     const TIntermSequence &args = *aggregateNode->getSequence();
2159 
2160     auto emitArgList = [&](const char *open, const char *close) {
2161         mOut << open;
2162 
2163         bool emitComma = false;
2164         for (TIntermNode *arg : args)
2165         {
2166             if (emitComma)
2167             {
2168                 emitClosingPointerParen();
2169                 mOut << ", ";
2170             }
2171             emitComma = true;
2172             arg->traverse(this);
2173         }
2174 
2175         mOut << close;
2176     };
2177 
2178     const TType &retType = aggregateNode->getType();
2179 
2180     if (aggregateNode->isConstructor())
2181     {
2182         const bool isStandalone = getParentNode()->getAsBlock();
2183         if (isStandalone)
2184         {
2185             // Prevent constructor from being interpreted as a declaration by wrapping in parens.
2186             // This can happen if given something like:
2187             //      int(symbol); // <- This will be treated like `int symbol;`... don't want that.
2188             // So instead emit:
2189             //      (int(symbol));
2190             mOut << "(";
2191         }
2192 
2193         const EmitTypeConfig etConfig;
2194 
2195         if (retType.isArray())
2196         {
2197             emitType(retType, etConfig);
2198             emitArgList("{", "}");
2199         }
2200         else if (retType.getStruct())
2201         {
2202             emitType(retType, etConfig);
2203             emitArgList("{", "}");
2204         }
2205         else
2206         {
2207             emitType(retType, etConfig);
2208             emitArgList("(", ")");
2209         }
2210 
2211         if (isStandalone)
2212         {
2213             mOut << ")";
2214         }
2215 
2216         return false;
2217     }
2218     else
2219     {
2220         const TOperator op = aggregateNode->getOp();
2221         if (op == EOpAtan)
2222         {
2223             TranslatorMetalReflection *reflection = mtl::getTranslatorMetalReflection(&mCompiler);
2224             reflection->hasAtan                   = true;
2225         }
2226         switch (op)
2227         {
2228             case TOperator::EOpCallFunctionInAST:
2229             case TOperator::EOpCallInternalRawFunction:
2230             {
2231                 const TFunction &func = *aggregateNode->getFunction();
2232                 emitNameOf(func);
2233                 //'@' symbol in name specifices a macro substitution marker.
2234                 if (!func.name().contains("@"))
2235                 {
2236                     emitArgList("(", ")");
2237                 }
2238                 else
2239                 {
2240                     mTemporarilyDisableSemicolon =
2241                         true;  // Disable semicolon for macro substitution.
2242                 }
2243                 return false;
2244             }
2245 
2246             default:
2247             {
2248                 auto getArgType = [&](size_t index) -> const TType * {
2249                     if (index < args.size())
2250                     {
2251                         TIntermTyped *arg = args[index]->getAsTyped();
2252                         ASSERT(arg);
2253                         return &arg->getType();
2254                     }
2255                     return nullptr;
2256                 };
2257 
2258                 const TType *argType0 = getArgType(0);
2259                 const TType *argType1 = getArgType(1);
2260                 const TType *argType2 = getArgType(2);
2261 
2262                 const char *opName = GetOperatorString(op, retType, argType0, argType1, argType2);
2263 
2264                 if (IsSymbolicOperator(op, retType, argType0, argType1))
2265                 {
2266                     switch (args.size())
2267                     {
2268                         case 1:
2269                         {
2270                             TIntermNode &operandNode = *aggregateNode->getChildNode(0);
2271                             if (IsPostfix(op))
2272                             {
2273                                 mOut << opName;
2274                                 groupedTraverse(operandNode);
2275                             }
2276                             else
2277                             {
2278                                 groupedTraverse(operandNode);
2279                                 mOut << opName;
2280                             }
2281                             return false;
2282                         }
2283 
2284                         case 2:
2285                         {
2286                             TIntermNode &leftNode  = *aggregateNode->getChildNode(0);
2287                             TIntermNode &rightNode = *aggregateNode->getChildNode(1);
2288                             groupedTraverse(leftNode);
2289                             mOut << " " << opName << " ";
2290                             groupedTraverse(rightNode);
2291                             return false;
2292                         }
2293 
2294                         default:
2295                             UNREACHABLE();
2296                             return false;
2297                     }
2298                 }
2299                 else if (opName == nullptr)
2300                 {
2301                     const TFunction &func = *aggregateNode->getFunction();
2302                     auto it               = mFuncToName.find(func.name());
2303                     ASSERT(it != mFuncToName.end());
2304                     EmitName(mOut, it->second);
2305                     emitArgList("(", ")");
2306                     return false;
2307                 }
2308                 else
2309                 {
2310                     mOut << opName;
2311                     emitArgList("(", ")");
2312                     return false;
2313                 }
2314             }
2315         }
2316     }
2317 }
2318 
emitOpenBrace()2319 void GenMetalTraverser::emitOpenBrace()
2320 {
2321     ASSERT(mIndentLevel >= 0);
2322 
2323     emitIndentation();
2324     mOut << "{\n";
2325     ++mIndentLevel;
2326 }
2327 
emitCloseBrace()2328 void GenMetalTraverser::emitCloseBrace()
2329 {
2330     ASSERT(mIndentLevel >= 1);
2331 
2332     --mIndentLevel;
2333     emitIndentation();
2334     mOut << "}";
2335 }
2336 
RequiresSemicolonTerminator(TIntermNode & node)2337 static bool RequiresSemicolonTerminator(TIntermNode &node)
2338 {
2339     if (node.getAsBlock())
2340     {
2341         return false;
2342     }
2343     if (node.getAsLoopNode())
2344     {
2345         return false;
2346     }
2347     if (node.getAsSwitchNode())
2348     {
2349         return false;
2350     }
2351     if (node.getAsIfElseNode())
2352     {
2353         return false;
2354     }
2355     if (node.getAsFunctionDefinition())
2356     {
2357         return false;
2358     }
2359     if (node.getAsCaseNode())
2360     {
2361         return false;
2362     }
2363 
2364     return true;
2365 }
2366 
NewlinePad(TIntermNode & node)2367 static bool NewlinePad(TIntermNode &node)
2368 {
2369     if (node.getAsFunctionDefinition())
2370     {
2371         return true;
2372     }
2373     if (TIntermDeclaration *declNode = node.getAsDeclarationNode())
2374     {
2375         ASSERT(declNode->getChildCount() == 1);
2376         TIntermNode &childNode = *declNode->getChildNode(0);
2377         if (TIntermSymbol *symbolNode = childNode.getAsSymbolNode())
2378         {
2379             const TVariable &var = symbolNode->variable();
2380             return var.getType().isStructSpecifier();
2381         }
2382         return false;
2383     }
2384     return false;
2385 }
2386 
visitBlock(Visit,TIntermBlock * blockNode)2387 bool GenMetalTraverser::visitBlock(Visit, TIntermBlock *blockNode)
2388 {
2389     ASSERT(mIndentLevel >= -1);
2390     const bool isGlobalScope  = mIndentLevel == -1;
2391     const bool parentIsSwitch = mParentIsSwitch;
2392     mParentIsSwitch           = false;
2393 
2394     if (isGlobalScope)
2395     {
2396         ++mIndentLevel;
2397     }
2398     else
2399     {
2400         emitOpenBrace();
2401         if (parentIsSwitch)
2402         {
2403             ++mIndentLevel;
2404         }
2405     }
2406 
2407     TIntermNode *prevStmtNode = nullptr;
2408 
2409     const size_t stmtCount = blockNode->getChildCount();
2410     for (size_t i = 0; i < stmtCount; ++i)
2411     {
2412         TIntermNode &stmtNode = *blockNode->getChildNode(i);
2413 
2414         if (isGlobalScope && prevStmtNode && (NewlinePad(*prevStmtNode) || NewlinePad(stmtNode)))
2415         {
2416             mOut << "\n";
2417         }
2418         const bool isCase = stmtNode.getAsCaseNode();
2419         mIndentLevel -= isCase;
2420         emitIndentation();
2421         mIndentLevel += isCase;
2422         stmtNode.traverse(this);
2423         if (RequiresSemicolonTerminator(stmtNode) && !mTemporarilyDisableSemicolon)
2424         {
2425             mOut << ";";
2426         }
2427         mTemporarilyDisableSemicolon = false;
2428         mOut << "\n";
2429 
2430         prevStmtNode = &stmtNode;
2431     }
2432 
2433     if (isGlobalScope)
2434     {
2435         ASSERT(mIndentLevel == 0);
2436         --mIndentLevel;
2437     }
2438     else
2439     {
2440         if (parentIsSwitch)
2441         {
2442             ASSERT(mIndentLevel >= 1);
2443             --mIndentLevel;
2444         }
2445         emitCloseBrace();
2446         mParentIsSwitch = parentIsSwitch;
2447     }
2448 
2449     return false;
2450 }
2451 
visitGlobalQualifierDeclaration(Visit,TIntermGlobalQualifierDeclaration *)2452 bool GenMetalTraverser::visitGlobalQualifierDeclaration(Visit, TIntermGlobalQualifierDeclaration *)
2453 {
2454     return false;
2455 }
2456 
visitDeclaration(Visit,TIntermDeclaration * declNode)2457 bool GenMetalTraverser::visitDeclaration(Visit, TIntermDeclaration *declNode)
2458 {
2459     ASSERT(declNode->getChildCount() == 1);
2460     TIntermNode &node = *declNode->getChildNode(0);
2461 
2462     EmitVariableDeclarationConfig evdConfig;
2463 
2464     if (TIntermSymbol *symbolNode = node.getAsSymbolNode())
2465     {
2466         const TVariable &var = symbolNode->variable();
2467         emitVariableDeclaration(VarDecl(var), evdConfig);
2468     }
2469     else if (TIntermBinary *initNode = node.getAsBinaryNode())
2470     {
2471         ASSERT(initNode->getOp() == TOperator::EOpInitialize);
2472         TIntermSymbol *leftSymbolNode = initNode->getLeft()->getAsSymbolNode();
2473         TIntermTyped *valueNode       = initNode->getRight()->getAsTyped();
2474         ASSERT(leftSymbolNode && valueNode);
2475 
2476         if (getRootNode() == getParentBlock())
2477         {
2478             // DeferGlobalInitializers should have turned non-const global initializers into
2479             // deferred initializers. Note that variables marked as EvqGlobal can be treated as
2480             // EvqConst in some ANGLE code but not actually have their qualifier actually changed to
2481             // EvqConst. Thus just assume all EvqGlobal are actually EvqConst for all code run after
2482             // DeferGlobalInitializers.
2483             mOut << "constant ";
2484         }
2485 
2486         const TVariable &var = leftSymbolNode->variable();
2487         const Name varName(var);
2488 
2489         if (ExpressionContainsName(varName, *valueNode))
2490         {
2491             mRenamedSymbols[&var] = mIdGen.createNewName(varName);
2492         }
2493 
2494         emitVariableDeclaration(VarDecl(var), evdConfig);
2495         mOut << " = ";
2496         groupedTraverse(*valueNode);
2497     }
2498     else
2499     {
2500         UNREACHABLE();
2501     }
2502 
2503     return false;
2504 }
2505 
visitLoop(Visit,TIntermLoop * loopNode)2506 bool GenMetalTraverser::visitLoop(Visit, TIntermLoop *loopNode)
2507 {
2508     const TLoopType loopType = loopNode->getType();
2509 
2510     switch (loopType)
2511     {
2512         case TLoopType::ELoopFor:
2513             return visitForLoop(loopNode);
2514         case TLoopType::ELoopWhile:
2515             return visitWhileLoop(loopNode);
2516         case TLoopType::ELoopDoWhile:
2517             return visitDoWhileLoop(loopNode);
2518     }
2519 }
2520 
visitForLoop(TIntermLoop * loopNode)2521 bool GenMetalTraverser::visitForLoop(TIntermLoop *loopNode)
2522 {
2523     ASSERT(loopNode->getType() == TLoopType::ELoopFor);
2524 
2525     TIntermNode *initNode  = loopNode->getInit();
2526     TIntermTyped *condNode = loopNode->getCondition();
2527     TIntermTyped *exprNode = loopNode->getExpression();
2528     TIntermBlock *bodyNode = loopNode->getBody();
2529     ASSERT(bodyNode);
2530 
2531     mOut << "for (";
2532 
2533     if (initNode)
2534     {
2535         initNode->traverse(this);
2536     }
2537     else
2538     {
2539         mOut << " ";
2540     }
2541 
2542     mOut << "; ";
2543 
2544     if (condNode)
2545     {
2546         condNode->traverse(this);
2547     }
2548 
2549     mOut << "; ";
2550 
2551     if (exprNode)
2552     {
2553         exprNode->traverse(this);
2554     }
2555 
2556     mOut << ")\n";
2557 
2558     bodyNode->traverse(this);
2559 
2560     return false;
2561 }
2562 
visitWhileLoop(TIntermLoop * loopNode)2563 bool GenMetalTraverser::visitWhileLoop(TIntermLoop *loopNode)
2564 {
2565     ASSERT(loopNode->getType() == TLoopType::ELoopWhile);
2566 
2567     TIntermNode *initNode  = loopNode->getInit();
2568     TIntermTyped *condNode = loopNode->getCondition();
2569     TIntermTyped *exprNode = loopNode->getExpression();
2570     TIntermBlock *bodyNode = loopNode->getBody();
2571     ASSERT(condNode && bodyNode);
2572     ASSERT(!initNode && !exprNode);
2573 
2574     emitIndentation();
2575     mOut << "while (";
2576     condNode->traverse(this);
2577     mOut << ")\n";
2578     bodyNode->traverse(this);
2579 
2580     return false;
2581 }
2582 
visitDoWhileLoop(TIntermLoop * loopNode)2583 bool GenMetalTraverser::visitDoWhileLoop(TIntermLoop *loopNode)
2584 {
2585     ASSERT(loopNode->getType() == TLoopType::ELoopDoWhile);
2586 
2587     TIntermNode *initNode  = loopNode->getInit();
2588     TIntermTyped *condNode = loopNode->getCondition();
2589     TIntermTyped *exprNode = loopNode->getExpression();
2590     TIntermBlock *bodyNode = loopNode->getBody();
2591     ASSERT(condNode && bodyNode);
2592     ASSERT(!initNode && !exprNode);
2593 
2594     emitIndentation();
2595     mOut << "do\n";
2596     bodyNode->traverse(this);
2597     mOut << "\n";
2598     emitIndentation();
2599     mOut << "while (";
2600     condNode->traverse(this);
2601     mOut << ");";
2602 
2603     return false;
2604 }
2605 
visitBranch(Visit,TIntermBranch * branchNode)2606 bool GenMetalTraverser::visitBranch(Visit, TIntermBranch *branchNode)
2607 {
2608     const TOperator flowOp = branchNode->getFlowOp();
2609     TIntermTyped *exprNode = branchNode->getExpression();
2610 
2611     emitIndentation();
2612 
2613     switch (flowOp)
2614     {
2615         case TOperator::EOpKill:
2616         {
2617             ASSERT(exprNode == nullptr);
2618             mOut << "metal::discard_fragment()";
2619         }
2620         break;
2621 
2622         case TOperator::EOpReturn:
2623         {
2624             if (isTraversingVertexMain)
2625             {
2626                 mOut << "#if TRANSFORM_FEEDBACK_ENABLED\n";
2627                 emitIndentation();
2628                 mOut << "return;\n";
2629                 emitIndentation();
2630                 mOut << "#else\n";
2631                 emitIndentation();
2632             }
2633             mOut << "return";
2634             if (exprNode)
2635             {
2636                 mOut << " ";
2637                 exprNode->traverse(this);
2638                 mOut << ";";
2639             }
2640             if (isTraversingVertexMain)
2641             {
2642                 mOut << "\n";
2643                 emitIndentation();
2644                 mOut << "#endif\n";
2645                 mTemporarilyDisableSemicolon = true;
2646             }
2647         }
2648         break;
2649 
2650         case TOperator::EOpBreak:
2651         {
2652             ASSERT(exprNode == nullptr);
2653             mOut << "break";
2654         }
2655         break;
2656 
2657         case TOperator::EOpContinue:
2658         {
2659             ASSERT(exprNode == nullptr);
2660             mOut << "continue";
2661         }
2662         break;
2663 
2664         default:
2665         {
2666             UNREACHABLE();
2667         }
2668     }
2669 
2670     return false;
2671 }
2672 
2673 static size_t emitMetalCallCount = 0;
2674 
EmitMetal(TCompiler & compiler,TIntermBlock & root,IdGen & idGen,const PipelineStructs & pipelineStructs,SymbolEnv & symbolEnv,const ProgramPreludeConfig & ppc,const ShCompileOptions & compileOptions)2675 bool sh::EmitMetal(TCompiler &compiler,
2676                    TIntermBlock &root,
2677                    IdGen &idGen,
2678                    const PipelineStructs &pipelineStructs,
2679                    SymbolEnv &symbolEnv,
2680                    const ProgramPreludeConfig &ppc,
2681                    const ShCompileOptions &compileOptions)
2682 {
2683     TInfoSinkBase &out = compiler.getInfoSink().obj;
2684 
2685     {
2686         ++emitMetalCallCount;
2687         std::string filenameProto = angle::GetEnvironmentVar("GMD_FIXED_EMIT");
2688         if (!filenameProto.empty())
2689         {
2690             if (filenameProto != "/dev/null")
2691             {
2692                 auto tryOpen = [&](char const *ext) {
2693                     auto filename = filenameProto;
2694                     filename += std::to_string(emitMetalCallCount);
2695                     filename += ".";
2696                     filename += ext;
2697                     return fopen(filename.c_str(), "rb");
2698                 };
2699                 FILE *file = tryOpen("metal");
2700                 if (!file)
2701                 {
2702                     file = tryOpen("cpp");
2703                 }
2704                 ASSERT(file);
2705 
2706                 fseek(file, 0, SEEK_END);
2707                 size_t fileSize = ftell(file);
2708                 fseek(file, 0, SEEK_SET);
2709 
2710                 std::vector<char> buff;
2711                 buff.resize(fileSize + 1);
2712                 fread(buff.data(), fileSize, 1, file);
2713                 buff.back() = '\0';
2714 
2715                 fclose(file);
2716 
2717                 out << buff.data();
2718             }
2719 
2720             return true;
2721         }
2722     }
2723 
2724     out << "\n\n";
2725 
2726     if (!EmitProgramPrelude(root, out, ppc))
2727     {
2728         return false;
2729     }
2730 
2731     {
2732 #if defined(ANGLE_ENABLE_ASSERTS)
2733         DebugSink outWrapper(out, angle::GetBoolEnvironmentVar("GMD_STDOUT"));
2734         outWrapper.watch(angle::GetEnvironmentVar("GMD_WATCH_STRING"));
2735 #else
2736         TInfoSinkBase &outWrapper = out;
2737 #endif
2738         GenMetalTraverser gen(compiler, outWrapper, idGen, pipelineStructs, symbolEnv,
2739                               compileOptions);
2740         root.traverse(&gen);
2741     }
2742 
2743     out << "\n";
2744 
2745     return true;
2746 }
2747