1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include <cctype>
8 #include <map>
9
10 #include "common/system_utils.h"
11 #include "compiler/translator/ImmutableStringBuilder.h"
12 #include "compiler/translator/SymbolTable.h"
13 #include "compiler/translator/TranslatorMetalDirect.h"
14 #include "compiler/translator/TranslatorMetalDirect/AstHelpers.h"
15 #include "compiler/translator/TranslatorMetalDirect/DebugSink.h"
16 #include "compiler/translator/TranslatorMetalDirect/EmitMetal.h"
17 #include "compiler/translator/TranslatorMetalDirect/Layout.h"
18 #include "compiler/translator/TranslatorMetalDirect/Name.h"
19 #include "compiler/translator/TranslatorMetalDirect/ProgramPrelude.h"
20 #include "compiler/translator/TranslatorMetalDirect/RewritePipelines.h"
21 #include "compiler/translator/tree_util/IntermTraverse.h"
22
23 using namespace sh;
24
25 ////////////////////////////////////////////////////////////////////////////////
26
27 #if defined(ANGLE_ENABLE_ASSERTS)
28 using Sink = DebugSink;
29 #else
30 using Sink = TInfoSinkBase;
31 #endif
32
33 ////////////////////////////////////////////////////////////////////////////////
34
35 namespace
36 {
37
38 struct VarDecl
39 {
VarDecl__anon16f098020111::VarDecl40 explicit VarDecl(const TVariable &var) : mVariable(&var), mIsField(false) {}
VarDecl__anon16f098020111::VarDecl41 explicit VarDecl(const TField &field) : mField(&field), mIsField(true) {}
42
variable__anon16f098020111::VarDecl43 ANGLE_INLINE const TVariable &variable() const
44 {
45 ASSERT(isVariable());
46 return *mVariable;
47 }
48
field__anon16f098020111::VarDecl49 ANGLE_INLINE const TField &field() const
50 {
51 ASSERT(isField());
52 return *mField;
53 }
54
isVariable__anon16f098020111::VarDecl55 ANGLE_INLINE bool isVariable() const { return !mIsField; }
56
isField__anon16f098020111::VarDecl57 ANGLE_INLINE bool isField() const { return mIsField; }
58
type__anon16f098020111::VarDecl59 const TType &type() const { return isField() ? *field().type() : variable().getType(); }
60
symbolType__anon16f098020111::VarDecl61 SymbolType symbolType() const
62 {
63 return isField() ? field().symbolType() : variable().symbolType();
64 }
65
66 private:
67 union
68 {
69 const TVariable *mVariable;
70 const TField *mField;
71 };
72 bool mIsField;
73 };
74
75 class GenMetalTraverser : public TIntermTraverser
76 {
77 public:
78 ~GenMetalTraverser() override;
79
80 GenMetalTraverser(const TCompiler &compiler,
81 Sink &out,
82 IdGen &idGen,
83 const PipelineStructs &pipelineStructs,
84 const Invariants &invariants,
85 SymbolEnv &symbolEnv,
86 TSymbolTable *symbolTable);
87
88 void visitSymbol(TIntermSymbol *) override;
89 void visitConstantUnion(TIntermConstantUnion *) override;
90 bool visitSwizzle(Visit, TIntermSwizzle *) override;
91 bool visitBinary(Visit, TIntermBinary *) override;
92 bool visitUnary(Visit, TIntermUnary *) override;
93 bool visitTernary(Visit, TIntermTernary *) override;
94 bool visitIfElse(Visit, TIntermIfElse *) override;
95 bool visitSwitch(Visit, TIntermSwitch *) override;
96 bool visitCase(Visit, TIntermCase *) override;
97 void visitFunctionPrototype(TIntermFunctionPrototype *) override;
98 bool visitFunctionDefinition(Visit, TIntermFunctionDefinition *) override;
99 bool visitAggregate(Visit, TIntermAggregate *) override;
100 bool visitBlock(Visit, TIntermBlock *) override;
101 bool visitGlobalQualifierDeclaration(Visit, TIntermGlobalQualifierDeclaration *) override;
102 bool visitDeclaration(Visit, TIntermDeclaration *) override;
103 bool visitLoop(Visit, TIntermLoop *) override;
104 bool visitForLoop(TIntermLoop *);
105 bool visitWhileLoop(TIntermLoop *);
106 bool visitDoWhileLoop(TIntermLoop *);
107 bool visitBranch(Visit, TIntermBranch *) override;
108
109 private:
110 using FuncToName = std::map<ImmutableString, Name>;
111 static FuncToName BuildFuncToName();
112
113 struct EmitVariableDeclarationConfig
114 {
115 bool isParameter = false;
116 bool isMainParameter = false;
117 bool emitPostQualifier = false;
118 bool isPacked = false;
119 bool disableStructSpecifier = false;
120 bool isUBO = false;
121 const AddressSpace *isPointer = nullptr;
122 const AddressSpace *isReference = nullptr;
123 };
124
125 struct EmitTypeConfig
126 {
127 const EmitVariableDeclarationConfig *evdConfig = nullptr;
128 };
129
130 void emitIndentation();
131 void emitOpeningPointerParen();
132 void emitClosingPointerParen();
133 void emitFunctionSignature(const TFunction &func);
134 void emitFunctionReturn(const TFunction &func);
135 void emitFunctionParameter(const TFunction &func, const TVariable ¶m);
136
137 void emitNameOf(const TField &object);
138 void emitNameOf(const TSymbol &object);
139 void emitNameOf(const VarDecl &object);
140
141 void emitBareTypeName(const TType &type, const EmitTypeConfig &etConfig);
142 void emitType(const TType &type, const EmitTypeConfig &etConfig);
143 void emitPostQualifier(const EmitVariableDeclarationConfig &evdConfig,
144 const VarDecl &decl,
145 const TQualifier qualifier);
146
147 struct FieldAnnotationIndices
148 {
149 size_t attribute = 0;
150 size_t color = 0;
151 };
152
153 void emitFieldDeclaration(const TField &field,
154 const TStructure &parent,
155 FieldAnnotationIndices &annotationIndices);
156 void emitAttributeDeclaration(const TField &field, FieldAnnotationIndices &annotationIndices);
157 void emitUniformBufferDeclaration(const TField &field,
158 FieldAnnotationIndices &annotationIndices);
159 void emitStructDeclaration(const TType &type);
160 void emitOrdinaryVariableDeclaration(const VarDecl &decl,
161 const EmitVariableDeclarationConfig &evdConfig);
162 void emitVariableDeclaration(const VarDecl &decl,
163 const EmitVariableDeclarationConfig &evdConfig);
164
165 void emitOpenBrace();
166 void emitCloseBrace();
167
168 void groupedTraverse(TIntermNode &node);
169
170 const TField &getDirectField(const TFieldListCollection &fieldsNode,
171 const TConstantUnion &index);
172 const TField &getDirectField(const TIntermTyped &fieldsNode, TIntermTyped &indexNode);
173
174 const TConstantUnion *emitConstantUnionArray(const TConstantUnion *const constUnion,
175 const size_t size);
176
177 const TConstantUnion *emitConstantUnion(const TType &type, const TConstantUnion *constUnion);
178
179 void emitSingleConstant(const TConstantUnion *const constUnion);
180
181 private:
182 Sink &mOut;
183 const TCompiler &mCompiler;
184 const PipelineStructs &mPipelineStructs;
185 const Invariants &mInvariants;
186 SymbolEnv &mSymbolEnv;
187 IdGen &mIdGen;
188 int mIndentLevel = -1;
189 int mLastIndentationPos = -1;
190 int mOpenPointerParenCount = 0;
191 bool mParentIsSwitch = false;
192 bool isTraversingVertexMain = false;
193 bool mTemporarilyDisableSemicolon = false;
194 std::unordered_map<const TSymbol *, Name> mRenamedSymbols;
195 const FuncToName mFuncToName = BuildFuncToName();
196 size_t mMainTextureIndex = 0;
197 size_t mMainSamplerIndex = 0;
198 size_t mMainUniformBufferIndex = 0;
199 size_t mDriverUniformsBindingIndex = 0;
200 size_t mUBOArgumentBufferBindingIndex = 0;
201 };
202
203 } // anonymous namespace
204
~GenMetalTraverser()205 GenMetalTraverser::~GenMetalTraverser()
206 {
207 ASSERT(mIndentLevel == -1);
208 ASSERT(!mParentIsSwitch);
209 ASSERT(mOpenPointerParenCount == 0);
210 }
211
GenMetalTraverser(const TCompiler & compiler,Sink & out,IdGen & idGen,const PipelineStructs & pipelineStructs,const Invariants & invariants,SymbolEnv & symbolEnv,TSymbolTable * symbolTable)212 GenMetalTraverser::GenMetalTraverser(const TCompiler &compiler,
213 Sink &out,
214 IdGen &idGen,
215 const PipelineStructs &pipelineStructs,
216 const Invariants &invariants,
217 SymbolEnv &symbolEnv,
218 TSymbolTable *symbolTable)
219 : TIntermTraverser(true, false, false),
220 mOut(out),
221 mCompiler(compiler),
222 mPipelineStructs(pipelineStructs),
223 mInvariants(invariants),
224 mSymbolEnv(symbolEnv),
225 mIdGen(idGen),
226 mMainUniformBufferIndex(symbolTable->getDefaultUniformsBindingIndex()),
227 mDriverUniformsBindingIndex(symbolTable->getDriverUniformsBindingIndex()),
228 mUBOArgumentBufferBindingIndex(symbolTable->getUBOArgumentBufferBindingIndex())
229 {}
230
emitIndentation()231 void GenMetalTraverser::emitIndentation()
232 {
233 ASSERT(mIndentLevel >= 0);
234
235 if (mLastIndentationPos == mOut.size())
236 {
237 return; // Line is already indented.
238 }
239
240 for (int i = 0; i < mIndentLevel; ++i)
241 {
242 mOut << " ";
243 }
244
245 mLastIndentationPos = mOut.size();
246 }
247
emitOpeningPointerParen()248 void GenMetalTraverser::emitOpeningPointerParen()
249 {
250 mOut << "(*";
251 mOpenPointerParenCount++;
252 }
253
emitClosingPointerParen()254 void GenMetalTraverser::emitClosingPointerParen()
255 {
256 if (mOpenPointerParenCount > 0)
257 {
258 mOut << ")";
259 mOpenPointerParenCount--;
260 }
261 }
262
GetOperatorString(TOperator op,const TType & resultType,const TType * argType0,const TType * argType1=nullptr)263 static const char *GetOperatorString(TOperator op,
264 const TType &resultType,
265 const TType *argType0,
266 const TType *argType1 = nullptr)
267 {
268 switch (op)
269 {
270 case TOperator::EOpComma:
271 return ",";
272 case TOperator::EOpAssign:
273 return "=";
274 case TOperator::EOpInitialize:
275 return "=";
276 case TOperator::EOpAddAssign:
277 return "+=";
278 case TOperator::EOpSubAssign:
279 return "-=";
280 case TOperator::EOpMulAssign:
281 return "*=";
282 case TOperator::EOpDivAssign:
283 return "/=";
284 case TOperator::EOpIModAssign:
285 return "%=";
286 case TOperator::EOpBitShiftLeftAssign:
287 return "<<="; // TODO: Check logical vs arithmetic shifting.
288 case TOperator::EOpBitShiftRightAssign:
289 return ">>="; // TODO: Check logical vs arithmetic shifting.
290 case TOperator::EOpBitwiseAndAssign:
291 return "&=";
292 case TOperator::EOpBitwiseXorAssign:
293 return "^=";
294 case TOperator::EOpBitwiseOrAssign:
295 return "|=";
296 case TOperator::EOpAdd:
297 return "+";
298 case TOperator::EOpSub:
299 return "-";
300 case TOperator::EOpMul:
301 return "*";
302 case TOperator::EOpDiv:
303 return "/";
304 case TOperator::EOpIMod:
305 return "%";
306 case TOperator::EOpBitShiftLeft:
307 return "<<"; // TODO: Check logical vs arithmetic shifting.
308 case TOperator::EOpBitShiftRight:
309 return ">>"; // TODO: Check logical vs arithmetic shifting.
310 case TOperator::EOpBitwiseAnd:
311 return "&";
312 case TOperator::EOpBitwiseXor:
313 return "^";
314 case TOperator::EOpBitwiseOr:
315 return "|";
316 case TOperator::EOpLessThan:
317 return "<";
318 case TOperator::EOpGreaterThan:
319 return ">";
320 case TOperator::EOpLessThanEqual:
321 return "<=";
322 case TOperator::EOpGreaterThanEqual:
323 return ">=";
324 case TOperator::EOpLessThanComponentWise:
325 return "<";
326 case TOperator::EOpLessThanEqualComponentWise:
327 return "<=";
328 case TOperator::EOpGreaterThanEqualComponentWise:
329 return ">=";
330 case TOperator::EOpGreaterThanComponentWise:
331 return ">";
332 case TOperator::EOpLogicalOr:
333 return "||";
334 case TOperator::EOpLogicalXor:
335 return "!=/*xor*/"; // XXX: This might need to be handled differently for some obtuse
336 // use case.
337 case TOperator::EOpLogicalAnd:
338 return "&&";
339 case TOperator::EOpNegative:
340 return "-";
341 case TOperator::EOpPositive:
342 if (argType0->isMatrix())
343 {
344 return "";
345 }
346 return "+";
347 case TOperator::EOpLogicalNot:
348 return "!";
349 case TOperator::EOpNotComponentWise:
350 return "!";
351 case TOperator::EOpBitwiseNot:
352 return "~";
353 case TOperator::EOpPostIncrement:
354 return "++";
355 case TOperator::EOpPostDecrement:
356 return "--";
357 case TOperator::EOpPreIncrement:
358 return "++";
359 case TOperator::EOpPreDecrement:
360 return "--";
361 case TOperator::EOpVectorTimesScalarAssign:
362 return "*=";
363 case TOperator::EOpVectorTimesMatrixAssign:
364 return "*=";
365 case TOperator::EOpMatrixTimesScalarAssign:
366 return "*=";
367 case TOperator::EOpMatrixTimesMatrixAssign:
368 return "*=";
369 case TOperator::EOpVectorTimesScalar:
370 return "*";
371 case TOperator::EOpVectorTimesMatrix:
372 return "*";
373 case TOperator::EOpMatrixTimesVector:
374 return "*";
375 case TOperator::EOpMatrixTimesScalar:
376 return "*";
377 case TOperator::EOpMatrixTimesMatrix:
378 return "*";
379 case TOperator::EOpEqualComponentWise:
380 return "==";
381 case TOperator::EOpNotEqualComponentWise:
382 return "!=";
383
384 case TOperator::EOpEqual:
385 if ((argType0->getStruct() && argType1->getStruct()) &&
386 (argType0->isArray() && argType1->isArray()))
387 {
388 return "ANGLE_equalStructArray";
389 }
390
391 if ((argType0->isVector() && argType1->isVector()) ||
392 (argType0->getStruct() && argType1->getStruct()) ||
393 (argType0->isArray() && argType1->isArray()) ||
394 (argType0->isMatrix() && argType1->isMatrix()))
395
396 {
397 return "ANGLE_equal";
398 }
399
400 return "==";
401
402 case TOperator::EOpNotEqual:
403 if ((argType0->getStruct() && argType1->getStruct()) &&
404 (argType0->isArray() && argType1->isArray()))
405 {
406 return "ANGLE_notEqualStructArray";
407 }
408
409 if ((argType0->isVector() && argType1->isVector()) ||
410 (argType0->isArray() && argType1->isArray()) ||
411 (argType0->isMatrix() && argType1->isMatrix()))
412 {
413 return "ANGLE_notEqual";
414 }
415 else if (argType0->getStruct() && argType1->getStruct())
416 {
417 return "ANGLE_notEqualStruct";
418 }
419 return "!=";
420
421 case TOperator::EOpKill:
422 UNIMPLEMENTED();
423 return "kill";
424 case TOperator::EOpReturn:
425 return "return";
426 case TOperator::EOpBreak:
427 return "break";
428 case TOperator::EOpContinue:
429 return "continue";
430
431 case TOperator::EOpRadians:
432 return "ANGLE_radians";
433 case TOperator::EOpDegrees:
434 return "ANGLE_degrees";
435 case TOperator::EOpAtan:
436 return "ANGLE_atan";
437 case TOperator::EOpMod:
438 return "ANGLE_mod"; // differs from metal::mod
439 case TOperator::EOpRefract:
440 return "ANGLE_refract";
441 case TOperator::EOpDistance:
442 return "ANGLE_distance";
443 case TOperator::EOpLength:
444 return "ANGLE_length";
445 case TOperator::EOpDot:
446 return "ANGLE_dot";
447 case TOperator::EOpNormalize:
448 return "ANGLE_normalize";
449 case TOperator::EOpFaceforward:
450 return "ANGLE_faceforward";
451 case TOperator::EOpReflect:
452 return "ANGLE_reflect";
453 case TOperator::EOpMatrixCompMult:
454 return "ANGLE_componentWiseMultiply";
455 case TOperator::EOpOuterProduct:
456 return "ANGLE_outerProduct";
457 case TOperator::EOpSign:
458 return "ANGLE_sign";
459
460 case TOperator::EOpAbs:
461 return "metal::abs";
462 case TOperator::EOpAll:
463 return "metal::all";
464 case TOperator::EOpAny:
465 return "metal::any";
466 case TOperator::EOpSin:
467 return "metal::sin";
468 case TOperator::EOpCos:
469 return "metal::cos";
470 case TOperator::EOpTan:
471 return "metal::tan";
472 case TOperator::EOpAsin:
473 return "metal::asin";
474 case TOperator::EOpAcos:
475 return "metal::acos";
476 case TOperator::EOpSinh:
477 return "metal::sinh";
478 case TOperator::EOpCosh:
479 return "metal::cosh";
480 case TOperator::EOpTanh:
481 return "metal::tanh";
482 case TOperator::EOpAsinh:
483 return "metal::asinh";
484 case TOperator::EOpAcosh:
485 return "metal::acosh";
486 case TOperator::EOpAtanh:
487 return "metal::atanh";
488 case TOperator::EOpFma:
489 return "metal::fma";
490 case TOperator::EOpPow:
491 return "metal::pow";
492 case TOperator::EOpExp:
493 return "metal::exp";
494 case TOperator::EOpExp2:
495 return "metal::exp2";
496 case TOperator::EOpLog:
497 return "metal::log";
498 case TOperator::EOpLog2:
499 return "metal::log2";
500 case TOperator::EOpSqrt:
501 return "metal::sqrt";
502 case TOperator::EOpFloor:
503 return "metal::floor";
504 case TOperator::EOpTrunc:
505 return "metal::trunc";
506 case TOperator::EOpCeil:
507 return "metal::ceil";
508 case TOperator::EOpFract:
509 return "metal::fract";
510 case TOperator::EOpMin:
511 return "metal::min";
512 case TOperator::EOpMax:
513 return "metal::max";
514 case TOperator::EOpRound:
515 return "metal::round";
516 case TOperator::EOpRoundEven:
517 return "metal::rint";
518 case TOperator::EOpClamp:
519 return "metal::clamp"; // TODO fast vs precise namespace
520 case TOperator::EOpMix:
521 return "metal::mix";
522 case TOperator::EOpStep:
523 return "metal::step";
524 case TOperator::EOpSmoothstep:
525 return "metal::smoothstep";
526 case TOperator::EOpModf:
527 return "metal::modf";
528 case TOperator::EOpIsnan:
529 return "metal::isnan";
530 case TOperator::EOpIsinf:
531 return "metal::isinf";
532 case TOperator::EOpLdexp:
533 return "metal::ldexp";
534 case TOperator::EOpFrexp:
535 return "metal::frexp";
536 case TOperator::EOpInversesqrt:
537 return "metal::rsqrt";
538 case TOperator::EOpCross:
539 return "metal::cross";
540 case TOperator::EOpDFdx:
541 return "metal::dfdx";
542 case TOperator::EOpDFdy:
543 return "metal::dfdy";
544 case TOperator::EOpFwidth:
545 return "metal::fwidth";
546 case TOperator::EOpTranspose:
547 return "metal::transpose";
548 case TOperator::EOpDeterminant:
549 return "metal::determinant";
550
551 case TOperator::EOpInverse:
552 return "ANGLE_inverse";
553
554 case TOperator::EOpFloatBitsToInt:
555 case TOperator::EOpFloatBitsToUint:
556 case TOperator::EOpIntBitsToFloat:
557 case TOperator::EOpUintBitsToFloat:
558 {
559 #define RETURN_AS_TYPE(post) \
560 do \
561 switch (resultType.getBasicType()) \
562 { \
563 case TBasicType::EbtInt: \
564 return "as_type<int" post ">"; \
565 case TBasicType::EbtUInt: \
566 return "as_type<uint" post ">"; \
567 case TBasicType::EbtFloat: \
568 return "as_type<float" post ">"; \
569 default: \
570 UNIMPLEMENTED(); \
571 return "TOperator_TODO"; \
572 } \
573 while (false)
574
575 if (resultType.isScalar())
576 {
577 RETURN_AS_TYPE("");
578 }
579 else if (resultType.isVector())
580 {
581 switch (resultType.getNominalSize())
582 {
583 case 2:
584 RETURN_AS_TYPE("2");
585 case 3:
586 RETURN_AS_TYPE("3");
587 case 4:
588 RETURN_AS_TYPE("4");
589 default:
590 UNREACHABLE();
591 return nullptr;
592 }
593 }
594 else
595 {
596 UNIMPLEMENTED();
597 return "TOperator_TODO";
598 }
599
600 #undef RETURN_AS_TYPE
601 }
602
603 case TOperator::EOpPackUnorm2x16:
604 return "metal::pack_float_to_unorm2x16";
605 case TOperator::EOpPackSnorm2x16:
606 return "metal::pack_float_to_snorm2x16";
607
608 case TOperator::EOpPackUnorm4x8:
609 return "metal::pack_float_to_unorm4x8";
610 case TOperator::EOpPackSnorm4x8:
611 return "metal::pack_float_to_snorm4x8";
612
613 case TOperator::EOpUnpackUnorm2x16:
614 return "metal::unpack_unorm2x16_to_float";
615 case TOperator::EOpUnpackSnorm2x16:
616 return "metal::unpack_snorm2x16_to_float";
617
618 case TOperator::EOpUnpackUnorm4x8:
619 return "metal::unpack_unorm4x8_to_float";
620 case TOperator::EOpUnpackSnorm4x8:
621 return "metal::unpack_snorm4x8_to_float";
622
623 case TOperator::EOpPackHalf2x16:
624 return "ANGLE_pack_half_2x16";
625 case TOperator::EOpUnpackHalf2x16:
626 return "ANGLE_unpack_half_2x16";
627
628 case TOperator::EOpBitfieldExtract:
629 case TOperator::EOpBitfieldInsert:
630 case TOperator::EOpBitfieldReverse:
631 case TOperator::EOpBitCount:
632 case TOperator::EOpFindLSB:
633 case TOperator::EOpFindMSB:
634 case TOperator::EOpUaddCarry:
635 case TOperator::EOpUsubBorrow:
636 case TOperator::EOpUmulExtended:
637 case TOperator::EOpImulExtended:
638 case TOperator::EOpBarrier:
639 case TOperator::EOpMemoryBarrier:
640 case TOperator::EOpMemoryBarrierAtomicCounter:
641 case TOperator::EOpMemoryBarrierBuffer:
642 case TOperator::EOpMemoryBarrierImage:
643 case TOperator::EOpMemoryBarrierShared:
644 case TOperator::EOpGroupMemoryBarrier:
645 case TOperator::EOpAtomicAdd:
646 case TOperator::EOpAtomicMin:
647 case TOperator::EOpAtomicMax:
648 case TOperator::EOpAtomicAnd:
649 case TOperator::EOpAtomicOr:
650 case TOperator::EOpAtomicXor:
651 case TOperator::EOpAtomicExchange:
652 case TOperator::EOpAtomicCompSwap:
653 case TOperator::EOpEmitVertex:
654 case TOperator::EOpEndPrimitive:
655 case TOperator::EOpFtransform:
656 case TOperator::EOpPackDouble2x32:
657 case TOperator::EOpUnpackDouble2x32:
658 case TOperator::EOpArrayLength:
659 UNIMPLEMENTED();
660 return "TOperator_TODO";
661
662 case TOperator::EOpNull:
663 case TOperator::EOpConstruct:
664 case TOperator::EOpCallFunctionInAST:
665 case TOperator::EOpCallInternalRawFunction:
666 case TOperator::EOpIndexDirect:
667 case TOperator::EOpIndexIndirect:
668 case TOperator::EOpIndexDirectStruct:
669 case TOperator::EOpIndexDirectInterfaceBlock:
670 UNREACHABLE();
671 return nullptr;
672 default:
673 // Any other built-in function.
674 return nullptr;
675 }
676 }
677
IsSymbolicOperator(TOperator op,const TType & resultType,const TType * argType0,const TType * argType1=nullptr)678 static bool IsSymbolicOperator(TOperator op,
679 const TType &resultType,
680 const TType *argType0,
681 const TType *argType1 = nullptr)
682 {
683 const char *operatorString = GetOperatorString(op, resultType, argType0, argType1);
684 if (operatorString == nullptr)
685 {
686 return false;
687 }
688 return !std::isalnum(operatorString[0]);
689 }
690
AsSpecificBinaryNode(TIntermNode & node,TOperator op)691 static TIntermBinary *AsSpecificBinaryNode(TIntermNode &node, TOperator op)
692 {
693 TIntermBinary *binaryNode = node.getAsBinaryNode();
694 if (binaryNode)
695 {
696 return binaryNode->getOp() == op ? binaryNode : nullptr;
697 }
698 return nullptr;
699 }
700
Parenthesize(TIntermNode & node)701 static bool Parenthesize(TIntermNode &node)
702 {
703 if (node.getAsSymbolNode())
704 {
705 return false;
706 }
707 if (node.getAsConstantUnion())
708 {
709 return false;
710 }
711 if (node.getAsAggregate())
712 {
713 return false;
714 }
715 if (node.getAsSwizzleNode())
716 {
717 return false;
718 }
719
720 if (TIntermUnary *unaryNode = node.getAsUnaryNode())
721 {
722 // TODO: Use a precedence and associativity rules instead of this ad-hoc impl.
723 const TType &resultType = unaryNode->getType();
724 const TType &argType = unaryNode->getOperand()->getType();
725 return IsSymbolicOperator(unaryNode->getOp(), resultType, &argType);
726 }
727
728 if (TIntermBinary *binaryNode = node.getAsBinaryNode())
729 {
730 // TODO: Use a precedence and associativity rules instead of this ad-hoc impl.
731 const TOperator op = binaryNode->getOp();
732 switch (op)
733 {
734 case TOperator::EOpIndexDirectStruct:
735 case TOperator::EOpIndexDirectInterfaceBlock:
736 case TOperator::EOpIndexDirect:
737 case TOperator::EOpIndexIndirect:
738 return Parenthesize(*binaryNode->getLeft());
739
740 case TOperator::EOpAssign:
741 case TOperator::EOpInitialize:
742 return AsSpecificBinaryNode(*binaryNode->getRight(), TOperator::EOpComma);
743
744 default:
745 {
746 const TType &resultType = binaryNode->getType();
747 const TType &leftType = binaryNode->getLeft()->getType();
748 const TType &rightType = binaryNode->getRight()->getType();
749 return IsSymbolicOperator(binaryNode->getOp(), resultType, &leftType, &rightType);
750 }
751 }
752 }
753
754 return true;
755 }
756
groupedTraverse(TIntermNode & node)757 void GenMetalTraverser::groupedTraverse(TIntermNode &node)
758 {
759 const bool emitParens = Parenthesize(node);
760
761 if (emitParens)
762 {
763 mOut << "(";
764 }
765
766 node.traverse(this);
767
768 if (emitParens)
769 {
770 mOut << ")";
771 }
772 }
773
emitPostQualifier(const EmitVariableDeclarationConfig & evdConfig,const VarDecl & decl,const TQualifier qualifier)774 void GenMetalTraverser::emitPostQualifier(const EmitVariableDeclarationConfig &evdConfig,
775 const VarDecl &decl,
776 const TQualifier qualifier)
777 {
778 switch (qualifier)
779 {
780 case TQualifier::EvqPosition:
781 case TQualifier::EvqFragCoord:
782 mOut << " [[position]]";
783 break;
784
785 case TQualifier::EvqPointSize:
786 mOut << " [[point_size]]";
787 break;
788
789 case TQualifier::EvqVertexID:
790 if (evdConfig.isMainParameter)
791 {
792 mOut << " [[vertex_id]]";
793 }
794 break;
795
796 case TQualifier::EvqPointCoord:
797 if (evdConfig.isMainParameter)
798 {
799 mOut << " [[point_coord]]";
800 }
801 break;
802
803 case TQualifier::EvqFrontFacing:
804 if (evdConfig.isMainParameter)
805 {
806 mOut << " [[front_facing]]";
807 }
808 break;
809
810 default:
811 break;
812 }
813
814 const bool isInvariant =
815 (decl.isField() ? mInvariants.contains(decl.field())
816 : mInvariants.contains(decl.variable())) &&
817 (qualifier == TQualifier::EvqPosition || qualifier == TQualifier::EvqFragCoord);
818
819 if (isInvariant)
820 {
821 mOut << " [[invariant]]";
822 }
823 }
824
EmitName(Sink & out,const Name & name)825 static void EmitName(Sink &out, const Name &name)
826 {
827 #if defined(ANGLE_ENABLE_ASSERTS)
828 DebugSink::EscapedSink escapedOut(out.escape());
829 #else
830 TInfoSinkBase &escapedOut = out;
831 #endif
832 name.emit(escapedOut);
833 }
834
emitNameOf(const TField & object)835 void GenMetalTraverser::emitNameOf(const TField &object)
836 {
837 EmitName(mOut, Name(object));
838 }
839
emitNameOf(const TSymbol & object)840 void GenMetalTraverser::emitNameOf(const TSymbol &object)
841 {
842 auto it = mRenamedSymbols.find(&object);
843 if (it == mRenamedSymbols.end())
844 {
845 EmitName(mOut, Name(object));
846 }
847 else
848 {
849 EmitName(mOut, it->second);
850 }
851 }
852
emitNameOf(const VarDecl & object)853 void GenMetalTraverser::emitNameOf(const VarDecl &object)
854 {
855 if (object.isField())
856 {
857 emitNameOf(object.field());
858 }
859 else
860 {
861 emitNameOf(object.variable());
862 }
863 }
864
emitBareTypeName(const TType & type,const EmitTypeConfig & etConfig)865 void GenMetalTraverser::emitBareTypeName(const TType &type, const EmitTypeConfig &etConfig)
866 {
867 const TBasicType basicType = type.getBasicType();
868
869 switch (basicType)
870 {
871 case TBasicType::EbtVoid:
872 case TBasicType::EbtBool:
873 case TBasicType::EbtFloat:
874 case TBasicType::EbtInt:
875 case TBasicType::EbtUInt:
876 {
877 mOut << type.getBasicString();
878 }
879 break;
880
881 case TBasicType::EbtStruct:
882 {
883 const TStructure &structure = *type.getStruct();
884 emitNameOf(structure);
885 }
886 break;
887
888 case TBasicType::EbtInterfaceBlock:
889 {
890 const TInterfaceBlock &interfaceBlock = *type.getInterfaceBlock();
891 emitNameOf(interfaceBlock);
892 }
893 break;
894
895 default:
896 {
897 if (IsSampler(basicType))
898 {
899 if (etConfig.evdConfig && etConfig.evdConfig->isMainParameter)
900 {
901 EmitName(mOut, GetTextureTypeName(basicType));
902 }
903 else
904 {
905 const TStructure &env = mSymbolEnv.getTextureEnv(basicType);
906 emitNameOf(env);
907 }
908 }
909 else
910 {
911 UNIMPLEMENTED();
912 }
913 }
914 }
915 }
916
emitType(const TType & type,const EmitTypeConfig & etConfig)917 void GenMetalTraverser::emitType(const TType &type, const EmitTypeConfig &etConfig)
918 {
919 const bool isUBO = etConfig.evdConfig ? etConfig.evdConfig->isUBO : false;
920 if (etConfig.evdConfig)
921 {
922 const auto &evdConfig = *etConfig.evdConfig;
923 if (isUBO)
924 {
925 if (type.isArray())
926 {
927 mOut << "ANGLE_tensor<";
928 }
929 }
930 if (evdConfig.isPointer)
931 {
932 mOut << toString(*evdConfig.isPointer);
933 mOut << " ";
934 }
935 else if (evdConfig.isReference)
936 {
937 mOut << toString(*evdConfig.isReference);
938 mOut << " ";
939 }
940 }
941
942 if (!isUBO)
943 {
944 if (type.isArray())
945 {
946 mOut << "ANGLE_tensor<";
947 }
948 }
949
950 if (type.isVector() || type.isMatrix())
951 {
952 mOut << "metal::";
953 }
954
955 if (etConfig.evdConfig && etConfig.evdConfig->isPacked)
956 {
957 mOut << "packed_";
958 }
959
960 emitBareTypeName(type, etConfig);
961
962 if (type.isVector())
963 {
964 mOut << type.getNominalSize();
965 }
966 else if (type.isMatrix())
967 {
968 mOut << type.getCols() << "x" << type.getRows();
969 }
970
971 if (!isUBO)
972 {
973 if (type.isArray())
974 {
975 for (auto size : type.getArraySizes())
976 {
977 mOut << ", " << size;
978 }
979 mOut << ">";
980 }
981 }
982
983 if (etConfig.evdConfig)
984 {
985 const auto &evdConfig = *etConfig.evdConfig;
986 if (evdConfig.isPointer)
987 {
988 mOut << " *";
989 }
990 else if (evdConfig.isReference)
991 {
992 mOut << " &";
993 }
994 if (isUBO)
995 {
996 if (type.isArray())
997 {
998 for (auto size : type.getArraySizes())
999 {
1000 mOut << ", " << size;
1001 }
1002 mOut << ">";
1003 }
1004 }
1005 }
1006 }
1007
emitFieldDeclaration(const TField & field,const TStructure & parent,FieldAnnotationIndices & annotationIndices)1008 void GenMetalTraverser::emitFieldDeclaration(const TField &field,
1009 const TStructure &parent,
1010 FieldAnnotationIndices &annotationIndices)
1011 {
1012 const TType &type = *field.type();
1013 const TBasicType basic = type.getBasicType();
1014
1015 EmitVariableDeclarationConfig evdConfig;
1016 evdConfig.emitPostQualifier = true;
1017 evdConfig.disableStructSpecifier = true;
1018 evdConfig.isPacked = mSymbolEnv.isPacked(field);
1019 evdConfig.isUBO = mSymbolEnv.isUBO(field);
1020 evdConfig.isPointer = mSymbolEnv.isPointer(field);
1021 evdConfig.isReference = mSymbolEnv.isReference(field);
1022 emitVariableDeclaration(VarDecl(field), evdConfig);
1023
1024 const TQualifier qual = type.getQualifier();
1025 switch (qual)
1026 {
1027 case TQualifier::EvqFlatIn:
1028 if (mPipelineStructs.fragmentIn.external == &parent)
1029 {
1030 mOut << " [[flat]]";
1031 TranslatorMetalReflection *reflection =
1032 ((sh::TranslatorMetalDirect *)&mCompiler)->getTranslatorMetalReflection();
1033 reflection->hasFlatInput = true;
1034 }
1035 break;
1036
1037 case TQualifier::EvqFragmentOut:
1038 case TQualifier::EvqFragData:
1039 if (mPipelineStructs.fragmentOut.external == &parent)
1040 {
1041 if ((type.isVector() &&
1042 (basic == TBasicType::EbtInt || basic == TBasicType::EbtUInt ||
1043 basic == TBasicType::EbtFloat)) ||
1044 type.getQualifier() == EvqFragData)
1045 {
1046 // TODO:
1047 // This is not correct in general and needs a reimplementation.
1048 // In GLSL 3.0, We can't always assume this is going to be a safe index.
1049 // See 4.3.8.2
1050 // https://www.khronos.org/registry/OpenGL/specs/es/3.0/GLSL_ES_Specification_3.00.pdf
1051 size_t index = annotationIndices.color++;
1052 mOut << " [[color(" << index << ")]]";
1053 }
1054 }
1055 break;
1056
1057 case TQualifier::EvqFragDepth:
1058 mOut << " [[depth(any)]]";
1059 break;
1060
1061 case TQualifier::EvqSampleMask:
1062 mOut << " [[sample_mask, function_constant(" << sh::mtl::kCoverageMaskEnabledConstName
1063 << ")]]";
1064 break;
1065
1066 default:
1067 break;
1068 }
1069 }
1070
BuildExternalAttributeIndexMap(const TCompiler & compiler,const PipelineScoped<TStructure> & structure)1071 static std::map<Name, size_t> BuildExternalAttributeIndexMap(
1072 const TCompiler &compiler,
1073 const PipelineScoped<TStructure> &structure)
1074 {
1075 ASSERT(structure.isTotallyFull());
1076
1077 const auto &shaderVars = compiler.getAttributes();
1078 const size_t shaderVarSize = shaderVars.size();
1079 size_t shaderVarIndex = 0;
1080
1081 const auto &externalFields = structure.external->fields();
1082 const size_t externalSize = externalFields.size();
1083 size_t externalIndex = 0;
1084
1085 const auto &internalFields = structure.internal->fields();
1086 const size_t internalSize = internalFields.size();
1087 size_t internalIndex = 0;
1088
1089 // Internal fields are never split. External fields are sometimes split.
1090 ASSERT(externalSize >= internalSize);
1091
1092 // Structures do not contain any inactive fields.
1093 ASSERT(shaderVarSize >= internalSize);
1094
1095 std::map<Name, size_t> externalNameToAttributeIndex;
1096 size_t attributeIndex = 0;
1097
1098 while (internalIndex < internalSize)
1099 {
1100 const TField &internalField = *internalFields[internalIndex];
1101 const Name internalName = Name(internalField);
1102 const TType &internalType = *internalField.type();
1103 while (internalName.rawName() != shaderVars[shaderVarIndex].name &&
1104 internalName.rawName() != shaderVars[shaderVarIndex].mappedName)
1105 {
1106 // This case represents an inactive field.
1107
1108 ++shaderVarIndex;
1109 ASSERT(shaderVarIndex < shaderVarSize);
1110
1111 ++attributeIndex; // TODO: Might need to increment more if shader var type is a matrix.
1112 }
1113
1114 const size_t cols = internalType.isMatrix() ? internalType.getCols() : 1;
1115
1116 for (size_t c = 0; c < cols; ++c)
1117 {
1118 const TField &externalField = *externalFields[externalIndex];
1119 const Name externalName = Name(externalField);
1120 ASSERT(!externalField.type()->isMatrix());
1121
1122 externalNameToAttributeIndex[externalName] = attributeIndex;
1123
1124 ++externalIndex;
1125 ++attributeIndex;
1126 }
1127
1128 ++shaderVarIndex;
1129 ++internalIndex;
1130 }
1131
1132 ASSERT(shaderVarIndex <= shaderVarSize);
1133 ASSERT(externalIndex <= externalSize); // less than if padding was introduced
1134 ASSERT(internalIndex == internalSize);
1135
1136 return externalNameToAttributeIndex;
1137 }
1138
emitAttributeDeclaration(const TField & field,FieldAnnotationIndices & annotationIndices)1139 void GenMetalTraverser::emitAttributeDeclaration(const TField &field,
1140 FieldAnnotationIndices &annotationIndices)
1141 {
1142 EmitVariableDeclarationConfig evdConfig;
1143 evdConfig.disableStructSpecifier = true;
1144 emitVariableDeclaration(VarDecl(field), evdConfig);
1145 mOut << sh::kUnassignedAttributeString;
1146 }
1147
emitUniformBufferDeclaration(const TField & field,FieldAnnotationIndices & annotationIndices)1148 void GenMetalTraverser::emitUniformBufferDeclaration(const TField &field,
1149 FieldAnnotationIndices &annotationIndices)
1150 {
1151 EmitVariableDeclarationConfig evdConfig;
1152 evdConfig.disableStructSpecifier = true;
1153 evdConfig.isUBO = mSymbolEnv.isUBO(field);
1154 evdConfig.isPointer = mSymbolEnv.isPointer(field);
1155 emitVariableDeclaration(VarDecl(field), evdConfig);
1156 mOut << "[[id(" << annotationIndices.attribute << ")]]";
1157
1158 const TType &type = *field.type();
1159 const int arraySize = type.isArray() ? type.getArraySizeProduct() : 1;
1160
1161 TranslatorMetalReflection *reflection =
1162 ((sh::TranslatorMetalDirect *)&mCompiler)->getTranslatorMetalReflection();
1163 ASSERT(type.getBasicType() == TBasicType::EbtStruct);
1164 const TStructure *structure = type.getStruct();
1165 const std::string originalName = reflection->getOriginalName(structure->uniqueId().get());
1166 reflection->addUniformBufferBinding(
1167 originalName,
1168 {.bindIndex = annotationIndices.attribute, .arraySize = static_cast<size_t>(arraySize)});
1169
1170 annotationIndices.attribute += arraySize;
1171 }
1172
emitStructDeclaration(const TType & type)1173 void GenMetalTraverser::emitStructDeclaration(const TType &type)
1174 {
1175 ASSERT(type.getBasicType() == TBasicType::EbtStruct);
1176 ASSERT(type.isStructSpecifier());
1177
1178 mOut << "struct ";
1179 emitBareTypeName(type, {});
1180
1181 mOut << "\n";
1182 emitOpenBrace();
1183
1184 const TStructure &structure = *type.getStruct();
1185 std::map<Name, size_t> fieldToAttributeIndex;
1186 const bool hasAttributeIndices = mPipelineStructs.vertexIn.external == &structure;
1187 const bool hasUniformBufferIndicies = mPipelineStructs.uniformBuffers.external == &structure;
1188 const bool reclaimUnusedAttributeIndices = mCompiler.getShaderVersion() < 300;
1189
1190 if (hasAttributeIndices)
1191 {
1192 fieldToAttributeIndex =
1193 BuildExternalAttributeIndexMap(mCompiler, mPipelineStructs.vertexIn);
1194 }
1195
1196 FieldAnnotationIndices annotationIndices;
1197
1198 for (const TField *field : structure.fields())
1199 {
1200 emitIndentation();
1201 if (hasAttributeIndices)
1202 {
1203 const auto it = fieldToAttributeIndex.find(Name(*field));
1204 if (it == fieldToAttributeIndex.end())
1205 {
1206 ASSERT(field->symbolType() == SymbolType::AngleInternal);
1207 ASSERT(field->name().beginsWith("_"));
1208 ASSERT(angle::EndsWith(field->name().data(), "_pad"));
1209 emitFieldDeclaration(*field, structure, annotationIndices);
1210 }
1211 else
1212 {
1213 ASSERT(field->symbolType() != SymbolType::AngleInternal ||
1214 !field->name().beginsWith("_") ||
1215 !angle::EndsWith(field->name().data(), "_pad"));
1216 if (!reclaimUnusedAttributeIndices)
1217 {
1218 annotationIndices.attribute = it->second;
1219 }
1220 emitAttributeDeclaration(*field, annotationIndices);
1221 }
1222 }
1223 else if (hasUniformBufferIndicies)
1224 {
1225 emitUniformBufferDeclaration(*field, annotationIndices);
1226 }
1227 else
1228 {
1229 emitFieldDeclaration(*field, structure, annotationIndices);
1230 }
1231 mOut << ";\n";
1232 }
1233
1234 if (!mPipelineStructs.matches(structure, true, true))
1235 {
1236 MetalLayoutOfConfig layoutConfig;
1237 layoutConfig.treatSamplersAsTextureEnv = true;
1238 Layout layout = MetalLayoutOf(type, layoutConfig);
1239 size_t pad = (kDefaultStructAlignmentSize - layout.sizeOf) % kDefaultStructAlignmentSize;
1240 if (pad != 0)
1241 {
1242 emitIndentation();
1243 mOut << "char ";
1244 EmitName(mOut, mIdGen.createNewName("pad"));
1245 mOut << "[" << pad << "];\n";
1246 }
1247 }
1248
1249 emitCloseBrace();
1250 }
1251
emitOrdinaryVariableDeclaration(const VarDecl & decl,const EmitVariableDeclarationConfig & evdConfig)1252 void GenMetalTraverser::emitOrdinaryVariableDeclaration(
1253 const VarDecl &decl,
1254 const EmitVariableDeclarationConfig &evdConfig)
1255 {
1256 EmitTypeConfig etConfig;
1257 etConfig.evdConfig = &evdConfig;
1258
1259 const TType &type = decl.type();
1260 emitType(type, etConfig);
1261 if (decl.symbolType() != SymbolType::Empty)
1262 {
1263 mOut << " ";
1264 emitNameOf(decl);
1265 }
1266 }
1267
emitVariableDeclaration(const VarDecl & decl,const EmitVariableDeclarationConfig & evdConfig)1268 void GenMetalTraverser::emitVariableDeclaration(const VarDecl &decl,
1269 const EmitVariableDeclarationConfig &evdConfig)
1270 {
1271 const SymbolType symbolType = decl.symbolType();
1272 const TType &type = decl.type();
1273 const TBasicType basicType = type.getBasicType();
1274
1275 switch (basicType)
1276 {
1277 case TBasicType::EbtStruct:
1278 {
1279 if (type.isStructSpecifier() && !evdConfig.disableStructSpecifier)
1280 {
1281 ASSERT(!evdConfig.isParameter);
1282 emitStructDeclaration(type);
1283 if (symbolType != SymbolType::Empty)
1284 {
1285 mOut << " ";
1286 emitNameOf(decl);
1287 }
1288 }
1289 else
1290 {
1291 emitOrdinaryVariableDeclaration(decl, evdConfig);
1292 }
1293 }
1294 break;
1295
1296 default:
1297 {
1298 ASSERT(symbolType != SymbolType::Empty || evdConfig.isParameter);
1299 emitOrdinaryVariableDeclaration(decl, evdConfig);
1300 }
1301 }
1302
1303 if (evdConfig.emitPostQualifier)
1304 {
1305 emitPostQualifier(evdConfig, decl, type.getQualifier());
1306 }
1307 }
1308
visitSymbol(TIntermSymbol * symbolNode)1309 void GenMetalTraverser::visitSymbol(TIntermSymbol *symbolNode)
1310 {
1311 const TVariable &var = symbolNode->variable();
1312 const TType &type = var.getType();
1313 ASSERT(var.symbolType() != SymbolType::Empty);
1314
1315 if (type.getBasicType() == TBasicType::EbtVoid)
1316 {
1317 mOut << "/*";
1318 emitNameOf(var);
1319 mOut << "*/";
1320 }
1321 else
1322 {
1323 emitNameOf(var);
1324 }
1325 }
1326
emitSingleConstant(const TConstantUnion * const constUnion)1327 void GenMetalTraverser::emitSingleConstant(const TConstantUnion *const constUnion)
1328 {
1329 switch (constUnion->getType())
1330 {
1331 case TBasicType::EbtBool:
1332 {
1333 mOut << (constUnion->getBConst() ? "true" : "false");
1334 }
1335 break;
1336
1337 case TBasicType::EbtFloat:
1338 {
1339 mOut << constUnion->getFConst() << "f";
1340 }
1341 break;
1342
1343 case TBasicType::EbtInt:
1344 {
1345 mOut << constUnion->getIConst();
1346 }
1347 break;
1348
1349 case TBasicType::EbtUInt:
1350 {
1351 mOut << constUnion->getUConst() << "u";
1352 }
1353 break;
1354
1355 default:
1356 {
1357 UNIMPLEMENTED();
1358 }
1359 }
1360 }
1361
emitConstantUnionArray(const TConstantUnion * const constUnion,const size_t size)1362 const TConstantUnion *GenMetalTraverser::emitConstantUnionArray(
1363 const TConstantUnion *const constUnion,
1364 const size_t size)
1365 {
1366 const TConstantUnion *constUnionIterated = constUnion;
1367 for (size_t i = 0; i < size; i++, constUnionIterated++)
1368 {
1369 emitSingleConstant(constUnionIterated);
1370
1371 if (i != size - 1)
1372 {
1373 mOut << ", ";
1374 }
1375 }
1376 return constUnionIterated;
1377 }
1378
emitConstantUnion(const TType & type,const TConstantUnion * constUnionBegin)1379 const TConstantUnion *GenMetalTraverser::emitConstantUnion(const TType &type,
1380 const TConstantUnion *constUnionBegin)
1381 {
1382 const TConstantUnion *constUnionCurr = constUnionBegin;
1383 const TStructure *structure = type.getStruct();
1384 if (structure)
1385 {
1386 EmitTypeConfig config = EmitTypeConfig{nullptr};
1387 emitType(type, config);
1388 mOut << "{";
1389 const TFieldList &fields = structure->fields();
1390 for (size_t i = 0; i < fields.size(); ++i)
1391 {
1392 const TType *fieldType = fields[i]->type();
1393 constUnionCurr = emitConstantUnion(*fieldType, constUnionCurr);
1394 if (i != fields.size() - 1)
1395 {
1396 mOut << ", ";
1397 }
1398 }
1399 mOut << "}";
1400 }
1401 else
1402 {
1403 size_t size = type.getObjectSize();
1404 bool writeType = size > 1;
1405 if (writeType)
1406 {
1407 EmitTypeConfig config = EmitTypeConfig{nullptr};
1408 emitType(type, config);
1409 mOut << "(";
1410 }
1411 constUnionCurr = emitConstantUnionArray(constUnionCurr, size);
1412 if (writeType)
1413 {
1414 mOut << ")";
1415 }
1416 }
1417 return constUnionCurr;
1418 }
1419
visitConstantUnion(TIntermConstantUnion * constValueNode)1420 void GenMetalTraverser::visitConstantUnion(TIntermConstantUnion *constValueNode)
1421 {
1422 emitConstantUnion(constValueNode->getType(), constValueNode->getConstantValue());
1423 }
1424
visitSwizzle(Visit,TIntermSwizzle * swizzleNode)1425 bool GenMetalTraverser::visitSwizzle(Visit, TIntermSwizzle *swizzleNode)
1426 {
1427 groupedTraverse(*swizzleNode->getOperand());
1428 mOut << ".";
1429
1430 {
1431 #if defined(ANGLE_ENABLE_ASSERTS)
1432 DebugSink::EscapedSink escapedOut(mOut.escape());
1433 TInfoSinkBase &out = escapedOut.get();
1434 #else
1435 TInfoSinkBase &out = mOut;
1436 #endif
1437 swizzleNode->writeOffsetsAsXYZW(&out);
1438 }
1439
1440 return false;
1441 }
1442
getDirectField(const TFieldListCollection & fieldListCollection,const TConstantUnion & index)1443 const TField &GenMetalTraverser::getDirectField(const TFieldListCollection &fieldListCollection,
1444 const TConstantUnion &index)
1445 {
1446 ASSERT(index.getType() == TBasicType::EbtInt);
1447
1448 const TFieldList &fieldList = fieldListCollection.fields();
1449 const int indexVal = index.getIConst();
1450 const TField &field = *fieldList[indexVal];
1451
1452 return field;
1453 }
1454
getDirectField(const TIntermTyped & fieldsNode,TIntermTyped & indexNode)1455 const TField &GenMetalTraverser::getDirectField(const TIntermTyped &fieldsNode,
1456 TIntermTyped &indexNode)
1457 {
1458 const TType &fieldsType = fieldsNode.getType();
1459
1460 const TFieldListCollection *fieldListCollection = fieldsType.getStruct();
1461 if (fieldListCollection == nullptr)
1462 {
1463 fieldListCollection = fieldsType.getInterfaceBlock();
1464 }
1465 ASSERT(fieldListCollection);
1466
1467 const TIntermConstantUnion *indexNode_ = indexNode.getAsConstantUnion();
1468 ASSERT(indexNode_);
1469 const TConstantUnion &index = *indexNode_->getConstantValue();
1470
1471 return getDirectField(*fieldListCollection, index);
1472 }
1473
visitBinary(Visit,TIntermBinary * binaryNode)1474 bool GenMetalTraverser::visitBinary(Visit, TIntermBinary *binaryNode)
1475 {
1476 const TOperator op = binaryNode->getOp();
1477 TIntermTyped &leftNode = *binaryNode->getLeft();
1478 TIntermTyped &rightNode = *binaryNode->getRight();
1479
1480 switch (op)
1481 {
1482 case TOperator::EOpIndexDirectStruct:
1483 case TOperator::EOpIndexDirectInterfaceBlock:
1484 {
1485 const TField &field = getDirectField(leftNode, rightNode);
1486 if (mSymbolEnv.isPointer(field) && mSymbolEnv.isUBO(field))
1487 {
1488 emitOpeningPointerParen();
1489 }
1490 groupedTraverse(leftNode);
1491 if (!mSymbolEnv.isPointer(field))
1492 {
1493 emitClosingPointerParen();
1494 }
1495 mOut << ".";
1496 emitNameOf(field);
1497 }
1498 break;
1499
1500 case TOperator::EOpIndexDirect:
1501 case TOperator::EOpIndexIndirect:
1502 {
1503 TType leftType = leftNode.getType();
1504 groupedTraverse(leftNode);
1505 mOut << "[";
1506 {
1507 mOut << "ANGLE_int_clamp(";
1508 groupedTraverse(rightNode);
1509 mOut << ", 0, ";
1510 if (leftType.isUnsizedArray())
1511 {
1512 groupedTraverse(leftNode);
1513 mOut << ".size()";
1514 }
1515 else
1516 {
1517 int maxSize;
1518 if (leftType.isArray())
1519 {
1520 maxSize = static_cast<int>(leftType.getOutermostArraySize()) - 1;
1521 }
1522 else
1523 {
1524 maxSize = leftType.getNominalSize() - 1;
1525 }
1526 mOut << maxSize;
1527 }
1528 mOut << ")";
1529 }
1530 mOut << "]";
1531 }
1532 break;
1533
1534 default:
1535 {
1536 const TType &resultType = binaryNode->getType();
1537 const TType &leftType = leftNode.getType();
1538 const TType &rightType = rightNode.getType();
1539
1540 if (IsSymbolicOperator(op, resultType, &leftType, &rightType))
1541 {
1542 groupedTraverse(leftNode);
1543 if (op != TOperator::EOpComma)
1544 {
1545 mOut << " ";
1546 }
1547 else
1548 {
1549 emitClosingPointerParen();
1550 }
1551 mOut << GetOperatorString(op, resultType, &leftType, &rightType) << " ";
1552 groupedTraverse(rightNode);
1553 }
1554 else
1555 {
1556 emitClosingPointerParen();
1557 mOut << GetOperatorString(op, resultType, &leftType, &rightType) << "(";
1558 leftNode.traverse(this);
1559 mOut << ", ";
1560 rightNode.traverse(this);
1561 mOut << ")";
1562 }
1563 }
1564 }
1565
1566 return false;
1567 }
1568
IsPostfix(TOperator op)1569 static bool IsPostfix(TOperator op)
1570 {
1571 switch (op)
1572 {
1573 case TOperator::EOpPostIncrement:
1574 case TOperator::EOpPostDecrement:
1575 return true;
1576
1577 default:
1578 return false;
1579 }
1580 }
1581
visitUnary(Visit,TIntermUnary * unaryNode)1582 bool GenMetalTraverser::visitUnary(Visit, TIntermUnary *unaryNode)
1583 {
1584 const TOperator op = unaryNode->getOp();
1585 const TType &resultType = unaryNode->getType();
1586
1587 TIntermTyped &arg = *unaryNode->getOperand();
1588 const TType &argType = arg.getType();
1589
1590 const char *name = GetOperatorString(op, resultType, &argType);
1591
1592 if (IsSymbolicOperator(op, resultType, &argType))
1593 {
1594 const bool postfix = IsPostfix(op);
1595 if (!postfix)
1596 {
1597 mOut << name;
1598 }
1599 groupedTraverse(arg);
1600 if (postfix)
1601 {
1602 mOut << name;
1603 }
1604 }
1605 else
1606 {
1607 mOut << name << "(";
1608 arg.traverse(this);
1609 mOut << ")";
1610 }
1611
1612 return false;
1613 }
1614
visitTernary(Visit,TIntermTernary * conditionalNode)1615 bool GenMetalTraverser::visitTernary(Visit, TIntermTernary *conditionalNode)
1616 {
1617 groupedTraverse(*conditionalNode->getCondition());
1618 mOut << " ? ";
1619 groupedTraverse(*conditionalNode->getTrueExpression());
1620 mOut << " : ";
1621 groupedTraverse(*conditionalNode->getFalseExpression());
1622
1623 return false;
1624 }
1625
visitIfElse(Visit,TIntermIfElse * ifThenElseNode)1626 bool GenMetalTraverser::visitIfElse(Visit, TIntermIfElse *ifThenElseNode)
1627 {
1628 TIntermTyped &condNode = *ifThenElseNode->getCondition();
1629 TIntermBlock *thenNode = ifThenElseNode->getTrueBlock();
1630 TIntermBlock *elseNode = ifThenElseNode->getFalseBlock();
1631
1632 emitIndentation();
1633 mOut << "if (";
1634 condNode.traverse(this);
1635 mOut << ")";
1636
1637 if (thenNode)
1638 {
1639 mOut << "\n";
1640 thenNode->traverse(this);
1641 }
1642 else
1643 {
1644 mOut << " {}";
1645 }
1646
1647 if (elseNode)
1648 {
1649 mOut << "\n";
1650 emitIndentation();
1651 mOut << "else\n";
1652 elseNode->traverse(this);
1653 }
1654 else
1655 {
1656 // Always emit "else" even when empty block to avoid nested if-stmt issues.
1657 mOut << " else {}";
1658 }
1659
1660 return false;
1661 }
1662
visitSwitch(Visit,TIntermSwitch * switchNode)1663 bool GenMetalTraverser::visitSwitch(Visit, TIntermSwitch *switchNode)
1664 {
1665 emitIndentation();
1666 mOut << "switch (";
1667 switchNode->getInit()->traverse(this);
1668 mOut << ")\n";
1669
1670 ASSERT(!mParentIsSwitch);
1671 mParentIsSwitch = true;
1672 switchNode->getStatementList()->traverse(this);
1673 mParentIsSwitch = false;
1674
1675 return false;
1676 }
1677
visitCase(Visit,TIntermCase * caseNode)1678 bool GenMetalTraverser::visitCase(Visit, TIntermCase *caseNode)
1679 {
1680 emitIndentation();
1681
1682 if (caseNode->hasCondition())
1683 {
1684 TIntermTyped *condExpr = caseNode->getCondition();
1685 mOut << "case ";
1686 condExpr->traverse(this);
1687 mOut << ":";
1688 }
1689 else
1690 {
1691 mOut << "default:\n";
1692 }
1693
1694 return false;
1695 }
1696
emitFunctionSignature(const TFunction & func)1697 void GenMetalTraverser::emitFunctionSignature(const TFunction &func)
1698 {
1699 const bool isMain = func.isMain();
1700
1701 emitFunctionReturn(func);
1702
1703 mOut << " ";
1704 emitNameOf(func);
1705 if (isMain)
1706 {
1707 mOut << "0";
1708 }
1709 mOut << "(";
1710
1711 bool emitComma = false;
1712 const size_t paramCount = func.getParamCount();
1713 for (size_t i = 0; i < paramCount; ++i)
1714 {
1715 if (emitComma)
1716 {
1717 mOut << ", ";
1718 }
1719 emitComma = true;
1720
1721 const TVariable ¶m = *func.getParam(i);
1722 emitFunctionParameter(func, param);
1723 }
1724 // TODO(anglebug.com/5505): reimplement transform feedback support.
1725 // if (isTraversingVertexMain)
1726 // {
1727 // mOut << " @@XFB-Bindings@@ ";
1728 // }
1729
1730 mOut << ")";
1731 }
1732
emitFunctionReturn(const TFunction & func)1733 void GenMetalTraverser::emitFunctionReturn(const TFunction &func)
1734 {
1735 const bool isMain = func.isMain();
1736 bool isVertexMain = false;
1737 const TType &returnType = func.getReturnType();
1738 if (isMain)
1739 {
1740 const TStructure *structure = returnType.getStruct();
1741 ASSERT(structure != nullptr);
1742 if (mPipelineStructs.fragmentOut.matches(*structure))
1743 {
1744 mOut << "fragment ";
1745 }
1746 else if (mPipelineStructs.vertexOut.matches(*structure))
1747 {
1748 mOut << "vertex __VERTEX_OUT(";
1749 isVertexMain = true;
1750 }
1751 else
1752 {
1753 UNIMPLEMENTED();
1754 }
1755 }
1756 emitType(returnType, EmitTypeConfig());
1757 if (isVertexMain)
1758 mOut << ") ";
1759 }
1760
emitFunctionParameter(const TFunction & func,const TVariable & param)1761 void GenMetalTraverser::emitFunctionParameter(const TFunction &func, const TVariable ¶m)
1762 {
1763 const bool isMain = func.isMain();
1764
1765 const TType &type = param.getType();
1766 const TStructure *structure = type.getStruct();
1767
1768 EmitVariableDeclarationConfig evdConfig;
1769 evdConfig.isParameter = true;
1770 evdConfig.isMainParameter = isMain;
1771 evdConfig.emitPostQualifier = isMain;
1772 evdConfig.isUBO = mSymbolEnv.isUBO(param);
1773 evdConfig.isPointer = mSymbolEnv.isPointer(param);
1774 evdConfig.isReference = mSymbolEnv.isReference(param);
1775 emitVariableDeclaration(VarDecl(param), evdConfig);
1776
1777 if (isMain)
1778 {
1779 TranslatorMetalReflection *reflection =
1780 ((sh::TranslatorMetalDirect *)&mCompiler)->getTranslatorMetalReflection();
1781 if (structure)
1782 {
1783 if (mPipelineStructs.fragmentIn.matches(*structure) ||
1784 mPipelineStructs.vertexIn.matches(*structure))
1785 {
1786 mOut << " [[stage_in]]";
1787 }
1788 else if (mPipelineStructs.angleUniforms.matches(*structure))
1789 {
1790 mOut << " [[buffer(" << mDriverUniformsBindingIndex << ")]]";
1791 }
1792 else if (mPipelineStructs.uniformBuffers.matches(*structure))
1793 {
1794 mOut << " [[buffer(" << mUBOArgumentBufferBindingIndex << ")]]";
1795 reflection->hasUBOs = true;
1796 }
1797 else if (mPipelineStructs.userUniforms.matches(*structure))
1798 {
1799 mOut << " [[buffer(" << mMainUniformBufferIndex << ")]]";
1800 reflection->addUserUniformBufferBinding(param.name().data(),
1801 mMainUniformBufferIndex);
1802 mMainUniformBufferIndex += type.getArraySizeProduct();
1803 }
1804 else if (structure->name() == "metal::sampler")
1805 {
1806 mOut << " [[sampler(" << (mMainSamplerIndex) << ")]]";
1807 const std::string originalName =
1808 reflection->getOriginalName(param.uniqueId().get());
1809 reflection->addSamplerBinding(originalName, mMainSamplerIndex);
1810 mMainSamplerIndex += type.getArraySizeProduct();
1811 }
1812 }
1813 else if (IsSampler(type.getBasicType()))
1814 {
1815 mOut << " [[texture(" << (mMainTextureIndex) << ")]]";
1816 const std::string originalName = reflection->getOriginalName(param.uniqueId().get());
1817 reflection->addTextureBinding(originalName, mMainSamplerIndex);
1818 mMainTextureIndex += type.getArraySizeProduct();
1819 }
1820 else if (Name(param) == Pipeline{Pipeline::Type::InstanceId, nullptr}.getStructInstanceName(
1821 Pipeline::Variant::Modified))
1822 {
1823 mOut << " [[instance_id]]";
1824 }
1825 }
1826 }
1827
visitFunctionPrototype(TIntermFunctionPrototype * funcProtoNode)1828 void GenMetalTraverser::visitFunctionPrototype(TIntermFunctionPrototype *funcProtoNode)
1829 {
1830 const TFunction &func = *funcProtoNode->getFunction();
1831
1832 emitIndentation();
1833 emitFunctionSignature(func);
1834 }
1835
visitFunctionDefinition(Visit,TIntermFunctionDefinition * funcDefNode)1836 bool GenMetalTraverser::visitFunctionDefinition(Visit, TIntermFunctionDefinition *funcDefNode)
1837 {
1838 const TFunction &func = *funcDefNode->getFunction();
1839 TIntermBlock &body = *funcDefNode->getBody();
1840 if (func.isMain())
1841 {
1842 const TType &returnType = func.getReturnType();
1843 const TStructure *structure = returnType.getStruct();
1844 isTraversingVertexMain = (mPipelineStructs.vertexOut.matches(*structure));
1845 }
1846 emitIndentation();
1847 emitFunctionSignature(func);
1848 mOut << "\n";
1849 body.traverse(this);
1850 if (isTraversingVertexMain)
1851 {
1852 isTraversingVertexMain = false;
1853 }
1854 return false;
1855 }
1856
BuildFuncToName()1857 GenMetalTraverser::FuncToName GenMetalTraverser::BuildFuncToName()
1858 {
1859 FuncToName map;
1860
1861 auto putAngle = [&](const char *nameStr) {
1862 const ImmutableString name(nameStr);
1863 ASSERT(map.find(name) == map.end());
1864 map[name] = Name(nameStr, SymbolType::AngleInternal);
1865 };
1866
1867 putAngle("texelFetch");
1868 putAngle("texelFetchOffset");
1869 putAngle("texture");
1870 putAngle("texture1D");
1871 putAngle("texture1DLod");
1872 putAngle("texture1DProjLod");
1873 putAngle("texture2D");
1874 putAngle("texture2DLod");
1875 putAngle("texture2DProj");
1876 putAngle("texture2DRect");
1877 putAngle("texture2DProjLod");
1878 putAngle("texture2DRectProj");
1879 putAngle("texture3D");
1880 putAngle("texture3DLod");
1881 putAngle("texture3DProjLod");
1882 putAngle("textureCube");
1883 putAngle("textureCubeLod");
1884 putAngle("textureCubeProjLod");
1885 putAngle("textureGrad");
1886 putAngle("textureGradOffset");
1887 putAngle("textureLod");
1888 putAngle("textureLodOffset");
1889 putAngle("textureOffset");
1890 putAngle("textureProj");
1891 putAngle("textureProjGrad");
1892 putAngle("textureProjGradOffset");
1893 putAngle("textureProjLod");
1894 putAngle("textureProjLodOffset");
1895 putAngle("textureProjOffset");
1896 putAngle("textureSize");
1897
1898 return map;
1899 }
1900
visitAggregate(Visit,TIntermAggregate * aggregateNode)1901 bool GenMetalTraverser::visitAggregate(Visit, TIntermAggregate *aggregateNode)
1902 {
1903 const TIntermSequence &args = *aggregateNode->getSequence();
1904
1905 auto emitArgList = [&](const char *open, const char *close) {
1906 mOut << open;
1907
1908 bool emitComma = false;
1909 for (TIntermNode *arg : args)
1910 {
1911 if (emitComma)
1912 {
1913 emitClosingPointerParen();
1914 mOut << ", ";
1915 }
1916 emitComma = true;
1917 arg->traverse(this);
1918 }
1919
1920 mOut << close;
1921 };
1922
1923 const TType &retType = aggregateNode->getType();
1924
1925 if (aggregateNode->isConstructor())
1926 {
1927 const bool isStandalone = getParentNode()->getAsBlock();
1928 if (isStandalone)
1929 {
1930 // Prevent constructor from being interpreted as a declaration by wrapping in parens.
1931 // This can happen if given something like:
1932 // int(symbol); // <- This will be treated like `int symbol;`... don't want that.
1933 // So instead emit:
1934 // (int(symbol));
1935 mOut << "(";
1936 }
1937
1938 const EmitTypeConfig etConfig;
1939
1940 if (retType.isArray())
1941 {
1942 emitType(retType, etConfig);
1943 emitArgList("{", "}");
1944 }
1945 else if (retType.getStruct())
1946 {
1947 emitType(retType, etConfig);
1948 emitArgList("{", "}");
1949 }
1950 else
1951 {
1952 emitType(retType, etConfig);
1953 emitArgList("(", ")");
1954 }
1955
1956 if (isStandalone)
1957 {
1958 mOut << ")";
1959 }
1960
1961 return false;
1962 }
1963 else
1964 {
1965 const TOperator op = aggregateNode->getOp();
1966 switch (op)
1967 {
1968 case TOperator::EOpCallFunctionInAST:
1969 case TOperator::EOpCallInternalRawFunction:
1970 {
1971 const TFunction &func = *aggregateNode->getFunction();
1972 emitNameOf(func);
1973 //'@' symbol in name specifices a macro substitution marker.
1974 if (!func.name().contains("@"))
1975 {
1976 emitArgList("(", ")");
1977 }
1978 else
1979 {
1980 mTemporarilyDisableSemicolon =
1981 true; // Disable semicolon for macro substitution.
1982 }
1983 return false;
1984 }
1985
1986 default:
1987 {
1988 auto getArgType = [&](size_t index) -> const TType * {
1989 if (index < args.size())
1990 {
1991 TIntermTyped *arg = args[index]->getAsTyped();
1992 ASSERT(arg);
1993 return &arg->getType();
1994 }
1995 return nullptr;
1996 };
1997
1998 ASSERT(!args.empty());
1999 const TType *argType0 = getArgType(0);
2000 const TType *argType1 = getArgType(1);
2001 ASSERT(argType0);
2002
2003 const char *opName = GetOperatorString(op, retType, argType0, argType1);
2004
2005 if (IsSymbolicOperator(op, retType, argType0, argType1))
2006 {
2007 switch (args.size())
2008 {
2009 case 1:
2010 {
2011 TIntermNode &operandNode = *aggregateNode->getChildNode(0);
2012 if (IsPostfix(op))
2013 {
2014 mOut << opName;
2015 groupedTraverse(operandNode);
2016 return false;
2017 }
2018 else
2019 {
2020 groupedTraverse(operandNode);
2021 mOut << opName;
2022 return false;
2023 }
2024 }
2025 break;
2026
2027 case 2:
2028 {
2029 TIntermNode &leftNode = *aggregateNode->getChildNode(0);
2030 TIntermNode &rightNode = *aggregateNode->getChildNode(1);
2031 groupedTraverse(leftNode);
2032 mOut << " " << opName << " ";
2033 groupedTraverse(rightNode);
2034 return false;
2035 }
2036 break;
2037
2038 default:
2039 UNREACHABLE();
2040 return false;
2041 }
2042 }
2043 else if (opName == nullptr)
2044 {
2045 const TFunction &func = *aggregateNode->getFunction();
2046 auto it = mFuncToName.find(func.name());
2047 ASSERT(it != mFuncToName.end());
2048 EmitName(mOut, it->second);
2049 emitArgList("(", ")");
2050 return false;
2051 }
2052 else
2053 {
2054 mOut << opName;
2055 emitArgList("(", ")");
2056 return false;
2057 }
2058 }
2059 }
2060 }
2061 }
2062
emitOpenBrace()2063 void GenMetalTraverser::emitOpenBrace()
2064 {
2065 ASSERT(mIndentLevel >= 0);
2066
2067 emitIndentation();
2068 mOut << "{\n";
2069 ++mIndentLevel;
2070 }
2071
emitCloseBrace()2072 void GenMetalTraverser::emitCloseBrace()
2073 {
2074 ASSERT(mIndentLevel >= 1);
2075
2076 --mIndentLevel;
2077 emitIndentation();
2078 mOut << "}";
2079 }
2080
RequiresSemicolonTerminator(TIntermNode & node)2081 static bool RequiresSemicolonTerminator(TIntermNode &node)
2082 {
2083 if (node.getAsBlock())
2084 {
2085 return false;
2086 }
2087 if (node.getAsLoopNode())
2088 {
2089 return false;
2090 }
2091 if (node.getAsSwitchNode())
2092 {
2093 return false;
2094 }
2095 if (node.getAsIfElseNode())
2096 {
2097 return false;
2098 }
2099 if (node.getAsFunctionDefinition())
2100 {
2101 return false;
2102 }
2103 if (node.getAsCaseNode())
2104 {
2105 return false;
2106 }
2107
2108 return true;
2109 }
2110
NewlinePad(TIntermNode & node)2111 static bool NewlinePad(TIntermNode &node)
2112 {
2113 if (node.getAsFunctionDefinition())
2114 {
2115 return true;
2116 }
2117 if (TIntermDeclaration *declNode = node.getAsDeclarationNode())
2118 {
2119 ASSERT(declNode->getChildCount() == 1);
2120 TIntermNode &childNode = *declNode->getChildNode(0);
2121 if (TIntermSymbol *symbolNode = childNode.getAsSymbolNode())
2122 {
2123 const TVariable &var = symbolNode->variable();
2124 return var.getType().isStructSpecifier();
2125 }
2126 return false;
2127 }
2128 return false;
2129 }
2130
visitBlock(Visit,TIntermBlock * blockNode)2131 bool GenMetalTraverser::visitBlock(Visit, TIntermBlock *blockNode)
2132 {
2133 ASSERT(mIndentLevel >= -1);
2134 const bool isGlobalScope = mIndentLevel == -1;
2135 const bool parentIsSwitch = mParentIsSwitch;
2136 mParentIsSwitch = false;
2137
2138 if (isGlobalScope)
2139 {
2140 ++mIndentLevel;
2141 }
2142 else
2143 {
2144 emitOpenBrace();
2145 if (parentIsSwitch)
2146 {
2147 ++mIndentLevel;
2148 }
2149 }
2150
2151 TIntermNode *prevStmtNode = nullptr;
2152
2153 const size_t stmtCount = blockNode->getChildCount();
2154 for (size_t i = 0; i < stmtCount; ++i)
2155 {
2156 TIntermNode &stmtNode = *blockNode->getChildNode(i);
2157
2158 if (isGlobalScope && prevStmtNode && (NewlinePad(*prevStmtNode) || NewlinePad(stmtNode)))
2159 {
2160 mOut << "\n";
2161 }
2162 const bool isCase = stmtNode.getAsCaseNode();
2163 mIndentLevel -= isCase;
2164 emitIndentation();
2165 mIndentLevel += isCase;
2166 stmtNode.traverse(this);
2167 if (RequiresSemicolonTerminator(stmtNode) && !mTemporarilyDisableSemicolon)
2168 {
2169 mOut << ";";
2170 }
2171 mTemporarilyDisableSemicolon = false;
2172 mOut << "\n";
2173
2174 prevStmtNode = &stmtNode;
2175 }
2176
2177 if (isGlobalScope)
2178 {
2179 ASSERT(mIndentLevel == 0);
2180 --mIndentLevel;
2181 }
2182 else
2183 {
2184 if (parentIsSwitch)
2185 {
2186 ASSERT(mIndentLevel >= 1);
2187 --mIndentLevel;
2188 }
2189 emitCloseBrace();
2190 mParentIsSwitch = parentIsSwitch;
2191 }
2192
2193 return false;
2194 }
2195
visitGlobalQualifierDeclaration(Visit,TIntermGlobalQualifierDeclaration *)2196 bool GenMetalTraverser::visitGlobalQualifierDeclaration(Visit, TIntermGlobalQualifierDeclaration *)
2197 {
2198 UNREACHABLE(); // RewriteGlobalQualifierDecls should have been called before this.
2199 return false;
2200 }
2201
visitDeclaration(Visit,TIntermDeclaration * declNode)2202 bool GenMetalTraverser::visitDeclaration(Visit, TIntermDeclaration *declNode)
2203 {
2204 ASSERT(declNode->getChildCount() == 1);
2205 TIntermNode &node = *declNode->getChildNode(0);
2206
2207 EmitVariableDeclarationConfig evdConfig;
2208
2209 if (TIntermSymbol *symbolNode = node.getAsSymbolNode())
2210 {
2211 const TVariable &var = symbolNode->variable();
2212 emitVariableDeclaration(VarDecl(var), evdConfig);
2213 }
2214 else if (TIntermBinary *initNode = node.getAsBinaryNode())
2215 {
2216 ASSERT(initNode->getOp() == TOperator::EOpInitialize);
2217 TIntermSymbol *leftSymbolNode = initNode->getLeft()->getAsSymbolNode();
2218 TIntermTyped *valueNode = initNode->getRight()->getAsTyped();
2219 ASSERT(leftSymbolNode && valueNode);
2220
2221 if (getRootNode() == getParentBlock())
2222 {
2223 // DeferGlobalInitializers should have turned non-const global initializers into
2224 // deferred initializers. Note that variables marked as EvqGlobal can be treated as
2225 // EvqConst in some ANGLE code but not actually have their qualifier actually changed to
2226 // EvqConst. Thus just assume all EvqGlobal are actually EvqConst for all code run after
2227 // DeferGlobalInitializers.
2228 mOut << "constant ";
2229 }
2230
2231 const TVariable &var = leftSymbolNode->variable();
2232 const Name varName(var);
2233
2234 if (ExpressionContainsName(varName, *valueNode))
2235 {
2236 mRenamedSymbols[&var] = mIdGen.createNewName(varName);
2237 }
2238
2239 emitVariableDeclaration(VarDecl(var), evdConfig);
2240 mOut << " = ";
2241 groupedTraverse(*valueNode);
2242 }
2243 else
2244 {
2245 UNREACHABLE();
2246 }
2247
2248 return false;
2249 }
2250
visitLoop(Visit,TIntermLoop * loopNode)2251 bool GenMetalTraverser::visitLoop(Visit, TIntermLoop *loopNode)
2252 {
2253 const TLoopType loopType = loopNode->getType();
2254
2255 switch (loopType)
2256 {
2257 case TLoopType::ELoopFor:
2258 return visitForLoop(loopNode);
2259 case TLoopType::ELoopWhile:
2260 return visitWhileLoop(loopNode);
2261 case TLoopType::ELoopDoWhile:
2262 return visitDoWhileLoop(loopNode);
2263 }
2264 }
2265
visitForLoop(TIntermLoop * loopNode)2266 bool GenMetalTraverser::visitForLoop(TIntermLoop *loopNode)
2267 {
2268 ASSERT(loopNode->getType() == TLoopType::ELoopFor);
2269
2270 TIntermNode *initNode = loopNode->getInit();
2271 TIntermTyped *condNode = loopNode->getCondition();
2272 TIntermTyped *exprNode = loopNode->getExpression();
2273 TIntermBlock *bodyNode = loopNode->getBody();
2274 ASSERT(bodyNode);
2275
2276 mOut << "for (";
2277
2278 if (initNode)
2279 {
2280 initNode->traverse(this);
2281 }
2282 else
2283 {
2284 mOut << " ";
2285 }
2286
2287 mOut << "; ";
2288
2289 if (condNode)
2290 {
2291 condNode->traverse(this);
2292 }
2293
2294 mOut << "; ";
2295
2296 if (exprNode)
2297 {
2298 exprNode->traverse(this);
2299 }
2300
2301 mOut << ")\n";
2302
2303 bodyNode->traverse(this);
2304
2305 return false;
2306 }
2307
visitWhileLoop(TIntermLoop * loopNode)2308 bool GenMetalTraverser::visitWhileLoop(TIntermLoop *loopNode)
2309 {
2310 ASSERT(loopNode->getType() == TLoopType::ELoopWhile);
2311
2312 TIntermNode *initNode = loopNode->getInit();
2313 TIntermTyped *condNode = loopNode->getCondition();
2314 TIntermTyped *exprNode = loopNode->getExpression();
2315 TIntermBlock *bodyNode = loopNode->getBody();
2316 ASSERT(condNode && bodyNode);
2317 ASSERT(!initNode && !exprNode);
2318
2319 emitIndentation();
2320 mOut << "while (";
2321 condNode->traverse(this);
2322 mOut << ")\n";
2323 bodyNode->traverse(this);
2324
2325 return false;
2326 }
2327
visitDoWhileLoop(TIntermLoop * loopNode)2328 bool GenMetalTraverser::visitDoWhileLoop(TIntermLoop *loopNode)
2329 {
2330 ASSERT(loopNode->getType() == TLoopType::ELoopDoWhile);
2331
2332 TIntermNode *initNode = loopNode->getInit();
2333 TIntermTyped *condNode = loopNode->getCondition();
2334 TIntermTyped *exprNode = loopNode->getExpression();
2335 TIntermBlock *bodyNode = loopNode->getBody();
2336 ASSERT(condNode && bodyNode);
2337 ASSERT(!initNode && !exprNode);
2338
2339 emitIndentation();
2340 mOut << "do\n";
2341 bodyNode->traverse(this);
2342 mOut << "\n";
2343 emitIndentation();
2344 mOut << "while (";
2345 condNode->traverse(this);
2346 mOut << ");";
2347
2348 return false;
2349 }
2350
visitBranch(Visit,TIntermBranch * branchNode)2351 bool GenMetalTraverser::visitBranch(Visit, TIntermBranch *branchNode)
2352 {
2353 const TOperator flowOp = branchNode->getFlowOp();
2354 TIntermTyped *exprNode = branchNode->getExpression();
2355
2356 emitIndentation();
2357
2358 switch (flowOp)
2359 {
2360 case TOperator::EOpKill:
2361 {
2362 ASSERT(exprNode == nullptr);
2363 mOut << "metal::discard_fragment()";
2364 }
2365 break;
2366
2367 case TOperator::EOpReturn:
2368 {
2369 if (isTraversingVertexMain)
2370 {
2371 mOut << "#if TRANSFORM_FEEDBACK_ENABLED\n";
2372 emitIndentation();
2373 mOut << "return;\n";
2374 emitIndentation();
2375 mOut << "#else\n";
2376 emitIndentation();
2377 }
2378 mOut << "return";
2379 if (exprNode)
2380 {
2381 mOut << " ";
2382 exprNode->traverse(this);
2383 mOut << ";";
2384 }
2385 if (isTraversingVertexMain)
2386 {
2387 mOut << "\n";
2388 emitIndentation();
2389 mOut << "#endif\n";
2390 mTemporarilyDisableSemicolon = true;
2391 }
2392 }
2393 break;
2394
2395 case TOperator::EOpBreak:
2396 {
2397 ASSERT(exprNode == nullptr);
2398 mOut << "break";
2399 }
2400 break;
2401
2402 case TOperator::EOpContinue:
2403 {
2404 ASSERT(exprNode == nullptr);
2405 mOut << "continue";
2406 }
2407 break;
2408
2409 default:
2410 {
2411 UNREACHABLE();
2412 }
2413 }
2414
2415 return false;
2416 }
2417
2418 static size_t emitMetalCallCount = 0;
2419
EmitMetal(TCompiler & compiler,TIntermBlock & root,IdGen & idGen,const PipelineStructs & pipelineStructs,const Invariants & invariants,SymbolEnv & symbolEnv,const ProgramPreludeConfig & ppc,TSymbolTable * symbolTable)2420 bool sh::EmitMetal(TCompiler &compiler,
2421 TIntermBlock &root,
2422 IdGen &idGen,
2423 const PipelineStructs &pipelineStructs,
2424 const Invariants &invariants,
2425 SymbolEnv &symbolEnv,
2426 const ProgramPreludeConfig &ppc,
2427 TSymbolTable *symbolTable)
2428 {
2429 TInfoSinkBase &out = compiler.getInfoSink().obj;
2430
2431 {
2432 ++emitMetalCallCount;
2433 std::string filenameProto = angle::GetEnvironmentVar("GMD_FIXED_EMIT");
2434 if (!filenameProto.empty())
2435 {
2436 if (filenameProto != "/dev/null")
2437 {
2438 auto tryOpen = [&](char const *ext) {
2439 auto filename = filenameProto;
2440 filename += std::to_string(emitMetalCallCount);
2441 filename += ".";
2442 filename += ext;
2443 return fopen(filename.c_str(), "rb");
2444 };
2445 FILE *file = tryOpen("metal");
2446 if (!file)
2447 {
2448 file = tryOpen("cpp");
2449 }
2450 ASSERT(file);
2451
2452 fseek(file, 0, SEEK_END);
2453 size_t fileSize = ftell(file);
2454 fseek(file, 0, SEEK_SET);
2455
2456 std::vector<char> buff;
2457 buff.resize(fileSize + 1);
2458 fread(buff.data(), fileSize, 1, file);
2459 buff.back() = '\0';
2460
2461 fclose(file);
2462
2463 out << buff.data();
2464 }
2465
2466 return true;
2467 }
2468 }
2469
2470 out << "\n\n";
2471
2472 if (!EmitProgramPrelude(root, out, ppc))
2473 {
2474 return false;
2475 }
2476
2477 {
2478 #if defined(ANGLE_ENABLE_ASSERTS)
2479 DebugSink outWrapper(out, angle::GetBoolEnvironmentVar("GMD_STDOUT"));
2480 outWrapper.watch(angle::GetEnvironmentVar("GMD_WATCH_STRING"));
2481 #else
2482 TInfoSinkBase &outWrapper = out;
2483 #endif
2484 GenMetalTraverser gen(compiler, outWrapper, idGen, pipelineStructs, invariants, symbolEnv,
2485 symbolTable);
2486 root.traverse(&gen);
2487 }
2488
2489 out << "\n";
2490
2491 return true;
2492 }
2493