1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include <cctype>
8 #include <map>
9
10 #include "common/system_utils.h"
11 #include "compiler/translator/BaseTypes.h"
12 #include "compiler/translator/ImmutableStringBuilder.h"
13 #include "compiler/translator/SymbolTable.h"
14 #include "compiler/translator/TranslatorMetalDirect.h"
15 #include "compiler/translator/TranslatorMetalDirect/AstHelpers.h"
16 #include "compiler/translator/TranslatorMetalDirect/DebugSink.h"
17 #include "compiler/translator/TranslatorMetalDirect/EmitMetal.h"
18 #include "compiler/translator/TranslatorMetalDirect/Layout.h"
19 #include "compiler/translator/TranslatorMetalDirect/Name.h"
20 #include "compiler/translator/TranslatorMetalDirect/ProgramPrelude.h"
21 #include "compiler/translator/TranslatorMetalDirect/RewritePipelines.h"
22 #include "compiler/translator/tree_util/IntermTraverse.h"
23
24 using namespace sh;
25
26 ////////////////////////////////////////////////////////////////////////////////
27
28 #if defined(ANGLE_ENABLE_ASSERTS)
29 using Sink = DebugSink;
30 #else
31 using Sink = TInfoSinkBase;
32 #endif
33
34 ////////////////////////////////////////////////////////////////////////////////
35
36 namespace
37 {
38
39 struct VarDecl
40 {
VarDecl__anoncc8a75750111::VarDecl41 explicit VarDecl(const TVariable &var) : mVariable(&var), mIsField(false) {}
VarDecl__anoncc8a75750111::VarDecl42 explicit VarDecl(const TField &field) : mField(&field), mIsField(true) {}
43
variable__anoncc8a75750111::VarDecl44 ANGLE_INLINE const TVariable &variable() const
45 {
46 ASSERT(isVariable());
47 return *mVariable;
48 }
49
field__anoncc8a75750111::VarDecl50 ANGLE_INLINE const TField &field() const
51 {
52 ASSERT(isField());
53 return *mField;
54 }
55
isVariable__anoncc8a75750111::VarDecl56 ANGLE_INLINE bool isVariable() const { return !mIsField; }
57
isField__anoncc8a75750111::VarDecl58 ANGLE_INLINE bool isField() const { return mIsField; }
59
type__anoncc8a75750111::VarDecl60 const TType &type() const { return isField() ? *field().type() : variable().getType(); }
61
symbolType__anoncc8a75750111::VarDecl62 SymbolType symbolType() const
63 {
64 return isField() ? field().symbolType() : variable().symbolType();
65 }
66
67 private:
68 union
69 {
70 const TVariable *mVariable;
71 const TField *mField;
72 };
73 bool mIsField;
74 };
75
76 class GenMetalTraverser : public TIntermTraverser
77 {
78 public:
79 ~GenMetalTraverser() override;
80
81 GenMetalTraverser(const TCompiler &compiler,
82 Sink &out,
83 IdGen &idGen,
84 const PipelineStructs &pipelineStructs,
85 SymbolEnv &symbolEnv,
86 TSymbolTable *symbolTable);
87
88 void visitSymbol(TIntermSymbol *) override;
89 void visitConstantUnion(TIntermConstantUnion *) override;
90 bool visitSwizzle(Visit, TIntermSwizzle *) override;
91 bool visitBinary(Visit, TIntermBinary *) override;
92 bool visitUnary(Visit, TIntermUnary *) override;
93 bool visitTernary(Visit, TIntermTernary *) override;
94 bool visitIfElse(Visit, TIntermIfElse *) override;
95 bool visitSwitch(Visit, TIntermSwitch *) override;
96 bool visitCase(Visit, TIntermCase *) override;
97 void visitFunctionPrototype(TIntermFunctionPrototype *) override;
98 bool visitFunctionDefinition(Visit, TIntermFunctionDefinition *) override;
99 bool visitAggregate(Visit, TIntermAggregate *) override;
100 bool visitBlock(Visit, TIntermBlock *) override;
101 bool visitGlobalQualifierDeclaration(Visit, TIntermGlobalQualifierDeclaration *) override;
102 bool visitDeclaration(Visit, TIntermDeclaration *) override;
103 bool visitLoop(Visit, TIntermLoop *) override;
104 bool visitForLoop(TIntermLoop *);
105 bool visitWhileLoop(TIntermLoop *);
106 bool visitDoWhileLoop(TIntermLoop *);
107 bool visitBranch(Visit, TIntermBranch *) override;
108
109 private:
110 using FuncToName = std::map<ImmutableString, Name>;
111 static FuncToName BuildFuncToName();
112
113 struct EmitVariableDeclarationConfig
114 {
115 bool isParameter = false;
116 bool isMainParameter = false;
117 bool emitPostQualifier = false;
118 bool isPacked = false;
119 bool disableStructSpecifier = false;
120 bool isUBO = false;
121 const AddressSpace *isPointer = nullptr;
122 const AddressSpace *isReference = nullptr;
123 };
124
125 struct EmitTypeConfig
126 {
127 const EmitVariableDeclarationConfig *evdConfig = nullptr;
128 };
129
130 void emitIndentation();
131 void emitOpeningPointerParen();
132 void emitClosingPointerParen();
133 void emitFunctionSignature(const TFunction &func);
134 void emitFunctionReturn(const TFunction &func);
135 void emitFunctionParameter(const TFunction &func, const TVariable ¶m);
136
137 void emitNameOf(const TField &object);
138 void emitNameOf(const TSymbol &object);
139 void emitNameOf(const VarDecl &object);
140
141 void emitBareTypeName(const TType &type, const EmitTypeConfig &etConfig);
142 void emitType(const TType &type, const EmitTypeConfig &etConfig);
143 void emitPostQualifier(const EmitVariableDeclarationConfig &evdConfig,
144 const VarDecl &decl,
145 const TQualifier qualifier);
146
147 struct FieldAnnotationIndices
148 {
149 size_t attribute = 0;
150 size_t color = 0;
151 };
152
153 void emitFieldDeclaration(const TField &field,
154 const TStructure &parent,
155 FieldAnnotationIndices &annotationIndices);
156 void emitAttributeDeclaration(const TField &field, FieldAnnotationIndices &annotationIndices);
157 void emitUniformBufferDeclaration(const TField &field,
158 FieldAnnotationIndices &annotationIndices);
159 void emitStructDeclaration(const TType &type);
160 void emitOrdinaryVariableDeclaration(const VarDecl &decl,
161 const EmitVariableDeclarationConfig &evdConfig);
162 void emitVariableDeclaration(const VarDecl &decl,
163 const EmitVariableDeclarationConfig &evdConfig);
164
165 void emitOpenBrace();
166 void emitCloseBrace();
167
168 void groupedTraverse(TIntermNode &node);
169
170 const TField &getDirectField(const TFieldListCollection &fieldsNode,
171 const TConstantUnion &index);
172 const TField &getDirectField(const TIntermTyped &fieldsNode, TIntermTyped &indexNode);
173
174 const TConstantUnion *emitConstantUnionArray(const TConstantUnion *const constUnion,
175 const size_t size);
176
177 const TConstantUnion *emitConstantUnion(const TType &type, const TConstantUnion *constUnion);
178
179 void emitSingleConstant(const TConstantUnion *const constUnion);
180
181 private:
182 Sink &mOut;
183 const TCompiler &mCompiler;
184 const PipelineStructs &mPipelineStructs;
185 SymbolEnv &mSymbolEnv;
186 IdGen &mIdGen;
187 int mIndentLevel = -1;
188 int mLastIndentationPos = -1;
189 int mOpenPointerParenCount = 0;
190 bool mParentIsSwitch = false;
191 bool isTraversingVertexMain = false;
192 bool mTemporarilyDisableSemicolon = false;
193 std::unordered_map<const TSymbol *, Name> mRenamedSymbols;
194 const FuncToName mFuncToName = BuildFuncToName();
195 size_t mMainTextureIndex = 0;
196 size_t mMainSamplerIndex = 0;
197 size_t mMainUniformBufferIndex = 0;
198 size_t mDriverUniformsBindingIndex = 0;
199 size_t mUBOArgumentBufferBindingIndex = 0;
200 };
201 } // anonymous namespace
202
~GenMetalTraverser()203 GenMetalTraverser::~GenMetalTraverser()
204 {
205 ASSERT(mIndentLevel == -1);
206 ASSERT(!mParentIsSwitch);
207 ASSERT(mOpenPointerParenCount == 0);
208 }
209
GenMetalTraverser(const TCompiler & compiler,Sink & out,IdGen & idGen,const PipelineStructs & pipelineStructs,SymbolEnv & symbolEnv,TSymbolTable * symbolTable)210 GenMetalTraverser::GenMetalTraverser(const TCompiler &compiler,
211 Sink &out,
212 IdGen &idGen,
213 const PipelineStructs &pipelineStructs,
214 SymbolEnv &symbolEnv,
215 TSymbolTable *symbolTable)
216 : TIntermTraverser(true, false, false),
217 mOut(out),
218 mCompiler(compiler),
219 mPipelineStructs(pipelineStructs),
220 mSymbolEnv(symbolEnv),
221 mIdGen(idGen),
222 mMainUniformBufferIndex(symbolTable->getDefaultUniformsBindingIndex()),
223 mDriverUniformsBindingIndex(symbolTable->getDriverUniformsBindingIndex()),
224 mUBOArgumentBufferBindingIndex(symbolTable->getUBOArgumentBufferBindingIndex())
225 {}
226
emitIndentation()227 void GenMetalTraverser::emitIndentation()
228 {
229 ASSERT(mIndentLevel >= 0);
230
231 if (mLastIndentationPos == mOut.size())
232 {
233 return; // Line is already indented.
234 }
235
236 for (int i = 0; i < mIndentLevel; ++i)
237 {
238 mOut << " ";
239 }
240
241 mLastIndentationPos = mOut.size();
242 }
243
emitOpeningPointerParen()244 void GenMetalTraverser::emitOpeningPointerParen()
245 {
246 mOut << "(*";
247 mOpenPointerParenCount++;
248 }
249
emitClosingPointerParen()250 void GenMetalTraverser::emitClosingPointerParen()
251 {
252 if (mOpenPointerParenCount > 0)
253 {
254 mOut << ")";
255 mOpenPointerParenCount--;
256 }
257 }
258
GetOperatorString(TOperator op,const TType & resultType,const TType * argType0,const TType * argType1,const TType * argType2)259 static const char *GetOperatorString(TOperator op,
260 const TType &resultType,
261 const TType *argType0,
262 const TType *argType1,
263 const TType *argType2)
264 {
265 switch (op)
266 {
267 case TOperator::EOpComma:
268 return ",";
269 case TOperator::EOpAssign:
270 return "=";
271 case TOperator::EOpInitialize:
272 return "=";
273 case TOperator::EOpAddAssign:
274 return "+=";
275 case TOperator::EOpSubAssign:
276 return "-=";
277 case TOperator::EOpMulAssign:
278 return "*=";
279 case TOperator::EOpDivAssign:
280 return "/=";
281 case TOperator::EOpIModAssign:
282 return "%=";
283 case TOperator::EOpBitShiftLeftAssign:
284 return "<<="; // TODO: Check logical vs arithmetic shifting.
285 case TOperator::EOpBitShiftRightAssign:
286 return ">>="; // TODO: Check logical vs arithmetic shifting.
287 case TOperator::EOpBitwiseAndAssign:
288 return "&=";
289 case TOperator::EOpBitwiseXorAssign:
290 return "^=";
291 case TOperator::EOpBitwiseOrAssign:
292 return "|=";
293 case TOperator::EOpAdd:
294 return "+";
295 case TOperator::EOpSub:
296 return "-";
297 case TOperator::EOpMul:
298 return "*";
299 case TOperator::EOpDiv:
300 return "/";
301 case TOperator::EOpIMod:
302 return "%";
303 case TOperator::EOpBitShiftLeft:
304 return "<<"; // TODO: Check logical vs arithmetic shifting.
305 case TOperator::EOpBitShiftRight:
306 return ">>"; // TODO: Check logical vs arithmetic shifting.
307 case TOperator::EOpBitwiseAnd:
308 return "&";
309 case TOperator::EOpBitwiseXor:
310 return "^";
311 case TOperator::EOpBitwiseOr:
312 return "|";
313 case TOperator::EOpLessThan:
314 return "<";
315 case TOperator::EOpGreaterThan:
316 return ">";
317 case TOperator::EOpLessThanEqual:
318 return "<=";
319 case TOperator::EOpGreaterThanEqual:
320 return ">=";
321 case TOperator::EOpLessThanComponentWise:
322 return "<";
323 case TOperator::EOpLessThanEqualComponentWise:
324 return "<=";
325 case TOperator::EOpGreaterThanEqualComponentWise:
326 return ">=";
327 case TOperator::EOpGreaterThanComponentWise:
328 return ">";
329 case TOperator::EOpLogicalOr:
330 return "||";
331 case TOperator::EOpLogicalXor:
332 return "!=/*xor*/"; // XXX: This might need to be handled differently for some obtuse
333 // use case.
334 case TOperator::EOpLogicalAnd:
335 return "&&";
336 case TOperator::EOpNegative:
337 return "-";
338 case TOperator::EOpPositive:
339 if (argType0->isMatrix())
340 {
341 return "";
342 }
343 return "+";
344 case TOperator::EOpLogicalNot:
345 return "!";
346 case TOperator::EOpNotComponentWise:
347 return "!";
348 case TOperator::EOpBitwiseNot:
349 return "~";
350 case TOperator::EOpPostIncrement:
351 return "++";
352 case TOperator::EOpPostDecrement:
353 return "--";
354 case TOperator::EOpPreIncrement:
355 return "++";
356 case TOperator::EOpPreDecrement:
357 return "--";
358 case TOperator::EOpVectorTimesScalarAssign:
359 return "*=";
360 case TOperator::EOpVectorTimesMatrixAssign:
361 return "*=";
362 case TOperator::EOpMatrixTimesScalarAssign:
363 return "*=";
364 case TOperator::EOpMatrixTimesMatrixAssign:
365 return "*=";
366 case TOperator::EOpVectorTimesScalar:
367 return "*";
368 case TOperator::EOpVectorTimesMatrix:
369 return "*";
370 case TOperator::EOpMatrixTimesVector:
371 return "*";
372 case TOperator::EOpMatrixTimesScalar:
373 return "*";
374 case TOperator::EOpMatrixTimesMatrix:
375 return "*";
376 case TOperator::EOpEqualComponentWise:
377 return "==";
378 case TOperator::EOpNotEqualComponentWise:
379 return "!=";
380
381 case TOperator::EOpEqual:
382 if ((argType0->getStruct() && argType1->getStruct()) &&
383 (argType0->isArray() && argType1->isArray()))
384 {
385 return "ANGLE_equalStructArray";
386 }
387
388 if ((argType0->isVector() && argType1->isVector()) ||
389 (argType0->getStruct() && argType1->getStruct()) ||
390 (argType0->isArray() && argType1->isArray()) ||
391 (argType0->isMatrix() && argType1->isMatrix()))
392
393 {
394 return "ANGLE_equal";
395 }
396
397 return "==";
398
399 case TOperator::EOpNotEqual:
400 if ((argType0->getStruct() && argType1->getStruct()) &&
401 (argType0->isArray() && argType1->isArray()))
402 {
403 return "ANGLE_notEqualStructArray";
404 }
405
406 if ((argType0->isVector() && argType1->isVector()) ||
407 (argType0->isArray() && argType1->isArray()) ||
408 (argType0->isMatrix() && argType1->isMatrix()))
409 {
410 return "ANGLE_notEqual";
411 }
412 else if (argType0->getStruct() && argType1->getStruct())
413 {
414 return "ANGLE_notEqualStruct";
415 }
416 return "!=";
417
418 case TOperator::EOpKill:
419 UNIMPLEMENTED();
420 return "kill";
421 case TOperator::EOpReturn:
422 return "return";
423 case TOperator::EOpBreak:
424 return "break";
425 case TOperator::EOpContinue:
426 return "continue";
427
428 case TOperator::EOpRadians:
429 return "ANGLE_radians";
430 case TOperator::EOpDegrees:
431 return "ANGLE_degrees";
432 case TOperator::EOpAtan:
433 return "ANGLE_atan";
434 case TOperator::EOpMod:
435 return "ANGLE_mod"; // differs from metal::mod
436 case TOperator::EOpRefract:
437 return "ANGLE_refract";
438 case TOperator::EOpDistance:
439 return "ANGLE_distance";
440 case TOperator::EOpLength:
441 return "ANGLE_length";
442 case TOperator::EOpDot:
443 return "ANGLE_dot";
444 case TOperator::EOpNormalize:
445 return "ANGLE_normalize";
446 case TOperator::EOpFaceforward:
447 return "ANGLE_faceforward";
448 case TOperator::EOpReflect:
449 return "ANGLE_reflect";
450 case TOperator::EOpMatrixCompMult:
451 return "ANGLE_componentWiseMultiply";
452 case TOperator::EOpOuterProduct:
453 return "ANGLE_outerProduct";
454 case TOperator::EOpSign:
455 return "ANGLE_sign";
456
457 case TOperator::EOpAbs:
458 return "metal::abs";
459 case TOperator::EOpAll:
460 return "metal::all";
461 case TOperator::EOpAny:
462 return "metal::any";
463 case TOperator::EOpSin:
464 return "metal::sin";
465 case TOperator::EOpCos:
466 return "metal::cos";
467 case TOperator::EOpTan:
468 return "metal::tan";
469 case TOperator::EOpAsin:
470 return "metal::asin";
471 case TOperator::EOpAcos:
472 return "metal::acos";
473 case TOperator::EOpSinh:
474 return "metal::sinh";
475 case TOperator::EOpCosh:
476 return "metal::cosh";
477 case TOperator::EOpTanh:
478 return "metal::tanh";
479 case TOperator::EOpAsinh:
480 return "metal::asinh";
481 case TOperator::EOpAcosh:
482 return "metal::acosh";
483 case TOperator::EOpAtanh:
484 return "metal::atanh";
485 case TOperator::EOpFma:
486 return "metal::fma";
487 case TOperator::EOpPow:
488 return "metal::pow";
489 case TOperator::EOpExp:
490 return "metal::exp";
491 case TOperator::EOpExp2:
492 return "metal::exp2";
493 case TOperator::EOpLog:
494 return "metal::log";
495 case TOperator::EOpLog2:
496 return "metal::log2";
497 case TOperator::EOpSqrt:
498 return "metal::sqrt";
499 case TOperator::EOpFloor:
500 return "metal::floor";
501 case TOperator::EOpTrunc:
502 return "metal::trunc";
503 case TOperator::EOpCeil:
504 return "metal::ceil";
505 case TOperator::EOpFract:
506 return "metal::fract";
507 case TOperator::EOpMin:
508 return "metal::min";
509 case TOperator::EOpMax:
510 return "metal::max";
511 case TOperator::EOpRound:
512 return "metal::round";
513 case TOperator::EOpRoundEven:
514 return "metal::rint";
515 case TOperator::EOpClamp:
516 return "metal::clamp"; // TODO fast vs precise namespace
517 case TOperator::EOpMix:
518 if (argType2 && argType2->getBasicType() == EbtBool)
519 return "ANGLE_mix_bool";
520 return "metal::mix";
521 case TOperator::EOpStep:
522 return "metal::step";
523 case TOperator::EOpSmoothstep:
524 return "metal::smoothstep";
525 case TOperator::EOpModf:
526 return "metal::modf";
527 case TOperator::EOpIsnan:
528 return "metal::isnan";
529 case TOperator::EOpIsinf:
530 return "metal::isinf";
531 case TOperator::EOpLdexp:
532 return "metal::ldexp";
533 case TOperator::EOpFrexp:
534 return "metal::frexp";
535 case TOperator::EOpInversesqrt:
536 return "metal::rsqrt";
537 case TOperator::EOpCross:
538 return "metal::cross";
539 case TOperator::EOpDFdx:
540 return "metal::dfdx";
541 case TOperator::EOpDFdy:
542 return "metal::dfdy";
543 case TOperator::EOpFwidth:
544 return "metal::fwidth";
545 case TOperator::EOpTranspose:
546 return "metal::transpose";
547 case TOperator::EOpDeterminant:
548 return "metal::determinant";
549
550 case TOperator::EOpInverse:
551 return "ANGLE_inverse";
552
553 case TOperator::EOpFloatBitsToInt:
554 case TOperator::EOpFloatBitsToUint:
555 case TOperator::EOpIntBitsToFloat:
556 case TOperator::EOpUintBitsToFloat:
557 {
558 #define RETURN_AS_TYPE(post) \
559 do \
560 switch (resultType.getBasicType()) \
561 { \
562 case TBasicType::EbtInt: \
563 return "as_type<int" post ">"; \
564 case TBasicType::EbtUInt: \
565 return "as_type<uint" post ">"; \
566 case TBasicType::EbtFloat: \
567 return "as_type<float" post ">"; \
568 default: \
569 UNIMPLEMENTED(); \
570 return "TOperator_TODO"; \
571 } \
572 while (false)
573
574 if (resultType.isScalar())
575 {
576 RETURN_AS_TYPE("");
577 }
578 else if (resultType.isVector())
579 {
580 switch (resultType.getNominalSize())
581 {
582 case 2:
583 RETURN_AS_TYPE("2");
584 case 3:
585 RETURN_AS_TYPE("3");
586 case 4:
587 RETURN_AS_TYPE("4");
588 default:
589 UNREACHABLE();
590 return nullptr;
591 }
592 }
593 else
594 {
595 UNIMPLEMENTED();
596 return "TOperator_TODO";
597 }
598
599 #undef RETURN_AS_TYPE
600 }
601
602 case TOperator::EOpPackUnorm2x16:
603 return "metal::pack_float_to_unorm2x16";
604 case TOperator::EOpPackSnorm2x16:
605 return "metal::pack_float_to_snorm2x16";
606
607 case TOperator::EOpPackUnorm4x8:
608 return "metal::pack_float_to_unorm4x8";
609 case TOperator::EOpPackSnorm4x8:
610 return "metal::pack_float_to_snorm4x8";
611
612 case TOperator::EOpUnpackUnorm2x16:
613 return "metal::unpack_unorm2x16_to_float";
614 case TOperator::EOpUnpackSnorm2x16:
615 return "metal::unpack_snorm2x16_to_float";
616
617 case TOperator::EOpUnpackUnorm4x8:
618 return "metal::unpack_unorm4x8_to_float";
619 case TOperator::EOpUnpackSnorm4x8:
620 return "metal::unpack_snorm4x8_to_float";
621
622 case TOperator::EOpPackHalf2x16:
623 return "ANGLE_pack_half_2x16";
624 case TOperator::EOpUnpackHalf2x16:
625 return "ANGLE_unpack_half_2x16";
626
627 case TOperator::EOpBitfieldExtract:
628 case TOperator::EOpBitfieldInsert:
629 case TOperator::EOpBitfieldReverse:
630 case TOperator::EOpBitCount:
631 case TOperator::EOpFindLSB:
632 case TOperator::EOpFindMSB:
633 case TOperator::EOpUaddCarry:
634 case TOperator::EOpUsubBorrow:
635 case TOperator::EOpUmulExtended:
636 case TOperator::EOpImulExtended:
637 case TOperator::EOpBarrier:
638 case TOperator::EOpMemoryBarrier:
639 case TOperator::EOpMemoryBarrierAtomicCounter:
640 case TOperator::EOpMemoryBarrierBuffer:
641 case TOperator::EOpMemoryBarrierImage:
642 case TOperator::EOpMemoryBarrierShared:
643 case TOperator::EOpGroupMemoryBarrier:
644 case TOperator::EOpAtomicAdd:
645 case TOperator::EOpAtomicMin:
646 case TOperator::EOpAtomicMax:
647 case TOperator::EOpAtomicAnd:
648 case TOperator::EOpAtomicOr:
649 case TOperator::EOpAtomicXor:
650 case TOperator::EOpAtomicExchange:
651 case TOperator::EOpAtomicCompSwap:
652 case TOperator::EOpEmitVertex:
653 case TOperator::EOpEndPrimitive:
654 case TOperator::EOpFtransform:
655 case TOperator::EOpPackDouble2x32:
656 case TOperator::EOpUnpackDouble2x32:
657 case TOperator::EOpArrayLength:
658 UNIMPLEMENTED();
659 return "TOperator_TODO";
660
661 case TOperator::EOpNull:
662 case TOperator::EOpConstruct:
663 case TOperator::EOpCallFunctionInAST:
664 case TOperator::EOpCallInternalRawFunction:
665 case TOperator::EOpIndexDirect:
666 case TOperator::EOpIndexIndirect:
667 case TOperator::EOpIndexDirectStruct:
668 case TOperator::EOpIndexDirectInterfaceBlock:
669 UNREACHABLE();
670 return nullptr;
671 default:
672 // Any other built-in function.
673 return nullptr;
674 }
675 }
676
IsSymbolicOperator(TOperator op,const TType & resultType,const TType * argType0,const TType * argType1)677 static bool IsSymbolicOperator(TOperator op,
678 const TType &resultType,
679 const TType *argType0,
680 const TType *argType1)
681 {
682 const char *operatorString = GetOperatorString(op, resultType, argType0, argType1, nullptr);
683 if (operatorString == nullptr)
684 {
685 return false;
686 }
687 return !std::isalnum(operatorString[0]);
688 }
689
AsSpecificBinaryNode(TIntermNode & node,TOperator op)690 static TIntermBinary *AsSpecificBinaryNode(TIntermNode &node, TOperator op)
691 {
692 TIntermBinary *binaryNode = node.getAsBinaryNode();
693 if (binaryNode)
694 {
695 return binaryNode->getOp() == op ? binaryNode : nullptr;
696 }
697 return nullptr;
698 }
699
Parenthesize(TIntermNode & node)700 static bool Parenthesize(TIntermNode &node)
701 {
702 if (node.getAsSymbolNode())
703 {
704 return false;
705 }
706 if (node.getAsConstantUnion())
707 {
708 return false;
709 }
710 if (node.getAsAggregate())
711 {
712 return false;
713 }
714 if (node.getAsSwizzleNode())
715 {
716 return false;
717 }
718
719 if (TIntermUnary *unaryNode = node.getAsUnaryNode())
720 {
721 // TODO: Use a precedence and associativity rules instead of this ad-hoc impl.
722 const TType &resultType = unaryNode->getType();
723 const TType &argType = unaryNode->getOperand()->getType();
724 return IsSymbolicOperator(unaryNode->getOp(), resultType, &argType, nullptr);
725 }
726
727 if (TIntermBinary *binaryNode = node.getAsBinaryNode())
728 {
729 // TODO: Use a precedence and associativity rules instead of this ad-hoc impl.
730 const TOperator op = binaryNode->getOp();
731 switch (op)
732 {
733 case TOperator::EOpIndexDirectStruct:
734 case TOperator::EOpIndexDirectInterfaceBlock:
735 case TOperator::EOpIndexDirect:
736 case TOperator::EOpIndexIndirect:
737 return Parenthesize(*binaryNode->getLeft());
738
739 case TOperator::EOpAssign:
740 case TOperator::EOpInitialize:
741 return AsSpecificBinaryNode(*binaryNode->getRight(), TOperator::EOpComma);
742
743 default:
744 {
745 const TType &resultType = binaryNode->getType();
746 const TType &leftType = binaryNode->getLeft()->getType();
747 const TType &rightType = binaryNode->getRight()->getType();
748 return IsSymbolicOperator(binaryNode->getOp(), resultType, &leftType, &rightType);
749 }
750 }
751 }
752
753 return true;
754 }
755
groupedTraverse(TIntermNode & node)756 void GenMetalTraverser::groupedTraverse(TIntermNode &node)
757 {
758 const bool emitParens = Parenthesize(node);
759
760 if (emitParens)
761 {
762 mOut << "(";
763 }
764
765 node.traverse(this);
766
767 if (emitParens)
768 {
769 mOut << ")";
770 }
771 }
772
emitPostQualifier(const EmitVariableDeclarationConfig & evdConfig,const VarDecl & decl,const TQualifier qualifier)773 void GenMetalTraverser::emitPostQualifier(const EmitVariableDeclarationConfig &evdConfig,
774 const VarDecl &decl,
775 const TQualifier qualifier)
776 {
777 bool isInvariant = false;
778 switch (qualifier)
779 {
780 case TQualifier::EvqPosition:
781 isInvariant = decl.type().isInvariant();
782 ANGLE_FALLTHROUGH;
783 case TQualifier::EvqFragCoord:
784 mOut << " [[position]]";
785 break;
786
787 case TQualifier::EvqPointSize:
788 mOut << " [[point_size]]";
789 break;
790
791 case TQualifier::EvqVertexID:
792 if (evdConfig.isMainParameter)
793 {
794 mOut << " [[vertex_id]]";
795 }
796 break;
797
798 case TQualifier::EvqPointCoord:
799 if (evdConfig.isMainParameter)
800 {
801 mOut << " [[point_coord]]";
802 }
803 break;
804
805 case TQualifier::EvqFrontFacing:
806 if (evdConfig.isMainParameter)
807 {
808 mOut << " [[front_facing]]";
809 }
810 break;
811
812 default:
813 break;
814 }
815
816 if (isInvariant)
817 {
818 mOut << " [[invariant]]";
819
820 TranslatorMetalReflection *reflection = mtl::getTranslatorMetalReflection(&mCompiler);
821 reflection->hasInvariance = true;
822 }
823 }
824
EmitName(Sink & out,const Name & name)825 static void EmitName(Sink &out, const Name &name)
826 {
827 #if defined(ANGLE_ENABLE_ASSERTS)
828 DebugSink::EscapedSink escapedOut(out.escape());
829 #else
830 TInfoSinkBase &escapedOut = out;
831 #endif
832 name.emit(escapedOut);
833 }
834
emitNameOf(const TField & object)835 void GenMetalTraverser::emitNameOf(const TField &object)
836 {
837 EmitName(mOut, Name(object));
838 }
839
emitNameOf(const TSymbol & object)840 void GenMetalTraverser::emitNameOf(const TSymbol &object)
841 {
842 auto it = mRenamedSymbols.find(&object);
843 if (it == mRenamedSymbols.end())
844 {
845 EmitName(mOut, Name(object));
846 }
847 else
848 {
849 EmitName(mOut, it->second);
850 }
851 }
852
emitNameOf(const VarDecl & object)853 void GenMetalTraverser::emitNameOf(const VarDecl &object)
854 {
855 if (object.isField())
856 {
857 emitNameOf(object.field());
858 }
859 else
860 {
861 emitNameOf(object.variable());
862 }
863 }
864
emitBareTypeName(const TType & type,const EmitTypeConfig & etConfig)865 void GenMetalTraverser::emitBareTypeName(const TType &type, const EmitTypeConfig &etConfig)
866 {
867 const TBasicType basicType = type.getBasicType();
868
869 switch (basicType)
870 {
871 case TBasicType::EbtVoid:
872 case TBasicType::EbtBool:
873 case TBasicType::EbtFloat:
874 case TBasicType::EbtInt:
875 case TBasicType::EbtUInt:
876 {
877 mOut << type.getBasicString();
878 }
879 break;
880
881 case TBasicType::EbtStruct:
882 {
883 const TStructure &structure = *type.getStruct();
884 emitNameOf(structure);
885 }
886 break;
887
888 case TBasicType::EbtInterfaceBlock:
889 {
890 const TInterfaceBlock &interfaceBlock = *type.getInterfaceBlock();
891 emitNameOf(interfaceBlock);
892 }
893 break;
894
895 default:
896 {
897 if (IsSampler(basicType))
898 {
899 if (etConfig.evdConfig && etConfig.evdConfig->isMainParameter)
900 {
901 EmitName(mOut, GetTextureTypeName(basicType));
902 }
903 else
904 {
905 const TStructure &env = mSymbolEnv.getTextureEnv(basicType);
906 emitNameOf(env);
907 }
908 }
909 else
910 {
911 UNIMPLEMENTED();
912 }
913 }
914 }
915 }
916
emitType(const TType & type,const EmitTypeConfig & etConfig)917 void GenMetalTraverser::emitType(const TType &type, const EmitTypeConfig &etConfig)
918 {
919 const bool isUBO = etConfig.evdConfig ? etConfig.evdConfig->isUBO : false;
920 if (etConfig.evdConfig)
921 {
922 const auto &evdConfig = *etConfig.evdConfig;
923 if (isUBO)
924 {
925 if (type.isArray())
926 {
927 mOut << "ANGLE_tensor<";
928 }
929 }
930 if (evdConfig.isPointer)
931 {
932 mOut << toString(*evdConfig.isPointer);
933 mOut << " ";
934 }
935 else if (evdConfig.isReference)
936 {
937 mOut << toString(*evdConfig.isReference);
938 mOut << " ";
939 }
940 }
941
942 if (!isUBO)
943 {
944 if (type.isArray())
945 {
946 mOut << "ANGLE_tensor<";
947 }
948 }
949
950 if (type.isVector() || type.isMatrix())
951 {
952 mOut << "metal::";
953 }
954
955 if (etConfig.evdConfig && etConfig.evdConfig->isPacked)
956 {
957 mOut << "packed_";
958 }
959
960 emitBareTypeName(type, etConfig);
961
962 if (type.isVector())
963 {
964 mOut << type.getNominalSize();
965 }
966 else if (type.isMatrix())
967 {
968 mOut << type.getCols() << "x" << type.getRows();
969 }
970
971 if (!isUBO)
972 {
973 if (type.isArray())
974 {
975 for (auto size : type.getArraySizes())
976 {
977 mOut << ", " << size;
978 }
979 mOut << ">";
980 }
981 }
982
983 if (etConfig.evdConfig)
984 {
985 const auto &evdConfig = *etConfig.evdConfig;
986 if (evdConfig.isPointer)
987 {
988 mOut << " *";
989 }
990 else if (evdConfig.isReference)
991 {
992 mOut << " &";
993 }
994 if (isUBO)
995 {
996 if (type.isArray())
997 {
998 for (auto size : type.getArraySizes())
999 {
1000 mOut << ", " << size;
1001 }
1002 mOut << ">";
1003 }
1004 }
1005 }
1006 }
1007
emitFieldDeclaration(const TField & field,const TStructure & parent,FieldAnnotationIndices & annotationIndices)1008 void GenMetalTraverser::emitFieldDeclaration(const TField &field,
1009 const TStructure &parent,
1010 FieldAnnotationIndices &annotationIndices)
1011 {
1012 const TType &type = *field.type();
1013 const TBasicType basic = type.getBasicType();
1014
1015 EmitVariableDeclarationConfig evdConfig;
1016 evdConfig.emitPostQualifier = true;
1017 evdConfig.disableStructSpecifier = true;
1018 evdConfig.isPacked = mSymbolEnv.isPacked(field);
1019 evdConfig.isUBO = mSymbolEnv.isUBO(field);
1020 evdConfig.isPointer = mSymbolEnv.isPointer(field);
1021 evdConfig.isReference = mSymbolEnv.isReference(field);
1022 emitVariableDeclaration(VarDecl(field), evdConfig);
1023
1024 const TQualifier qual = type.getQualifier();
1025 switch (qual)
1026 {
1027 case TQualifier::EvqFlatIn:
1028 if (mPipelineStructs.fragmentIn.external == &parent)
1029 {
1030 mOut << " [[flat]]";
1031 TranslatorMetalReflection *reflection =
1032 mtl::getTranslatorMetalReflection(&mCompiler);
1033 reflection->hasFlatInput = true;
1034 }
1035 break;
1036
1037 case TQualifier::EvqFragmentOut:
1038 case TQualifier::EvqFragData:
1039 if (mPipelineStructs.fragmentOut.external == &parent)
1040 {
1041 if ((type.isVector() &&
1042 (basic == TBasicType::EbtInt || basic == TBasicType::EbtUInt ||
1043 basic == TBasicType::EbtFloat)) ||
1044 type.getQualifier() == EvqFragData)
1045 {
1046 // The OpenGL ES 3.0 spec says locations must be specified
1047 // unless there is only a single output, in which case the
1048 // location is 0. So, when we get to this point the shader
1049 // will have been rejected if locations are not specified
1050 // and there is more than one output.
1051 const TLayoutQualifier &layoutQualifier = type.getLayoutQualifier();
1052 size_t index = layoutQualifier.locationsSpecified ? layoutQualifier.location
1053 : annotationIndices.color++;
1054 mOut << " [[color(" << index << ")]]";
1055 }
1056 }
1057 break;
1058
1059 case TQualifier::EvqFragDepth:
1060 mOut << " [[depth(any)]]";
1061 break;
1062
1063 case TQualifier::EvqSampleMask:
1064 mOut << " [[sample_mask, function_constant(" << sh::mtl::kCoverageMaskEnabledConstName
1065 << ")]]";
1066 break;
1067
1068 default:
1069 break;
1070 }
1071 }
1072
BuildExternalAttributeIndexMap(const TCompiler & compiler,const PipelineScoped<TStructure> & structure)1073 static std::map<Name, size_t> BuildExternalAttributeIndexMap(
1074 const TCompiler &compiler,
1075 const PipelineScoped<TStructure> &structure)
1076 {
1077 ASSERT(structure.isTotallyFull());
1078
1079 const auto &shaderVars = compiler.getAttributes();
1080 const size_t shaderVarSize = shaderVars.size();
1081 size_t shaderVarIndex = 0;
1082
1083 const auto &externalFields = structure.external->fields();
1084 const size_t externalSize = externalFields.size();
1085 size_t externalIndex = 0;
1086
1087 const auto &internalFields = structure.internal->fields();
1088 const size_t internalSize = internalFields.size();
1089 size_t internalIndex = 0;
1090
1091 // Internal fields are never split. External fields are sometimes split.
1092 ASSERT(externalSize >= internalSize);
1093
1094 // Structures do not contain any inactive fields.
1095 ASSERT(shaderVarSize >= internalSize);
1096
1097 std::map<Name, size_t> externalNameToAttributeIndex;
1098 size_t attributeIndex = 0;
1099
1100 while (internalIndex < internalSize)
1101 {
1102 const TField &internalField = *internalFields[internalIndex];
1103 const Name internalName = Name(internalField);
1104 const TType &internalType = *internalField.type();
1105 while (internalName.rawName() != shaderVars[shaderVarIndex].name &&
1106 internalName.rawName() != shaderVars[shaderVarIndex].mappedName)
1107 {
1108 // This case represents an inactive field.
1109
1110 ++shaderVarIndex;
1111 ASSERT(shaderVarIndex < shaderVarSize);
1112
1113 ++attributeIndex; // TODO: Might need to increment more if shader var type is a matrix.
1114 }
1115
1116 const size_t cols = internalType.isMatrix() ? internalType.getCols() : 1;
1117
1118 for (size_t c = 0; c < cols; ++c)
1119 {
1120 const TField &externalField = *externalFields[externalIndex];
1121 const Name externalName = Name(externalField);
1122 ASSERT(!externalField.type()->isMatrix());
1123
1124 externalNameToAttributeIndex[externalName] = attributeIndex;
1125
1126 ++externalIndex;
1127 ++attributeIndex;
1128 }
1129
1130 ++shaderVarIndex;
1131 ++internalIndex;
1132 }
1133
1134 ASSERT(shaderVarIndex <= shaderVarSize);
1135 ASSERT(externalIndex <= externalSize); // less than if padding was introduced
1136 ASSERT(internalIndex == internalSize);
1137
1138 return externalNameToAttributeIndex;
1139 }
1140
emitAttributeDeclaration(const TField & field,FieldAnnotationIndices & annotationIndices)1141 void GenMetalTraverser::emitAttributeDeclaration(const TField &field,
1142 FieldAnnotationIndices &annotationIndices)
1143 {
1144 EmitVariableDeclarationConfig evdConfig;
1145 evdConfig.disableStructSpecifier = true;
1146 emitVariableDeclaration(VarDecl(field), evdConfig);
1147 mOut << sh::kUnassignedAttributeString;
1148 }
1149
emitUniformBufferDeclaration(const TField & field,FieldAnnotationIndices & annotationIndices)1150 void GenMetalTraverser::emitUniformBufferDeclaration(const TField &field,
1151 FieldAnnotationIndices &annotationIndices)
1152 {
1153 EmitVariableDeclarationConfig evdConfig;
1154 evdConfig.disableStructSpecifier = true;
1155 evdConfig.isUBO = mSymbolEnv.isUBO(field);
1156 evdConfig.isPointer = mSymbolEnv.isPointer(field);
1157 emitVariableDeclaration(VarDecl(field), evdConfig);
1158 mOut << "[[id(" << annotationIndices.attribute << ")]]";
1159
1160 const TType &type = *field.type();
1161 const int arraySize = type.isArray() ? type.getArraySizeProduct() : 1;
1162
1163 TranslatorMetalReflection *reflection = mtl::getTranslatorMetalReflection(&mCompiler);
1164 ASSERT(type.getBasicType() == TBasicType::EbtStruct);
1165 const TStructure *structure = type.getStruct();
1166 const std::string originalName = reflection->getOriginalName(structure->uniqueId().get());
1167 reflection->addUniformBufferBinding(
1168 originalName,
1169 {.bindIndex = annotationIndices.attribute, .arraySize = static_cast<size_t>(arraySize)});
1170
1171 annotationIndices.attribute += arraySize;
1172 }
1173
emitStructDeclaration(const TType & type)1174 void GenMetalTraverser::emitStructDeclaration(const TType &type)
1175 {
1176 ASSERT(type.getBasicType() == TBasicType::EbtStruct);
1177 ASSERT(type.isStructSpecifier());
1178
1179 mOut << "struct ";
1180 emitBareTypeName(type, {});
1181
1182 mOut << "\n";
1183 emitOpenBrace();
1184
1185 const TStructure &structure = *type.getStruct();
1186 std::map<Name, size_t> fieldToAttributeIndex;
1187 const bool hasAttributeIndices = mPipelineStructs.vertexIn.external == &structure;
1188 const bool hasUniformBufferIndicies = mPipelineStructs.uniformBuffers.external == &structure;
1189 const bool reclaimUnusedAttributeIndices = mCompiler.getShaderVersion() < 300;
1190
1191 if (hasAttributeIndices)
1192 {
1193 fieldToAttributeIndex =
1194 BuildExternalAttributeIndexMap(mCompiler, mPipelineStructs.vertexIn);
1195 }
1196
1197 FieldAnnotationIndices annotationIndices;
1198
1199 for (const TField *field : structure.fields())
1200 {
1201 emitIndentation();
1202 if (hasAttributeIndices)
1203 {
1204 const auto it = fieldToAttributeIndex.find(Name(*field));
1205 if (it == fieldToAttributeIndex.end())
1206 {
1207 ASSERT(field->symbolType() == SymbolType::AngleInternal);
1208 ASSERT(field->name().beginsWith("_"));
1209 ASSERT(angle::EndsWith(field->name().data(), "_pad"));
1210 emitFieldDeclaration(*field, structure, annotationIndices);
1211 }
1212 else
1213 {
1214 ASSERT(field->symbolType() != SymbolType::AngleInternal ||
1215 !field->name().beginsWith("_") ||
1216 !angle::EndsWith(field->name().data(), "_pad"));
1217 if (!reclaimUnusedAttributeIndices)
1218 {
1219 annotationIndices.attribute = it->second;
1220 }
1221 emitAttributeDeclaration(*field, annotationIndices);
1222 }
1223 }
1224 else if (hasUniformBufferIndicies)
1225 {
1226 emitUniformBufferDeclaration(*field, annotationIndices);
1227 }
1228 else
1229 {
1230 emitFieldDeclaration(*field, structure, annotationIndices);
1231 }
1232 mOut << ";\n";
1233 }
1234
1235 if (!mPipelineStructs.matches(structure, true, true))
1236 {
1237 MetalLayoutOfConfig layoutConfig;
1238 layoutConfig.treatSamplersAsTextureEnv = true;
1239 Layout layout = MetalLayoutOf(type, layoutConfig);
1240 size_t pad = (kDefaultStructAlignmentSize - layout.sizeOf) % kDefaultStructAlignmentSize;
1241 if (pad != 0)
1242 {
1243 emitIndentation();
1244 mOut << "char ";
1245 EmitName(mOut, mIdGen.createNewName("pad"));
1246 mOut << "[" << pad << "];\n";
1247 }
1248 }
1249
1250 emitCloseBrace();
1251 }
1252
emitOrdinaryVariableDeclaration(const VarDecl & decl,const EmitVariableDeclarationConfig & evdConfig)1253 void GenMetalTraverser::emitOrdinaryVariableDeclaration(
1254 const VarDecl &decl,
1255 const EmitVariableDeclarationConfig &evdConfig)
1256 {
1257 EmitTypeConfig etConfig;
1258 etConfig.evdConfig = &evdConfig;
1259
1260 const TType &type = decl.type();
1261 emitType(type, etConfig);
1262 if (decl.symbolType() != SymbolType::Empty)
1263 {
1264 mOut << " ";
1265 emitNameOf(decl);
1266 }
1267 }
1268
emitVariableDeclaration(const VarDecl & decl,const EmitVariableDeclarationConfig & evdConfig)1269 void GenMetalTraverser::emitVariableDeclaration(const VarDecl &decl,
1270 const EmitVariableDeclarationConfig &evdConfig)
1271 {
1272 const SymbolType symbolType = decl.symbolType();
1273 const TType &type = decl.type();
1274 const TBasicType basicType = type.getBasicType();
1275
1276 switch (basicType)
1277 {
1278 case TBasicType::EbtStruct:
1279 {
1280 if (type.isStructSpecifier() && !evdConfig.disableStructSpecifier)
1281 {
1282 ASSERT(!evdConfig.isParameter);
1283 emitStructDeclaration(type);
1284 if (symbolType != SymbolType::Empty)
1285 {
1286 mOut << " ";
1287 emitNameOf(decl);
1288 }
1289 }
1290 else
1291 {
1292 emitOrdinaryVariableDeclaration(decl, evdConfig);
1293 }
1294 }
1295 break;
1296
1297 default:
1298 {
1299 ASSERT(symbolType != SymbolType::Empty || evdConfig.isParameter);
1300 emitOrdinaryVariableDeclaration(decl, evdConfig);
1301 }
1302 }
1303
1304 if (evdConfig.emitPostQualifier)
1305 {
1306 emitPostQualifier(evdConfig, decl, type.getQualifier());
1307 }
1308 }
1309
visitSymbol(TIntermSymbol * symbolNode)1310 void GenMetalTraverser::visitSymbol(TIntermSymbol *symbolNode)
1311 {
1312 const TVariable &var = symbolNode->variable();
1313 const TType &type = var.getType();
1314 ASSERT(var.symbolType() != SymbolType::Empty);
1315
1316 if (type.getBasicType() == TBasicType::EbtVoid)
1317 {
1318 mOut << "/*";
1319 emitNameOf(var);
1320 mOut << "*/";
1321 }
1322 else
1323 {
1324 emitNameOf(var);
1325 }
1326 }
1327
emitSingleConstant(const TConstantUnion * const constUnion)1328 void GenMetalTraverser::emitSingleConstant(const TConstantUnion *const constUnion)
1329 {
1330 switch (constUnion->getType())
1331 {
1332 case TBasicType::EbtBool:
1333 {
1334 mOut << (constUnion->getBConst() ? "true" : "false");
1335 }
1336 break;
1337
1338 case TBasicType::EbtFloat:
1339 {
1340 float value = constUnion->getFConst();
1341 if (std::isnan(value))
1342 {
1343 mOut << "NAN";
1344 }
1345 else if (std::isinf(value))
1346 {
1347 if (value < 0)
1348 {
1349 mOut << "-";
1350 }
1351 mOut << "INFINITY";
1352 }
1353 else
1354 {
1355 mOut << value << "f";
1356 }
1357 }
1358 break;
1359
1360 case TBasicType::EbtInt:
1361 {
1362 mOut << constUnion->getIConst();
1363 }
1364 break;
1365
1366 case TBasicType::EbtUInt:
1367 {
1368 mOut << constUnion->getUConst() << "u";
1369 }
1370 break;
1371
1372 default:
1373 {
1374 UNIMPLEMENTED();
1375 }
1376 }
1377 }
1378
emitConstantUnionArray(const TConstantUnion * const constUnion,const size_t size)1379 const TConstantUnion *GenMetalTraverser::emitConstantUnionArray(
1380 const TConstantUnion *const constUnion,
1381 const size_t size)
1382 {
1383 const TConstantUnion *constUnionIterated = constUnion;
1384 for (size_t i = 0; i < size; i++, constUnionIterated++)
1385 {
1386 emitSingleConstant(constUnionIterated);
1387
1388 if (i != size - 1)
1389 {
1390 mOut << ", ";
1391 }
1392 }
1393 return constUnionIterated;
1394 }
1395
emitConstantUnion(const TType & type,const TConstantUnion * constUnionBegin)1396 const TConstantUnion *GenMetalTraverser::emitConstantUnion(const TType &type,
1397 const TConstantUnion *constUnionBegin)
1398 {
1399 const TConstantUnion *constUnionCurr = constUnionBegin;
1400 const TStructure *structure = type.getStruct();
1401 if (structure)
1402 {
1403 EmitTypeConfig config = EmitTypeConfig{nullptr};
1404 emitType(type, config);
1405 mOut << "{";
1406 const TFieldList &fields = structure->fields();
1407 for (size_t i = 0; i < fields.size(); ++i)
1408 {
1409 const TType *fieldType = fields[i]->type();
1410 constUnionCurr = emitConstantUnion(*fieldType, constUnionCurr);
1411 if (i != fields.size() - 1)
1412 {
1413 mOut << ", ";
1414 }
1415 }
1416 mOut << "}";
1417 }
1418 else
1419 {
1420 size_t size = type.getObjectSize();
1421 bool writeType = size > 1;
1422 if (writeType)
1423 {
1424 EmitTypeConfig config = EmitTypeConfig{nullptr};
1425 emitType(type, config);
1426 mOut << "(";
1427 }
1428 constUnionCurr = emitConstantUnionArray(constUnionCurr, size);
1429 if (writeType)
1430 {
1431 mOut << ")";
1432 }
1433 }
1434 return constUnionCurr;
1435 }
1436
visitConstantUnion(TIntermConstantUnion * constValueNode)1437 void GenMetalTraverser::visitConstantUnion(TIntermConstantUnion *constValueNode)
1438 {
1439 emitConstantUnion(constValueNode->getType(), constValueNode->getConstantValue());
1440 }
1441
visitSwizzle(Visit,TIntermSwizzle * swizzleNode)1442 bool GenMetalTraverser::visitSwizzle(Visit, TIntermSwizzle *swizzleNode)
1443 {
1444 groupedTraverse(*swizzleNode->getOperand());
1445 mOut << ".";
1446
1447 {
1448 #if defined(ANGLE_ENABLE_ASSERTS)
1449 DebugSink::EscapedSink escapedOut(mOut.escape());
1450 TInfoSinkBase &out = escapedOut.get();
1451 #else
1452 TInfoSinkBase &out = mOut;
1453 #endif
1454 swizzleNode->writeOffsetsAsXYZW(&out);
1455 }
1456
1457 return false;
1458 }
1459
getDirectField(const TFieldListCollection & fieldListCollection,const TConstantUnion & index)1460 const TField &GenMetalTraverser::getDirectField(const TFieldListCollection &fieldListCollection,
1461 const TConstantUnion &index)
1462 {
1463 ASSERT(index.getType() == TBasicType::EbtInt);
1464
1465 const TFieldList &fieldList = fieldListCollection.fields();
1466 const int indexVal = index.getIConst();
1467 const TField &field = *fieldList[indexVal];
1468
1469 return field;
1470 }
1471
getDirectField(const TIntermTyped & fieldsNode,TIntermTyped & indexNode)1472 const TField &GenMetalTraverser::getDirectField(const TIntermTyped &fieldsNode,
1473 TIntermTyped &indexNode)
1474 {
1475 const TType &fieldsType = fieldsNode.getType();
1476
1477 const TFieldListCollection *fieldListCollection = fieldsType.getStruct();
1478 if (fieldListCollection == nullptr)
1479 {
1480 fieldListCollection = fieldsType.getInterfaceBlock();
1481 }
1482 ASSERT(fieldListCollection);
1483
1484 const TIntermConstantUnion *indexNode_ = indexNode.getAsConstantUnion();
1485 ASSERT(indexNode_);
1486 const TConstantUnion &index = *indexNode_->getConstantValue();
1487
1488 return getDirectField(*fieldListCollection, index);
1489 }
1490
visitBinary(Visit,TIntermBinary * binaryNode)1491 bool GenMetalTraverser::visitBinary(Visit, TIntermBinary *binaryNode)
1492 {
1493 const TOperator op = binaryNode->getOp();
1494 TIntermTyped &leftNode = *binaryNode->getLeft();
1495 TIntermTyped &rightNode = *binaryNode->getRight();
1496
1497 switch (op)
1498 {
1499 case TOperator::EOpIndexDirectStruct:
1500 case TOperator::EOpIndexDirectInterfaceBlock:
1501 {
1502 const TField &field = getDirectField(leftNode, rightNode);
1503 if (mSymbolEnv.isPointer(field) && mSymbolEnv.isUBO(field))
1504 {
1505 emitOpeningPointerParen();
1506 }
1507 groupedTraverse(leftNode);
1508 if (!mSymbolEnv.isPointer(field))
1509 {
1510 emitClosingPointerParen();
1511 }
1512 mOut << ".";
1513 emitNameOf(field);
1514 }
1515 break;
1516
1517 case TOperator::EOpIndexDirect:
1518 case TOperator::EOpIndexIndirect:
1519 {
1520 TType leftType = leftNode.getType();
1521 groupedTraverse(leftNode);
1522 mOut << "[";
1523 {
1524 mOut << "ANGLE_int_clamp(";
1525 groupedTraverse(rightNode);
1526 mOut << ", 0, ";
1527 if (leftType.isUnsizedArray())
1528 {
1529 groupedTraverse(leftNode);
1530 mOut << ".size()";
1531 }
1532 else
1533 {
1534 int maxSize;
1535 if (leftType.isArray())
1536 {
1537 maxSize = static_cast<int>(leftType.getOutermostArraySize()) - 1;
1538 }
1539 else
1540 {
1541 maxSize = leftType.getNominalSize() - 1;
1542 }
1543 mOut << maxSize;
1544 }
1545 mOut << ")";
1546 }
1547 mOut << "]";
1548 }
1549 break;
1550
1551 default:
1552 {
1553 const TType &resultType = binaryNode->getType();
1554 const TType &leftType = leftNode.getType();
1555 const TType &rightType = rightNode.getType();
1556
1557 if (IsSymbolicOperator(op, resultType, &leftType, &rightType))
1558 {
1559 groupedTraverse(leftNode);
1560 if (op != TOperator::EOpComma)
1561 {
1562 mOut << " ";
1563 }
1564 else
1565 {
1566 emitClosingPointerParen();
1567 }
1568 mOut << GetOperatorString(op, resultType, &leftType, &rightType, nullptr) << " ";
1569 groupedTraverse(rightNode);
1570 }
1571 else
1572 {
1573 emitClosingPointerParen();
1574 mOut << GetOperatorString(op, resultType, &leftType, &rightType, nullptr) << "(";
1575 leftNode.traverse(this);
1576 mOut << ", ";
1577 rightNode.traverse(this);
1578 mOut << ")";
1579 }
1580 }
1581 }
1582
1583 return false;
1584 }
1585
IsPostfix(TOperator op)1586 static bool IsPostfix(TOperator op)
1587 {
1588 switch (op)
1589 {
1590 case TOperator::EOpPostIncrement:
1591 case TOperator::EOpPostDecrement:
1592 return true;
1593
1594 default:
1595 return false;
1596 }
1597 }
1598
visitUnary(Visit,TIntermUnary * unaryNode)1599 bool GenMetalTraverser::visitUnary(Visit, TIntermUnary *unaryNode)
1600 {
1601 const TOperator op = unaryNode->getOp();
1602 const TType &resultType = unaryNode->getType();
1603
1604 TIntermTyped &arg = *unaryNode->getOperand();
1605 const TType &argType = arg.getType();
1606
1607 const char *name = GetOperatorString(op, resultType, &argType, nullptr, nullptr);
1608
1609 if (IsSymbolicOperator(op, resultType, &argType, nullptr))
1610 {
1611 const bool postfix = IsPostfix(op);
1612 if (!postfix)
1613 {
1614 mOut << name;
1615 }
1616 groupedTraverse(arg);
1617 if (postfix)
1618 {
1619 mOut << name;
1620 }
1621 }
1622 else
1623 {
1624 mOut << name << "(";
1625 arg.traverse(this);
1626 mOut << ")";
1627 }
1628
1629 return false;
1630 }
1631
visitTernary(Visit,TIntermTernary * conditionalNode)1632 bool GenMetalTraverser::visitTernary(Visit, TIntermTernary *conditionalNode)
1633 {
1634 groupedTraverse(*conditionalNode->getCondition());
1635 mOut << " ? ";
1636 groupedTraverse(*conditionalNode->getTrueExpression());
1637 mOut << " : ";
1638 groupedTraverse(*conditionalNode->getFalseExpression());
1639
1640 return false;
1641 }
1642
visitIfElse(Visit,TIntermIfElse * ifThenElseNode)1643 bool GenMetalTraverser::visitIfElse(Visit, TIntermIfElse *ifThenElseNode)
1644 {
1645 TIntermTyped &condNode = *ifThenElseNode->getCondition();
1646 TIntermBlock *thenNode = ifThenElseNode->getTrueBlock();
1647 TIntermBlock *elseNode = ifThenElseNode->getFalseBlock();
1648
1649 emitIndentation();
1650 mOut << "if (";
1651 condNode.traverse(this);
1652 mOut << ")";
1653
1654 if (thenNode)
1655 {
1656 mOut << "\n";
1657 thenNode->traverse(this);
1658 }
1659 else
1660 {
1661 mOut << " {}";
1662 }
1663
1664 if (elseNode)
1665 {
1666 mOut << "\n";
1667 emitIndentation();
1668 mOut << "else\n";
1669 elseNode->traverse(this);
1670 }
1671 else
1672 {
1673 // Always emit "else" even when empty block to avoid nested if-stmt issues.
1674 mOut << " else {}";
1675 }
1676
1677 return false;
1678 }
1679
visitSwitch(Visit,TIntermSwitch * switchNode)1680 bool GenMetalTraverser::visitSwitch(Visit, TIntermSwitch *switchNode)
1681 {
1682 emitIndentation();
1683 mOut << "switch (";
1684 switchNode->getInit()->traverse(this);
1685 mOut << ")\n";
1686
1687 ASSERT(!mParentIsSwitch);
1688 mParentIsSwitch = true;
1689 switchNode->getStatementList()->traverse(this);
1690 mParentIsSwitch = false;
1691
1692 return false;
1693 }
1694
visitCase(Visit,TIntermCase * caseNode)1695 bool GenMetalTraverser::visitCase(Visit, TIntermCase *caseNode)
1696 {
1697 emitIndentation();
1698
1699 if (caseNode->hasCondition())
1700 {
1701 TIntermTyped *condExpr = caseNode->getCondition();
1702 mOut << "case ";
1703 condExpr->traverse(this);
1704 mOut << ":";
1705 }
1706 else
1707 {
1708 mOut << "default:\n";
1709 }
1710
1711 return false;
1712 }
1713
emitFunctionSignature(const TFunction & func)1714 void GenMetalTraverser::emitFunctionSignature(const TFunction &func)
1715 {
1716 const bool isMain = func.isMain();
1717
1718 emitFunctionReturn(func);
1719
1720 mOut << " ";
1721 emitNameOf(func);
1722 if (isMain)
1723 {
1724 mOut << "0";
1725 }
1726 mOut << "(";
1727
1728 bool emitComma = false;
1729 const size_t paramCount = func.getParamCount();
1730 for (size_t i = 0; i < paramCount; ++i)
1731 {
1732 if (emitComma)
1733 {
1734 mOut << ", ";
1735 }
1736 emitComma = true;
1737
1738 const TVariable ¶m = *func.getParam(i);
1739 emitFunctionParameter(func, param);
1740 }
1741
1742 if (isTraversingVertexMain)
1743 {
1744 mOut << " @@XFB-Bindings@@ ";
1745 }
1746
1747 mOut << ")";
1748 }
1749
emitFunctionReturn(const TFunction & func)1750 void GenMetalTraverser::emitFunctionReturn(const TFunction &func)
1751 {
1752 const bool isMain = func.isMain();
1753 bool isVertexMain = false;
1754 const TType &returnType = func.getReturnType();
1755 if (isMain)
1756 {
1757 const TStructure *structure = returnType.getStruct();
1758 ASSERT(structure != nullptr);
1759 if (mPipelineStructs.fragmentOut.matches(*structure))
1760 {
1761 mOut << "fragment ";
1762 }
1763 else if (mPipelineStructs.vertexOut.matches(*structure))
1764 {
1765 mOut << "vertex __VERTEX_OUT(";
1766 isVertexMain = true;
1767 }
1768 else
1769 {
1770 UNIMPLEMENTED();
1771 }
1772 }
1773 emitType(returnType, EmitTypeConfig());
1774 if (isVertexMain)
1775 mOut << ") ";
1776 }
1777
emitFunctionParameter(const TFunction & func,const TVariable & param)1778 void GenMetalTraverser::emitFunctionParameter(const TFunction &func, const TVariable ¶m)
1779 {
1780 const bool isMain = func.isMain();
1781
1782 const TType &type = param.getType();
1783 const TStructure *structure = type.getStruct();
1784
1785 EmitVariableDeclarationConfig evdConfig;
1786 evdConfig.isParameter = true;
1787 evdConfig.isMainParameter = isMain;
1788 evdConfig.emitPostQualifier = isMain;
1789 evdConfig.isUBO = mSymbolEnv.isUBO(param);
1790 evdConfig.isPointer = mSymbolEnv.isPointer(param);
1791 evdConfig.isReference = mSymbolEnv.isReference(param);
1792 emitVariableDeclaration(VarDecl(param), evdConfig);
1793
1794 if (isMain)
1795 {
1796 TranslatorMetalReflection *reflection = mtl::getTranslatorMetalReflection(&mCompiler);
1797 if (structure)
1798 {
1799 if (mPipelineStructs.fragmentIn.matches(*structure) ||
1800 mPipelineStructs.vertexIn.matches(*structure))
1801 {
1802 mOut << " [[stage_in]]";
1803 }
1804 else if (mPipelineStructs.angleUniforms.matches(*structure))
1805 {
1806 mOut << " [[buffer(" << mDriverUniformsBindingIndex << ")]]";
1807 }
1808 else if (mPipelineStructs.uniformBuffers.matches(*structure))
1809 {
1810 mOut << " [[buffer(" << mUBOArgumentBufferBindingIndex << ")]]";
1811 reflection->hasUBOs = true;
1812 }
1813 else if (mPipelineStructs.userUniforms.matches(*structure))
1814 {
1815 mOut << " [[buffer(" << mMainUniformBufferIndex << ")]]";
1816 reflection->addUserUniformBufferBinding(param.name().data(),
1817 mMainUniformBufferIndex);
1818 mMainUniformBufferIndex += type.getArraySizeProduct();
1819 }
1820 else if (structure->name() == "metal::sampler")
1821 {
1822 mOut << " [[sampler(" << (mMainSamplerIndex) << ")]]";
1823 const std::string originalName =
1824 reflection->getOriginalName(param.uniqueId().get());
1825 reflection->addSamplerBinding(originalName, mMainSamplerIndex);
1826 mMainSamplerIndex += type.getArraySizeProduct();
1827 }
1828 }
1829 else if (IsSampler(type.getBasicType()))
1830 {
1831 mOut << " [[texture(" << (mMainTextureIndex) << ")]]";
1832 const std::string originalName = reflection->getOriginalName(param.uniqueId().get());
1833 reflection->addTextureBinding(originalName, mMainSamplerIndex);
1834 mMainTextureIndex += type.getArraySizeProduct();
1835 }
1836 else if (Name(param) == Pipeline{Pipeline::Type::InstanceId, nullptr}.getStructInstanceName(
1837 Pipeline::Variant::Modified))
1838 {
1839 mOut << " [[instance_id]]";
1840 }
1841 }
1842 }
1843
visitFunctionPrototype(TIntermFunctionPrototype * funcProtoNode)1844 void GenMetalTraverser::visitFunctionPrototype(TIntermFunctionPrototype *funcProtoNode)
1845 {
1846 const TFunction &func = *funcProtoNode->getFunction();
1847
1848 emitIndentation();
1849 emitFunctionSignature(func);
1850 }
1851
visitFunctionDefinition(Visit,TIntermFunctionDefinition * funcDefNode)1852 bool GenMetalTraverser::visitFunctionDefinition(Visit, TIntermFunctionDefinition *funcDefNode)
1853 {
1854 const TFunction &func = *funcDefNode->getFunction();
1855 TIntermBlock &body = *funcDefNode->getBody();
1856 if (func.isMain())
1857 {
1858 const TType &returnType = func.getReturnType();
1859 const TStructure *structure = returnType.getStruct();
1860 isTraversingVertexMain = (mPipelineStructs.vertexOut.matches(*structure));
1861 }
1862 emitIndentation();
1863 emitFunctionSignature(func);
1864 mOut << "\n";
1865 body.traverse(this);
1866 if (isTraversingVertexMain)
1867 {
1868 isTraversingVertexMain = false;
1869 }
1870 return false;
1871 }
1872
BuildFuncToName()1873 GenMetalTraverser::FuncToName GenMetalTraverser::BuildFuncToName()
1874 {
1875 FuncToName map;
1876
1877 auto putAngle = [&](const char *nameStr) {
1878 const ImmutableString name(nameStr);
1879 ASSERT(map.find(name) == map.end());
1880 map[name] = Name(nameStr, SymbolType::AngleInternal);
1881 };
1882
1883 putAngle("texelFetch");
1884 putAngle("texelFetchOffset");
1885 putAngle("texture");
1886 putAngle("texture1D");
1887 putAngle("texture1DLod");
1888 putAngle("texture1DProjLod");
1889 putAngle("texture2D");
1890 putAngle("texture2DLod");
1891 putAngle("texture2DProj");
1892 putAngle("texture2DRect");
1893 putAngle("texture2DProjLod");
1894 putAngle("texture2DRectProj");
1895 putAngle("texture3D");
1896 putAngle("texture3DLod");
1897 putAngle("texture3DProjLod");
1898 putAngle("textureCube");
1899 putAngle("textureCubeLod");
1900 putAngle("textureCubeProjLod");
1901 putAngle("textureGrad");
1902 putAngle("textureGradOffset");
1903 putAngle("textureLod");
1904 putAngle("textureLodOffset");
1905 putAngle("textureOffset");
1906 putAngle("textureProj");
1907 putAngle("textureProjGrad");
1908 putAngle("textureProjGradOffset");
1909 putAngle("textureProjLod");
1910 putAngle("textureProjLodOffset");
1911 putAngle("textureProjOffset");
1912 putAngle("textureSize");
1913
1914 return map;
1915 }
1916
visitAggregate(Visit,TIntermAggregate * aggregateNode)1917 bool GenMetalTraverser::visitAggregate(Visit, TIntermAggregate *aggregateNode)
1918 {
1919 const TIntermSequence &args = *aggregateNode->getSequence();
1920
1921 auto emitArgList = [&](const char *open, const char *close) {
1922 mOut << open;
1923
1924 bool emitComma = false;
1925 for (TIntermNode *arg : args)
1926 {
1927 if (emitComma)
1928 {
1929 emitClosingPointerParen();
1930 mOut << ", ";
1931 }
1932 emitComma = true;
1933 arg->traverse(this);
1934 }
1935
1936 mOut << close;
1937 };
1938
1939 const TType &retType = aggregateNode->getType();
1940
1941 if (aggregateNode->isConstructor())
1942 {
1943 const bool isStandalone = getParentNode()->getAsBlock();
1944 if (isStandalone)
1945 {
1946 // Prevent constructor from being interpreted as a declaration by wrapping in parens.
1947 // This can happen if given something like:
1948 // int(symbol); // <- This will be treated like `int symbol;`... don't want that.
1949 // So instead emit:
1950 // (int(symbol));
1951 mOut << "(";
1952 }
1953
1954 const EmitTypeConfig etConfig;
1955
1956 if (retType.isArray())
1957 {
1958 emitType(retType, etConfig);
1959 emitArgList("{", "}");
1960 }
1961 else if (retType.getStruct())
1962 {
1963 emitType(retType, etConfig);
1964 emitArgList("{", "}");
1965 }
1966 else
1967 {
1968 emitType(retType, etConfig);
1969 emitArgList("(", ")");
1970 }
1971
1972 if (isStandalone)
1973 {
1974 mOut << ")";
1975 }
1976
1977 return false;
1978 }
1979 else
1980 {
1981 const TOperator op = aggregateNode->getOp();
1982 if (op == EOpAtan)
1983 {
1984 TranslatorMetalReflection *reflection = mtl::getTranslatorMetalReflection(&mCompiler);
1985 reflection->hasAtan = true;
1986 }
1987 switch (op)
1988 {
1989 case TOperator::EOpCallFunctionInAST:
1990 case TOperator::EOpCallInternalRawFunction:
1991 {
1992 const TFunction &func = *aggregateNode->getFunction();
1993 emitNameOf(func);
1994 //'@' symbol in name specifices a macro substitution marker.
1995 if (!func.name().contains("@"))
1996 {
1997 emitArgList("(", ")");
1998 }
1999 else
2000 {
2001 mTemporarilyDisableSemicolon =
2002 true; // Disable semicolon for macro substitution.
2003 }
2004 return false;
2005 }
2006
2007 default:
2008 {
2009 auto getArgType = [&](size_t index) -> const TType * {
2010 if (index < args.size())
2011 {
2012 TIntermTyped *arg = args[index]->getAsTyped();
2013 ASSERT(arg);
2014 return &arg->getType();
2015 }
2016 return nullptr;
2017 };
2018
2019 ASSERT(!args.empty());
2020 const TType *argType0 = getArgType(0);
2021 const TType *argType1 = getArgType(1);
2022 const TType *argType2 = getArgType(2);
2023 ASSERT(argType0);
2024
2025 const char *opName = GetOperatorString(op, retType, argType0, argType1, argType2);
2026
2027 if (IsSymbolicOperator(op, retType, argType0, argType1))
2028 {
2029 switch (args.size())
2030 {
2031 case 1:
2032 {
2033 TIntermNode &operandNode = *aggregateNode->getChildNode(0);
2034 if (IsPostfix(op))
2035 {
2036 mOut << opName;
2037 groupedTraverse(operandNode);
2038 }
2039 else
2040 {
2041 groupedTraverse(operandNode);
2042 mOut << opName;
2043 }
2044 return false;
2045 }
2046
2047 case 2:
2048 {
2049 TIntermNode &leftNode = *aggregateNode->getChildNode(0);
2050 TIntermNode &rightNode = *aggregateNode->getChildNode(1);
2051 groupedTraverse(leftNode);
2052 mOut << " " << opName << " ";
2053 groupedTraverse(rightNode);
2054 return false;
2055 }
2056
2057 default:
2058 UNREACHABLE();
2059 return false;
2060 }
2061 }
2062 else if (opName == nullptr)
2063 {
2064 const TFunction &func = *aggregateNode->getFunction();
2065 auto it = mFuncToName.find(func.name());
2066 ASSERT(it != mFuncToName.end());
2067 EmitName(mOut, it->second);
2068 emitArgList("(", ")");
2069 return false;
2070 }
2071 else
2072 {
2073 mOut << opName;
2074 emitArgList("(", ")");
2075 return false;
2076 }
2077 }
2078 }
2079 }
2080 }
2081
emitOpenBrace()2082 void GenMetalTraverser::emitOpenBrace()
2083 {
2084 ASSERT(mIndentLevel >= 0);
2085
2086 emitIndentation();
2087 mOut << "{\n";
2088 ++mIndentLevel;
2089 }
2090
emitCloseBrace()2091 void GenMetalTraverser::emitCloseBrace()
2092 {
2093 ASSERT(mIndentLevel >= 1);
2094
2095 --mIndentLevel;
2096 emitIndentation();
2097 mOut << "}";
2098 }
2099
RequiresSemicolonTerminator(TIntermNode & node)2100 static bool RequiresSemicolonTerminator(TIntermNode &node)
2101 {
2102 if (node.getAsBlock())
2103 {
2104 return false;
2105 }
2106 if (node.getAsLoopNode())
2107 {
2108 return false;
2109 }
2110 if (node.getAsSwitchNode())
2111 {
2112 return false;
2113 }
2114 if (node.getAsIfElseNode())
2115 {
2116 return false;
2117 }
2118 if (node.getAsFunctionDefinition())
2119 {
2120 return false;
2121 }
2122 if (node.getAsCaseNode())
2123 {
2124 return false;
2125 }
2126
2127 return true;
2128 }
2129
NewlinePad(TIntermNode & node)2130 static bool NewlinePad(TIntermNode &node)
2131 {
2132 if (node.getAsFunctionDefinition())
2133 {
2134 return true;
2135 }
2136 if (TIntermDeclaration *declNode = node.getAsDeclarationNode())
2137 {
2138 ASSERT(declNode->getChildCount() == 1);
2139 TIntermNode &childNode = *declNode->getChildNode(0);
2140 if (TIntermSymbol *symbolNode = childNode.getAsSymbolNode())
2141 {
2142 const TVariable &var = symbolNode->variable();
2143 return var.getType().isStructSpecifier();
2144 }
2145 return false;
2146 }
2147 return false;
2148 }
2149
visitBlock(Visit,TIntermBlock * blockNode)2150 bool GenMetalTraverser::visitBlock(Visit, TIntermBlock *blockNode)
2151 {
2152 ASSERT(mIndentLevel >= -1);
2153 const bool isGlobalScope = mIndentLevel == -1;
2154 const bool parentIsSwitch = mParentIsSwitch;
2155 mParentIsSwitch = false;
2156
2157 if (isGlobalScope)
2158 {
2159 ++mIndentLevel;
2160 }
2161 else
2162 {
2163 emitOpenBrace();
2164 if (parentIsSwitch)
2165 {
2166 ++mIndentLevel;
2167 }
2168 }
2169
2170 TIntermNode *prevStmtNode = nullptr;
2171
2172 const size_t stmtCount = blockNode->getChildCount();
2173 for (size_t i = 0; i < stmtCount; ++i)
2174 {
2175 TIntermNode &stmtNode = *blockNode->getChildNode(i);
2176
2177 if (isGlobalScope && prevStmtNode && (NewlinePad(*prevStmtNode) || NewlinePad(stmtNode)))
2178 {
2179 mOut << "\n";
2180 }
2181 const bool isCase = stmtNode.getAsCaseNode();
2182 mIndentLevel -= isCase;
2183 emitIndentation();
2184 mIndentLevel += isCase;
2185 stmtNode.traverse(this);
2186 if (RequiresSemicolonTerminator(stmtNode) && !mTemporarilyDisableSemicolon)
2187 {
2188 mOut << ";";
2189 }
2190 mTemporarilyDisableSemicolon = false;
2191 mOut << "\n";
2192
2193 prevStmtNode = &stmtNode;
2194 }
2195
2196 if (isGlobalScope)
2197 {
2198 ASSERT(mIndentLevel == 0);
2199 --mIndentLevel;
2200 }
2201 else
2202 {
2203 if (parentIsSwitch)
2204 {
2205 ASSERT(mIndentLevel >= 1);
2206 --mIndentLevel;
2207 }
2208 emitCloseBrace();
2209 mParentIsSwitch = parentIsSwitch;
2210 }
2211
2212 return false;
2213 }
2214
visitGlobalQualifierDeclaration(Visit,TIntermGlobalQualifierDeclaration *)2215 bool GenMetalTraverser::visitGlobalQualifierDeclaration(Visit, TIntermGlobalQualifierDeclaration *)
2216 {
2217 return false;
2218 }
2219
visitDeclaration(Visit,TIntermDeclaration * declNode)2220 bool GenMetalTraverser::visitDeclaration(Visit, TIntermDeclaration *declNode)
2221 {
2222 ASSERT(declNode->getChildCount() == 1);
2223 TIntermNode &node = *declNode->getChildNode(0);
2224
2225 EmitVariableDeclarationConfig evdConfig;
2226
2227 if (TIntermSymbol *symbolNode = node.getAsSymbolNode())
2228 {
2229 const TVariable &var = symbolNode->variable();
2230 emitVariableDeclaration(VarDecl(var), evdConfig);
2231 }
2232 else if (TIntermBinary *initNode = node.getAsBinaryNode())
2233 {
2234 ASSERT(initNode->getOp() == TOperator::EOpInitialize);
2235 TIntermSymbol *leftSymbolNode = initNode->getLeft()->getAsSymbolNode();
2236 TIntermTyped *valueNode = initNode->getRight()->getAsTyped();
2237 ASSERT(leftSymbolNode && valueNode);
2238
2239 if (getRootNode() == getParentBlock())
2240 {
2241 // DeferGlobalInitializers should have turned non-const global initializers into
2242 // deferred initializers. Note that variables marked as EvqGlobal can be treated as
2243 // EvqConst in some ANGLE code but not actually have their qualifier actually changed to
2244 // EvqConst. Thus just assume all EvqGlobal are actually EvqConst for all code run after
2245 // DeferGlobalInitializers.
2246 mOut << "constant ";
2247 }
2248
2249 const TVariable &var = leftSymbolNode->variable();
2250 const Name varName(var);
2251
2252 if (ExpressionContainsName(varName, *valueNode))
2253 {
2254 mRenamedSymbols[&var] = mIdGen.createNewName(varName);
2255 }
2256
2257 emitVariableDeclaration(VarDecl(var), evdConfig);
2258 mOut << " = ";
2259 groupedTraverse(*valueNode);
2260 }
2261 else
2262 {
2263 UNREACHABLE();
2264 }
2265
2266 return false;
2267 }
2268
visitLoop(Visit,TIntermLoop * loopNode)2269 bool GenMetalTraverser::visitLoop(Visit, TIntermLoop *loopNode)
2270 {
2271 const TLoopType loopType = loopNode->getType();
2272
2273 switch (loopType)
2274 {
2275 case TLoopType::ELoopFor:
2276 return visitForLoop(loopNode);
2277 case TLoopType::ELoopWhile:
2278 return visitWhileLoop(loopNode);
2279 case TLoopType::ELoopDoWhile:
2280 return visitDoWhileLoop(loopNode);
2281 }
2282 }
2283
visitForLoop(TIntermLoop * loopNode)2284 bool GenMetalTraverser::visitForLoop(TIntermLoop *loopNode)
2285 {
2286 ASSERT(loopNode->getType() == TLoopType::ELoopFor);
2287
2288 TIntermNode *initNode = loopNode->getInit();
2289 TIntermTyped *condNode = loopNode->getCondition();
2290 TIntermTyped *exprNode = loopNode->getExpression();
2291 TIntermBlock *bodyNode = loopNode->getBody();
2292 ASSERT(bodyNode);
2293
2294 mOut << "for (";
2295
2296 if (initNode)
2297 {
2298 initNode->traverse(this);
2299 }
2300 else
2301 {
2302 mOut << " ";
2303 }
2304
2305 mOut << "; ";
2306
2307 if (condNode)
2308 {
2309 condNode->traverse(this);
2310 }
2311
2312 mOut << "; ";
2313
2314 if (exprNode)
2315 {
2316 exprNode->traverse(this);
2317 }
2318
2319 mOut << ")\n";
2320
2321 bodyNode->traverse(this);
2322
2323 return false;
2324 }
2325
visitWhileLoop(TIntermLoop * loopNode)2326 bool GenMetalTraverser::visitWhileLoop(TIntermLoop *loopNode)
2327 {
2328 ASSERT(loopNode->getType() == TLoopType::ELoopWhile);
2329
2330 TIntermNode *initNode = loopNode->getInit();
2331 TIntermTyped *condNode = loopNode->getCondition();
2332 TIntermTyped *exprNode = loopNode->getExpression();
2333 TIntermBlock *bodyNode = loopNode->getBody();
2334 ASSERT(condNode && bodyNode);
2335 ASSERT(!initNode && !exprNode);
2336
2337 emitIndentation();
2338 mOut << "while (";
2339 condNode->traverse(this);
2340 mOut << ")\n";
2341 bodyNode->traverse(this);
2342
2343 return false;
2344 }
2345
visitDoWhileLoop(TIntermLoop * loopNode)2346 bool GenMetalTraverser::visitDoWhileLoop(TIntermLoop *loopNode)
2347 {
2348 ASSERT(loopNode->getType() == TLoopType::ELoopDoWhile);
2349
2350 TIntermNode *initNode = loopNode->getInit();
2351 TIntermTyped *condNode = loopNode->getCondition();
2352 TIntermTyped *exprNode = loopNode->getExpression();
2353 TIntermBlock *bodyNode = loopNode->getBody();
2354 ASSERT(condNode && bodyNode);
2355 ASSERT(!initNode && !exprNode);
2356
2357 emitIndentation();
2358 mOut << "do\n";
2359 bodyNode->traverse(this);
2360 mOut << "\n";
2361 emitIndentation();
2362 mOut << "while (";
2363 condNode->traverse(this);
2364 mOut << ");";
2365
2366 return false;
2367 }
2368
visitBranch(Visit,TIntermBranch * branchNode)2369 bool GenMetalTraverser::visitBranch(Visit, TIntermBranch *branchNode)
2370 {
2371 const TOperator flowOp = branchNode->getFlowOp();
2372 TIntermTyped *exprNode = branchNode->getExpression();
2373
2374 emitIndentation();
2375
2376 switch (flowOp)
2377 {
2378 case TOperator::EOpKill:
2379 {
2380 ASSERT(exprNode == nullptr);
2381 mOut << "metal::discard_fragment()";
2382 }
2383 break;
2384
2385 case TOperator::EOpReturn:
2386 {
2387 if (isTraversingVertexMain)
2388 {
2389 mOut << "#if TRANSFORM_FEEDBACK_ENABLED\n";
2390 emitIndentation();
2391 mOut << "return;\n";
2392 emitIndentation();
2393 mOut << "#else\n";
2394 emitIndentation();
2395 }
2396 mOut << "return";
2397 if (exprNode)
2398 {
2399 mOut << " ";
2400 exprNode->traverse(this);
2401 mOut << ";";
2402 }
2403 if (isTraversingVertexMain)
2404 {
2405 mOut << "\n";
2406 emitIndentation();
2407 mOut << "#endif\n";
2408 mTemporarilyDisableSemicolon = true;
2409 }
2410 }
2411 break;
2412
2413 case TOperator::EOpBreak:
2414 {
2415 ASSERT(exprNode == nullptr);
2416 mOut << "break";
2417 }
2418 break;
2419
2420 case TOperator::EOpContinue:
2421 {
2422 ASSERT(exprNode == nullptr);
2423 mOut << "continue";
2424 }
2425 break;
2426
2427 default:
2428 {
2429 UNREACHABLE();
2430 }
2431 }
2432
2433 return false;
2434 }
2435
2436 static size_t emitMetalCallCount = 0;
2437
EmitMetal(TCompiler & compiler,TIntermBlock & root,IdGen & idGen,const PipelineStructs & pipelineStructs,SymbolEnv & symbolEnv,const ProgramPreludeConfig & ppc,TSymbolTable * symbolTable)2438 bool sh::EmitMetal(TCompiler &compiler,
2439 TIntermBlock &root,
2440 IdGen &idGen,
2441 const PipelineStructs &pipelineStructs,
2442 SymbolEnv &symbolEnv,
2443 const ProgramPreludeConfig &ppc,
2444 TSymbolTable *symbolTable)
2445 {
2446 TInfoSinkBase &out = compiler.getInfoSink().obj;
2447
2448 {
2449 ++emitMetalCallCount;
2450 std::string filenameProto = angle::GetEnvironmentVar("GMD_FIXED_EMIT");
2451 if (!filenameProto.empty())
2452 {
2453 if (filenameProto != "/dev/null")
2454 {
2455 auto tryOpen = [&](char const *ext) {
2456 auto filename = filenameProto;
2457 filename += std::to_string(emitMetalCallCount);
2458 filename += ".";
2459 filename += ext;
2460 return fopen(filename.c_str(), "rb");
2461 };
2462 FILE *file = tryOpen("metal");
2463 if (!file)
2464 {
2465 file = tryOpen("cpp");
2466 }
2467 ASSERT(file);
2468
2469 fseek(file, 0, SEEK_END);
2470 size_t fileSize = ftell(file);
2471 fseek(file, 0, SEEK_SET);
2472
2473 std::vector<char> buff;
2474 buff.resize(fileSize + 1);
2475 fread(buff.data(), fileSize, 1, file);
2476 buff.back() = '\0';
2477
2478 fclose(file);
2479
2480 out << buff.data();
2481 }
2482
2483 return true;
2484 }
2485 }
2486
2487 out << "\n\n";
2488
2489 if (!EmitProgramPrelude(root, out, ppc))
2490 {
2491 return false;
2492 }
2493
2494 {
2495 #if defined(ANGLE_ENABLE_ASSERTS)
2496 DebugSink outWrapper(out, angle::GetBoolEnvironmentVar("GMD_STDOUT"));
2497 outWrapper.watch(angle::GetEnvironmentVar("GMD_WATCH_STRING"));
2498 #else
2499 TInfoSinkBase &outWrapper = out;
2500 #endif
2501 GenMetalTraverser gen(compiler, outWrapper, idGen, pipelineStructs, symbolEnv, symbolTable);
2502 root.traverse(&gen);
2503 }
2504
2505 out << "\n";
2506
2507 return true;
2508 }
2509