1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include <cctype>
8 #include <map>
9
10 #include "common/system_utils.h"
11 #include "compiler/translator/BaseTypes.h"
12 #include "compiler/translator/ImmutableStringBuilder.h"
13 #include "compiler/translator/SymbolTable.h"
14 #include "compiler/translator/msl/AstHelpers.h"
15 #include "compiler/translator/msl/DebugSink.h"
16 #include "compiler/translator/msl/EmitMetal.h"
17 #include "compiler/translator/msl/Layout.h"
18 #include "compiler/translator/msl/Name.h"
19 #include "compiler/translator/msl/ProgramPrelude.h"
20 #include "compiler/translator/msl/RewritePipelines.h"
21 #include "compiler/translator/msl/TranslatorMSL.h"
22 #include "compiler/translator/tree_util/IntermTraverse.h"
23
24 using namespace sh;
25
26 ////////////////////////////////////////////////////////////////////////////////
27
28 #if defined(ANGLE_ENABLE_ASSERTS)
29 using Sink = DebugSink;
30 #else
31 using Sink = TInfoSinkBase;
32 #endif
33
34 ////////////////////////////////////////////////////////////////////////////////
35
36 namespace
37 {
38
39 struct VarDecl
40 {
VarDecl__anonb64d01340111::VarDecl41 explicit VarDecl(const TVariable &var) : mVariable(&var), mIsField(false) {}
VarDecl__anonb64d01340111::VarDecl42 explicit VarDecl(const TField &field) : mField(&field), mIsField(true) {}
43
variable__anonb64d01340111::VarDecl44 ANGLE_INLINE const TVariable &variable() const
45 {
46 ASSERT(isVariable());
47 return *mVariable;
48 }
49
field__anonb64d01340111::VarDecl50 ANGLE_INLINE const TField &field() const
51 {
52 ASSERT(isField());
53 return *mField;
54 }
55
isVariable__anonb64d01340111::VarDecl56 ANGLE_INLINE bool isVariable() const { return !mIsField; }
57
isField__anonb64d01340111::VarDecl58 ANGLE_INLINE bool isField() const { return mIsField; }
59
type__anonb64d01340111::VarDecl60 const TType &type() const { return isField() ? *field().type() : variable().getType(); }
61
symbolType__anonb64d01340111::VarDecl62 SymbolType symbolType() const
63 {
64 return isField() ? field().symbolType() : variable().symbolType();
65 }
66
67 private:
68 union
69 {
70 const TVariable *mVariable;
71 const TField *mField;
72 };
73 bool mIsField;
74 };
75
76 class GenMetalTraverser : public TIntermTraverser
77 {
78 public:
79 ~GenMetalTraverser() override;
80
81 GenMetalTraverser(const TCompiler &compiler,
82 Sink &out,
83 IdGen &idGen,
84 const PipelineStructs &pipelineStructs,
85 SymbolEnv &symbolEnv,
86 const ShCompileOptions &compileOptions);
87
88 void visitSymbol(TIntermSymbol *) override;
89 void visitConstantUnion(TIntermConstantUnion *) override;
90 bool visitSwizzle(Visit, TIntermSwizzle *) override;
91 bool visitBinary(Visit, TIntermBinary *) override;
92 bool visitUnary(Visit, TIntermUnary *) override;
93 bool visitTernary(Visit, TIntermTernary *) override;
94 bool visitIfElse(Visit, TIntermIfElse *) override;
95 bool visitSwitch(Visit, TIntermSwitch *) override;
96 bool visitCase(Visit, TIntermCase *) override;
97 void visitFunctionPrototype(TIntermFunctionPrototype *) override;
98 bool visitFunctionDefinition(Visit, TIntermFunctionDefinition *) override;
99 bool visitAggregate(Visit, TIntermAggregate *) override;
100 bool visitBlock(Visit, TIntermBlock *) override;
101 bool visitGlobalQualifierDeclaration(Visit, TIntermGlobalQualifierDeclaration *) override;
102 bool visitDeclaration(Visit, TIntermDeclaration *) override;
103 bool visitLoop(Visit, TIntermLoop *) override;
104 bool visitForLoop(TIntermLoop *);
105 bool visitWhileLoop(TIntermLoop *);
106 bool visitDoWhileLoop(TIntermLoop *);
107 bool visitBranch(Visit, TIntermBranch *) override;
108
109 private:
110 using FuncToName = std::map<ImmutableString, Name>;
111 static FuncToName BuildFuncToName();
112
113 struct EmitVariableDeclarationConfig
114 {
115 bool isParameter = false;
116 bool isMainParameter = false;
117 bool emitPostQualifier = false;
118 bool isPacked = false;
119 bool disableStructSpecifier = false;
120 bool isUBO = false;
121 const AddressSpace *isPointer = nullptr;
122 const AddressSpace *isReference = nullptr;
123 };
124
125 struct EmitTypeConfig
126 {
127 const EmitVariableDeclarationConfig *evdConfig = nullptr;
128 };
129
130 void emitIndentation();
131 void emitOpeningPointerParen();
132 void emitClosingPointerParen();
133 void emitFunctionSignature(const TFunction &func);
134 void emitFunctionReturn(const TFunction &func);
135 void emitFunctionParameter(const TFunction &func, const TVariable ¶m);
136
137 void emitNameOf(const TField &object);
138 void emitNameOf(const TSymbol &object);
139 void emitNameOf(const VarDecl &object);
140
141 void emitBareTypeName(const TType &type, const EmitTypeConfig &etConfig);
142 void emitType(const TType &type, const EmitTypeConfig &etConfig);
143 void emitPostQualifier(const EmitVariableDeclarationConfig &evdConfig,
144 const VarDecl &decl,
145 const TQualifier qualifier);
146
147 struct FieldAnnotationIndices
148 {
149 size_t attribute = 0;
150 size_t color = 0;
151 };
152
153 void emitFieldDeclaration(const TField &field,
154 const TStructure &parent,
155 FieldAnnotationIndices &annotationIndices);
156 void emitAttributeDeclaration(const TField &field, FieldAnnotationIndices &annotationIndices);
157 void emitUniformBufferDeclaration(const TField &field,
158 FieldAnnotationIndices &annotationIndices);
159 void emitStructDeclaration(const TType &type);
160 void emitOrdinaryVariableDeclaration(const VarDecl &decl,
161 const EmitVariableDeclarationConfig &evdConfig);
162 void emitVariableDeclaration(const VarDecl &decl,
163 const EmitVariableDeclarationConfig &evdConfig);
164
165 void emitOpenBrace();
166 void emitCloseBrace();
167
168 void groupedTraverse(TIntermNode &node);
169
170 const TField &getDirectField(const TFieldListCollection &fieldsNode,
171 const TConstantUnion &index);
172 const TField &getDirectField(const TIntermTyped &fieldsNode, TIntermTyped &indexNode);
173
174 const TConstantUnion *emitConstantUnionArray(const TConstantUnion *const constUnion,
175 const size_t size);
176
177 const TConstantUnion *emitConstantUnion(const TType &type, const TConstantUnion *constUnion);
178
179 void emitSingleConstant(const TConstantUnion *const constUnion);
180
181 private:
182 Sink &mOut;
183 const TCompiler &mCompiler;
184 const PipelineStructs &mPipelineStructs;
185 SymbolEnv &mSymbolEnv;
186 IdGen &mIdGen;
187 int mIndentLevel = -1;
188 int mLastIndentationPos = -1;
189 int mOpenPointerParenCount = 0;
190 bool mParentIsSwitch = false;
191 bool isTraversingVertexMain = false;
192 bool mTemporarilyDisableSemicolon = false;
193 std::unordered_map<const TSymbol *, Name> mRenamedSymbols;
194 const FuncToName mFuncToName = BuildFuncToName();
195 size_t mMainTextureIndex = 0;
196 size_t mMainSamplerIndex = 0;
197 size_t mMainUniformBufferIndex = 0;
198 size_t mDriverUniformsBindingIndex = 0;
199 size_t mUBOArgumentBufferBindingIndex = 0;
200 bool mRasterOrderGroupsSupported = false;
201 };
202 } // anonymous namespace
203
~GenMetalTraverser()204 GenMetalTraverser::~GenMetalTraverser()
205 {
206 ASSERT(mIndentLevel == -1);
207 ASSERT(!mParentIsSwitch);
208 ASSERT(mOpenPointerParenCount == 0);
209 }
210
GenMetalTraverser(const TCompiler & compiler,Sink & out,IdGen & idGen,const PipelineStructs & pipelineStructs,SymbolEnv & symbolEnv,const ShCompileOptions & compileOptions)211 GenMetalTraverser::GenMetalTraverser(const TCompiler &compiler,
212 Sink &out,
213 IdGen &idGen,
214 const PipelineStructs &pipelineStructs,
215 SymbolEnv &symbolEnv,
216 const ShCompileOptions &compileOptions)
217 : TIntermTraverser(true, false, false),
218 mOut(out),
219 mCompiler(compiler),
220 mPipelineStructs(pipelineStructs),
221 mSymbolEnv(symbolEnv),
222 mIdGen(idGen),
223 mMainUniformBufferIndex(compileOptions.metal.defaultUniformsBindingIndex),
224 mDriverUniformsBindingIndex(compileOptions.metal.driverUniformsBindingIndex),
225 mUBOArgumentBufferBindingIndex(compileOptions.metal.UBOArgumentBufferBindingIndex),
226 mRasterOrderGroupsSupported(compileOptions.pls.fragmentSyncType ==
227 ShFragmentSynchronizationType::RasterOrderGroups_Metal)
228 {}
229
emitIndentation()230 void GenMetalTraverser::emitIndentation()
231 {
232 ASSERT(mIndentLevel >= 0);
233
234 if (mLastIndentationPos == mOut.size())
235 {
236 return; // Line is already indented.
237 }
238
239 for (int i = 0; i < mIndentLevel; ++i)
240 {
241 mOut << " ";
242 }
243
244 mLastIndentationPos = mOut.size();
245 }
246
emitOpeningPointerParen()247 void GenMetalTraverser::emitOpeningPointerParen()
248 {
249 mOut << "(*";
250 mOpenPointerParenCount++;
251 }
252
emitClosingPointerParen()253 void GenMetalTraverser::emitClosingPointerParen()
254 {
255 if (mOpenPointerParenCount > 0)
256 {
257 mOut << ")";
258 mOpenPointerParenCount--;
259 }
260 }
261
GetOperatorString(TOperator op,const TType & resultType,const TType * argType0,const TType * argType1,const TType * argType2)262 static const char *GetOperatorString(TOperator op,
263 const TType &resultType,
264 const TType *argType0,
265 const TType *argType1,
266 const TType *argType2)
267 {
268 switch (op)
269 {
270 case TOperator::EOpComma:
271 return ",";
272 case TOperator::EOpAssign:
273 return "=";
274 case TOperator::EOpInitialize:
275 return "=";
276 case TOperator::EOpAddAssign:
277 return "+=";
278 case TOperator::EOpSubAssign:
279 return "-=";
280 case TOperator::EOpMulAssign:
281 return "*=";
282 case TOperator::EOpDivAssign:
283 return "/=";
284 case TOperator::EOpIModAssign:
285 return "%=";
286 case TOperator::EOpBitShiftLeftAssign:
287 return "<<="; // TODO: Check logical vs arithmetic shifting.
288 case TOperator::EOpBitShiftRightAssign:
289 return ">>="; // TODO: Check logical vs arithmetic shifting.
290 case TOperator::EOpBitwiseAndAssign:
291 return "&=";
292 case TOperator::EOpBitwiseXorAssign:
293 return "^=";
294 case TOperator::EOpBitwiseOrAssign:
295 return "|=";
296 case TOperator::EOpAdd:
297 return "+";
298 case TOperator::EOpSub:
299 return "-";
300 case TOperator::EOpMul:
301 return "*";
302 case TOperator::EOpDiv:
303 return "/";
304 case TOperator::EOpIMod:
305 return "%";
306 case TOperator::EOpBitShiftLeft:
307 return "<<"; // TODO: Check logical vs arithmetic shifting.
308 case TOperator::EOpBitShiftRight:
309 return ">>"; // TODO: Check logical vs arithmetic shifting.
310 case TOperator::EOpBitwiseAnd:
311 return "&";
312 case TOperator::EOpBitwiseXor:
313 return "^";
314 case TOperator::EOpBitwiseOr:
315 return "|";
316 case TOperator::EOpLessThan:
317 return "<";
318 case TOperator::EOpGreaterThan:
319 return ">";
320 case TOperator::EOpLessThanEqual:
321 return "<=";
322 case TOperator::EOpGreaterThanEqual:
323 return ">=";
324 case TOperator::EOpLessThanComponentWise:
325 return "<";
326 case TOperator::EOpLessThanEqualComponentWise:
327 return "<=";
328 case TOperator::EOpGreaterThanEqualComponentWise:
329 return ">=";
330 case TOperator::EOpGreaterThanComponentWise:
331 return ">";
332 case TOperator::EOpLogicalOr:
333 return "||";
334 case TOperator::EOpLogicalXor:
335 return "!=/*xor*/"; // XXX: This might need to be handled differently for some obtuse
336 // use case.
337 case TOperator::EOpLogicalAnd:
338 return "&&";
339 case TOperator::EOpNegative:
340 return "-";
341 case TOperator::EOpPositive:
342 if (argType0->isMatrix())
343 {
344 return "";
345 }
346 return "+";
347 case TOperator::EOpLogicalNot:
348 return "!";
349 case TOperator::EOpNotComponentWise:
350 return "!";
351 case TOperator::EOpBitwiseNot:
352 return "~";
353 case TOperator::EOpPostIncrement:
354 return "++";
355 case TOperator::EOpPostDecrement:
356 return "--";
357 case TOperator::EOpPreIncrement:
358 return "++";
359 case TOperator::EOpPreDecrement:
360 return "--";
361 case TOperator::EOpVectorTimesScalarAssign:
362 return "*=";
363 case TOperator::EOpVectorTimesMatrixAssign:
364 return "*=";
365 case TOperator::EOpMatrixTimesScalarAssign:
366 return "*=";
367 case TOperator::EOpMatrixTimesMatrixAssign:
368 return "*=";
369 case TOperator::EOpVectorTimesScalar:
370 return "*";
371 case TOperator::EOpVectorTimesMatrix:
372 return "*";
373 case TOperator::EOpMatrixTimesVector:
374 return "*";
375 case TOperator::EOpMatrixTimesScalar:
376 return "*";
377 case TOperator::EOpMatrixTimesMatrix:
378 return "*";
379 case TOperator::EOpEqualComponentWise:
380 return "==";
381 case TOperator::EOpNotEqualComponentWise:
382 return "!=";
383
384 case TOperator::EOpEqual:
385 if ((argType0->getStruct() && argType1->getStruct()) &&
386 (argType0->isArray() && argType1->isArray()))
387 {
388 return "ANGLE_equalStructArray";
389 }
390
391 if ((argType0->isVector() && argType1->isVector()) ||
392 (argType0->getStruct() && argType1->getStruct()) ||
393 (argType0->isArray() && argType1->isArray()) ||
394 (argType0->isMatrix() && argType1->isMatrix()))
395
396 {
397 return "ANGLE_equal";
398 }
399
400 return "==";
401
402 case TOperator::EOpNotEqual:
403 if ((argType0->getStruct() && argType1->getStruct()) &&
404 (argType0->isArray() && argType1->isArray()))
405 {
406 return "ANGLE_notEqualStructArray";
407 }
408
409 if ((argType0->isVector() && argType1->isVector()) ||
410 (argType0->isArray() && argType1->isArray()) ||
411 (argType0->isMatrix() && argType1->isMatrix()))
412 {
413 return "ANGLE_notEqual";
414 }
415 else if (argType0->getStruct() && argType1->getStruct())
416 {
417 return "ANGLE_notEqualStruct";
418 }
419 return "!=";
420
421 case TOperator::EOpKill:
422 UNIMPLEMENTED();
423 return "kill";
424 case TOperator::EOpReturn:
425 return "return";
426 case TOperator::EOpBreak:
427 return "break";
428 case TOperator::EOpContinue:
429 return "continue";
430
431 case TOperator::EOpRadians:
432 return "ANGLE_radians";
433 case TOperator::EOpDegrees:
434 return "ANGLE_degrees";
435 case TOperator::EOpAtan:
436 return "ANGLE_atan";
437 case TOperator::EOpMod:
438 return "ANGLE_mod"; // differs from metal::mod
439 case TOperator::EOpRefract:
440 return "ANGLE_refract";
441 case TOperator::EOpDistance:
442 return "ANGLE_distance";
443 case TOperator::EOpLength:
444 return "ANGLE_length";
445 case TOperator::EOpDot:
446 return "ANGLE_dot";
447 case TOperator::EOpNormalize:
448 return "ANGLE_normalize";
449 case TOperator::EOpFaceforward:
450 return "ANGLE_faceforward";
451 case TOperator::EOpReflect:
452 return "ANGLE_reflect";
453 case TOperator::EOpMatrixCompMult:
454 return "ANGLE_componentWiseMultiply";
455 case TOperator::EOpOuterProduct:
456 return "ANGLE_outerProduct";
457 case TOperator::EOpSign:
458 return "ANGLE_sign";
459
460 case TOperator::EOpAbs:
461 return "metal::abs";
462 case TOperator::EOpAll:
463 return "metal::all";
464 case TOperator::EOpAny:
465 return "metal::any";
466 case TOperator::EOpSin:
467 return "metal::sin";
468 case TOperator::EOpCos:
469 return "metal::cos";
470 case TOperator::EOpTan:
471 return "metal::tan";
472 case TOperator::EOpAsin:
473 return "metal::asin";
474 case TOperator::EOpAcos:
475 return "metal::acos";
476 case TOperator::EOpSinh:
477 return "metal::sinh";
478 case TOperator::EOpCosh:
479 return "metal::cosh";
480 case TOperator::EOpTanh:
481 return "metal::tanh";
482 case TOperator::EOpAsinh:
483 return "metal::asinh";
484 case TOperator::EOpAcosh:
485 return "metal::acosh";
486 case TOperator::EOpAtanh:
487 return "metal::atanh";
488 case TOperator::EOpFma:
489 return "metal::fma";
490 case TOperator::EOpPow:
491 return "metal::pow";
492 case TOperator::EOpExp:
493 return "metal::exp";
494 case TOperator::EOpExp2:
495 return "metal::exp2";
496 case TOperator::EOpLog:
497 return "metal::log";
498 case TOperator::EOpLog2:
499 return "metal::log2";
500 case TOperator::EOpSqrt:
501 return "metal::sqrt";
502 case TOperator::EOpFloor:
503 return "metal::floor";
504 case TOperator::EOpTrunc:
505 return "metal::trunc";
506 case TOperator::EOpCeil:
507 return "metal::ceil";
508 case TOperator::EOpFract:
509 return "metal::fract";
510 case TOperator::EOpMin:
511 return "metal::min";
512 case TOperator::EOpMax:
513 return "metal::max";
514 case TOperator::EOpRound:
515 return "metal::round";
516 case TOperator::EOpRoundEven:
517 return "metal::rint";
518 case TOperator::EOpClamp:
519 return "metal::clamp"; // TODO fast vs precise namespace
520 case TOperator::EOpSaturate:
521 return "metal::saturate"; // TODO fast vs precise namespace
522 case TOperator::EOpMix:
523 if (argType2 && argType2->getBasicType() == EbtBool)
524 return "ANGLE_mix_bool";
525 return "metal::mix";
526 case TOperator::EOpStep:
527 return "metal::step";
528 case TOperator::EOpSmoothstep:
529 return "metal::smoothstep";
530 case TOperator::EOpModf:
531 return "metal::modf";
532 case TOperator::EOpIsnan:
533 return "metal::isnan";
534 case TOperator::EOpIsinf:
535 return "metal::isinf";
536 case TOperator::EOpLdexp:
537 return "metal::ldexp";
538 case TOperator::EOpFrexp:
539 return "metal::frexp";
540 case TOperator::EOpInversesqrt:
541 return "metal::rsqrt";
542 case TOperator::EOpCross:
543 return "metal::cross";
544 case TOperator::EOpDFdx:
545 return "metal::dfdx";
546 case TOperator::EOpDFdy:
547 return "metal::dfdy";
548 case TOperator::EOpFwidth:
549 return "metal::fwidth";
550 case TOperator::EOpTranspose:
551 return "metal::transpose";
552 case TOperator::EOpDeterminant:
553 return "metal::determinant";
554
555 case TOperator::EOpInverse:
556 return "ANGLE_inverse";
557
558 case TOperator::EOpInterpolateAtCentroid:
559 return "ANGLE_interpolateAtCentroid";
560 case TOperator::EOpInterpolateAtSample:
561 return "ANGLE_interpolateAtSample";
562 case TOperator::EOpInterpolateAtOffset:
563 return "ANGLE_interpolateAtOffset";
564 case TOperator::EOpInterpolateAtCenter:
565 return "ANGLE_interpolateAtCenter";
566
567 case TOperator::EOpFloatBitsToInt:
568 case TOperator::EOpFloatBitsToUint:
569 case TOperator::EOpIntBitsToFloat:
570 case TOperator::EOpUintBitsToFloat:
571 {
572 #define RETURN_AS_TYPE_SCALAR() \
573 do \
574 switch (resultType.getBasicType()) \
575 { \
576 case TBasicType::EbtInt: \
577 return "as_type<int>"; \
578 case TBasicType::EbtUInt: \
579 return "as_type<uint32_t>"; \
580 case TBasicType::EbtFloat: \
581 return "as_type<float>"; \
582 default: \
583 UNIMPLEMENTED(); \
584 return "TOperator_TODO"; \
585 } \
586 while (false)
587
588 #define RETURN_AS_TYPE(post) \
589 do \
590 switch (resultType.getBasicType()) \
591 { \
592 case TBasicType::EbtInt: \
593 return "as_type<int" post ">"; \
594 case TBasicType::EbtUInt: \
595 return "as_type<uint" post ">"; \
596 case TBasicType::EbtFloat: \
597 return "as_type<float" post ">"; \
598 default: \
599 UNIMPLEMENTED(); \
600 return "TOperator_TODO"; \
601 } \
602 while (false)
603
604 if (resultType.isScalar())
605 {
606 RETURN_AS_TYPE_SCALAR();
607 }
608 else if (resultType.isVector())
609 {
610 switch (resultType.getNominalSize())
611 {
612 case 2:
613 RETURN_AS_TYPE("2");
614 case 3:
615 RETURN_AS_TYPE("3");
616 case 4:
617 RETURN_AS_TYPE("4");
618 default:
619 UNREACHABLE();
620 return nullptr;
621 }
622 }
623 else
624 {
625 UNIMPLEMENTED();
626 return "TOperator_TODO";
627 }
628
629 #undef RETURN_AS_TYPE
630 #undef RETURN_AS_TYPE_SCALAR
631 }
632
633 case TOperator::EOpPackUnorm2x16:
634 return "metal::pack_float_to_unorm2x16";
635 case TOperator::EOpPackSnorm2x16:
636 return "metal::pack_float_to_snorm2x16";
637
638 case TOperator::EOpPackUnorm4x8:
639 return "metal::pack_float_to_unorm4x8";
640 case TOperator::EOpPackSnorm4x8:
641 return "metal::pack_float_to_snorm4x8";
642
643 case TOperator::EOpUnpackUnorm2x16:
644 return "metal::unpack_unorm2x16_to_float";
645 case TOperator::EOpUnpackSnorm2x16:
646 return "metal::unpack_snorm2x16_to_float";
647
648 case TOperator::EOpUnpackUnorm4x8:
649 return "metal::unpack_unorm4x8_to_float";
650 case TOperator::EOpUnpackSnorm4x8:
651 return "metal::unpack_snorm4x8_to_float";
652
653 case TOperator::EOpPackHalf2x16:
654 return "ANGLE_pack_half_2x16";
655 case TOperator::EOpUnpackHalf2x16:
656 return "ANGLE_unpack_half_2x16";
657
658 case TOperator::EOpNumSamples:
659 return "metal::get_num_samples";
660 case TOperator::EOpSamplePosition:
661 return "metal::get_sample_position";
662
663 case TOperator::EOpBitfieldExtract:
664 case TOperator::EOpBitfieldInsert:
665 case TOperator::EOpBitfieldReverse:
666 case TOperator::EOpBitCount:
667 case TOperator::EOpFindLSB:
668 case TOperator::EOpFindMSB:
669 case TOperator::EOpUaddCarry:
670 case TOperator::EOpUsubBorrow:
671 case TOperator::EOpUmulExtended:
672 case TOperator::EOpImulExtended:
673 case TOperator::EOpBarrier:
674 case TOperator::EOpMemoryBarrier:
675 case TOperator::EOpMemoryBarrierAtomicCounter:
676 case TOperator::EOpMemoryBarrierBuffer:
677 case TOperator::EOpMemoryBarrierShared:
678 case TOperator::EOpGroupMemoryBarrier:
679 case TOperator::EOpAtomicAdd:
680 case TOperator::EOpAtomicMin:
681 case TOperator::EOpAtomicMax:
682 case TOperator::EOpAtomicAnd:
683 case TOperator::EOpAtomicOr:
684 case TOperator::EOpAtomicXor:
685 case TOperator::EOpAtomicExchange:
686 case TOperator::EOpAtomicCompSwap:
687 case TOperator::EOpEmitVertex:
688 case TOperator::EOpEndPrimitive:
689 case TOperator::EOpFtransform:
690 case TOperator::EOpPackDouble2x32:
691 case TOperator::EOpUnpackDouble2x32:
692 case TOperator::EOpArrayLength:
693 UNIMPLEMENTED();
694 return "TOperator_TODO";
695
696 case TOperator::EOpNull:
697 case TOperator::EOpConstruct:
698 case TOperator::EOpCallFunctionInAST:
699 case TOperator::EOpCallInternalRawFunction:
700 case TOperator::EOpIndexDirect:
701 case TOperator::EOpIndexIndirect:
702 case TOperator::EOpIndexDirectStruct:
703 case TOperator::EOpIndexDirectInterfaceBlock:
704 UNREACHABLE();
705 return nullptr;
706 default:
707 // Any other built-in function.
708 return nullptr;
709 }
710 }
711
IsSymbolicOperator(TOperator op,const TType & resultType,const TType * argType0,const TType * argType1)712 static bool IsSymbolicOperator(TOperator op,
713 const TType &resultType,
714 const TType *argType0,
715 const TType *argType1)
716 {
717 const char *operatorString = GetOperatorString(op, resultType, argType0, argType1, nullptr);
718 if (operatorString == nullptr)
719 {
720 return false;
721 }
722 return !std::isalnum(operatorString[0]);
723 }
724
AsSpecificBinaryNode(TIntermNode & node,TOperator op)725 static TIntermBinary *AsSpecificBinaryNode(TIntermNode &node, TOperator op)
726 {
727 TIntermBinary *binaryNode = node.getAsBinaryNode();
728 if (binaryNode)
729 {
730 return binaryNode->getOp() == op ? binaryNode : nullptr;
731 }
732 return nullptr;
733 }
734
Parenthesize(TIntermNode & node)735 static bool Parenthesize(TIntermNode &node)
736 {
737 if (node.getAsSymbolNode())
738 {
739 return false;
740 }
741 if (node.getAsConstantUnion())
742 {
743 return false;
744 }
745 if (node.getAsAggregate())
746 {
747 return false;
748 }
749 if (node.getAsSwizzleNode())
750 {
751 return false;
752 }
753
754 if (TIntermUnary *unaryNode = node.getAsUnaryNode())
755 {
756 // TODO: Use a precedence and associativity rules instead of this ad-hoc impl.
757 const TType &resultType = unaryNode->getType();
758 const TType &argType = unaryNode->getOperand()->getType();
759 return IsSymbolicOperator(unaryNode->getOp(), resultType, &argType, nullptr);
760 }
761
762 if (TIntermBinary *binaryNode = node.getAsBinaryNode())
763 {
764 // TODO: Use a precedence and associativity rules instead of this ad-hoc impl.
765 const TOperator op = binaryNode->getOp();
766 switch (op)
767 {
768 case TOperator::EOpIndexDirectStruct:
769 case TOperator::EOpIndexDirectInterfaceBlock:
770 case TOperator::EOpIndexDirect:
771 case TOperator::EOpIndexIndirect:
772 return Parenthesize(*binaryNode->getLeft());
773
774 case TOperator::EOpAssign:
775 case TOperator::EOpInitialize:
776 return AsSpecificBinaryNode(*binaryNode->getRight(), TOperator::EOpComma);
777
778 default:
779 {
780 const TType &resultType = binaryNode->getType();
781 const TType &leftType = binaryNode->getLeft()->getType();
782 const TType &rightType = binaryNode->getRight()->getType();
783 return IsSymbolicOperator(binaryNode->getOp(), resultType, &leftType, &rightType);
784 }
785 }
786 }
787
788 return true;
789 }
790
groupedTraverse(TIntermNode & node)791 void GenMetalTraverser::groupedTraverse(TIntermNode &node)
792 {
793 const bool emitParens = Parenthesize(node);
794
795 if (emitParens)
796 {
797 mOut << "(";
798 }
799
800 node.traverse(this);
801
802 if (emitParens)
803 {
804 mOut << ")";
805 }
806 }
807
emitPostQualifier(const EmitVariableDeclarationConfig & evdConfig,const VarDecl & decl,const TQualifier qualifier)808 void GenMetalTraverser::emitPostQualifier(const EmitVariableDeclarationConfig &evdConfig,
809 const VarDecl &decl,
810 const TQualifier qualifier)
811 {
812 bool isInvariant = false;
813 switch (qualifier)
814 {
815 case TQualifier::EvqPosition:
816 isInvariant = decl.type().isInvariant();
817 [[fallthrough]];
818 case TQualifier::EvqFragCoord:
819 mOut << " [[position]]";
820 break;
821
822 case TQualifier::EvqClipDistance:
823 mOut << " [[clip_distance]] [" << decl.type().getOutermostArraySize() << "]";
824 break;
825
826 case TQualifier::EvqPointSize:
827 mOut << " [[point_size]]";
828 break;
829
830 case TQualifier::EvqVertexID:
831 if (evdConfig.isMainParameter)
832 {
833 mOut << " [[vertex_id]]";
834 }
835 break;
836
837 case TQualifier::EvqPointCoord:
838 if (evdConfig.isMainParameter)
839 {
840 mOut << " [[point_coord]]";
841 }
842 break;
843
844 case TQualifier::EvqFrontFacing:
845 if (evdConfig.isMainParameter)
846 {
847 mOut << " [[front_facing]]";
848 }
849 break;
850
851 case TQualifier::EvqSampleID:
852 if (evdConfig.isMainParameter)
853 {
854 mOut << " [[sample_id]]";
855 }
856 break;
857
858 case TQualifier::EvqSampleMaskIn:
859 if (evdConfig.isMainParameter)
860 {
861 mOut << " [[sample_mask]]";
862 }
863 break;
864
865 default:
866 break;
867 }
868
869 if (isInvariant)
870 {
871 mOut << " [[invariant]]";
872
873 TranslatorMetalReflection *reflection = mtl::getTranslatorMetalReflection(&mCompiler);
874 reflection->hasInvariance = true;
875 }
876 }
877
EmitName(Sink & out,const Name & name)878 static void EmitName(Sink &out, const Name &name)
879 {
880 #if defined(ANGLE_ENABLE_ASSERTS)
881 DebugSink::EscapedSink escapedOut(out.escape());
882 #else
883 TInfoSinkBase &escapedOut = out;
884 #endif
885 name.emit(escapedOut);
886 }
887
emitNameOf(const TField & object)888 void GenMetalTraverser::emitNameOf(const TField &object)
889 {
890 EmitName(mOut, Name(object));
891 }
892
emitNameOf(const TSymbol & object)893 void GenMetalTraverser::emitNameOf(const TSymbol &object)
894 {
895 auto it = mRenamedSymbols.find(&object);
896 if (it == mRenamedSymbols.end())
897 {
898 EmitName(mOut, Name(object));
899 }
900 else
901 {
902 EmitName(mOut, it->second);
903 }
904 }
905
emitNameOf(const VarDecl & object)906 void GenMetalTraverser::emitNameOf(const VarDecl &object)
907 {
908 if (object.isField())
909 {
910 emitNameOf(object.field());
911 }
912 else
913 {
914 emitNameOf(object.variable());
915 }
916 }
917
emitBareTypeName(const TType & type,const EmitTypeConfig & etConfig)918 void GenMetalTraverser::emitBareTypeName(const TType &type, const EmitTypeConfig &etConfig)
919 {
920 const TBasicType basicType = type.getBasicType();
921
922 switch (basicType)
923 {
924 case TBasicType::EbtVoid:
925 case TBasicType::EbtBool:
926 case TBasicType::EbtFloat:
927 case TBasicType::EbtInt:
928 {
929 mOut << type.getBasicString();
930 break;
931 }
932 case TBasicType::EbtUInt:
933 {
934 if (type.isScalar())
935 {
936 mOut << "uint32_t";
937 }
938 else
939 {
940 mOut << type.getBasicString();
941 }
942 }
943 break;
944
945 case TBasicType::EbtStruct:
946 {
947 const TStructure &structure = *type.getStruct();
948 emitNameOf(structure);
949 }
950 break;
951
952 case TBasicType::EbtInterfaceBlock:
953 {
954 const TInterfaceBlock &interfaceBlock = *type.getInterfaceBlock();
955 emitNameOf(interfaceBlock);
956 }
957 break;
958
959 default:
960 {
961 if (IsSampler(basicType))
962 {
963 if (etConfig.evdConfig && etConfig.evdConfig->isMainParameter)
964 {
965 EmitName(mOut, GetTextureTypeName(basicType));
966 }
967 else
968 {
969 const TStructure &env = mSymbolEnv.getTextureEnv(basicType);
970 emitNameOf(env);
971 }
972 }
973 else if (IsImage(basicType))
974 {
975 mOut << "metal::texture2d<";
976 switch (type.getBasicType())
977 {
978 case EbtImage2D:
979 mOut << "float";
980 break;
981 case EbtIImage2D:
982 mOut << "int";
983 break;
984 case EbtUImage2D:
985 mOut << "uint";
986 break;
987 default:
988 UNIMPLEMENTED();
989 break;
990 }
991 if (type.getMemoryQualifier().readonly || type.getMemoryQualifier().writeonly)
992 {
993 UNIMPLEMENTED();
994 }
995 mOut << ", metal::access::read_write>";
996 }
997 else
998 {
999 UNIMPLEMENTED();
1000 }
1001 }
1002 }
1003 }
1004
emitType(const TType & type,const EmitTypeConfig & etConfig)1005 void GenMetalTraverser::emitType(const TType &type, const EmitTypeConfig &etConfig)
1006 {
1007 const bool isUBO = etConfig.evdConfig ? etConfig.evdConfig->isUBO : false;
1008 if (etConfig.evdConfig)
1009 {
1010 const auto &evdConfig = *etConfig.evdConfig;
1011 if (isUBO)
1012 {
1013 if (type.isArray())
1014 {
1015 mOut << "ANGLE_tensor<";
1016 }
1017 }
1018 if (evdConfig.isPointer)
1019 {
1020 mOut << toString(*evdConfig.isPointer);
1021 mOut << " ";
1022 }
1023 else if (evdConfig.isReference)
1024 {
1025 mOut << toString(*evdConfig.isReference);
1026 mOut << " ";
1027 }
1028 }
1029
1030 if (!isUBO)
1031 {
1032 if (type.isArray())
1033 {
1034 mOut << "ANGLE_tensor<";
1035 }
1036 }
1037
1038 if (type.isInterpolant())
1039 {
1040 mOut << "metal::interpolant<";
1041 }
1042
1043 if (type.isVector() || type.isMatrix())
1044 {
1045 mOut << "metal::";
1046 }
1047
1048 if (etConfig.evdConfig && etConfig.evdConfig->isPacked)
1049 {
1050 mOut << "packed_";
1051 }
1052
1053 emitBareTypeName(type, etConfig);
1054
1055 if (type.isVector())
1056 {
1057 mOut << static_cast<uint32_t>(type.getNominalSize());
1058 }
1059 else if (type.isMatrix())
1060 {
1061 mOut << static_cast<uint32_t>(type.getCols()) << "x"
1062 << static_cast<uint32_t>(type.getRows());
1063 }
1064
1065 if (type.isInterpolant())
1066 {
1067 mOut << ", metal::interpolation::";
1068 switch (type.getQualifier())
1069 {
1070 case EvqNoPerspectiveIn:
1071 case EvqNoPerspectiveCentroidIn:
1072 case EvqNoPerspectiveSampleIn:
1073 mOut << "no_";
1074 break;
1075 default:
1076 break;
1077 }
1078 mOut << "perspective>";
1079 }
1080
1081 if (!isUBO)
1082 {
1083 if (type.isArray())
1084 {
1085 for (auto size : type.getArraySizes())
1086 {
1087 mOut << ", " << size;
1088 }
1089 mOut << ">";
1090 }
1091 }
1092
1093 if (etConfig.evdConfig)
1094 {
1095 const auto &evdConfig = *etConfig.evdConfig;
1096 if (evdConfig.isPointer)
1097 {
1098 mOut << " *";
1099 }
1100 else if (evdConfig.isReference)
1101 {
1102 mOut << " &";
1103 }
1104 if (isUBO)
1105 {
1106 if (type.isArray())
1107 {
1108 for (auto size : type.getArraySizes())
1109 {
1110 mOut << ", " << size;
1111 }
1112 mOut << ">";
1113 }
1114 }
1115 }
1116 }
1117
emitFieldDeclaration(const TField & field,const TStructure & parent,FieldAnnotationIndices & annotationIndices)1118 void GenMetalTraverser::emitFieldDeclaration(const TField &field,
1119 const TStructure &parent,
1120 FieldAnnotationIndices &annotationIndices)
1121 {
1122 const TType &type = *field.type();
1123 const TBasicType basic = type.getBasicType();
1124
1125 EmitVariableDeclarationConfig evdConfig;
1126 evdConfig.emitPostQualifier = true;
1127 evdConfig.disableStructSpecifier = true;
1128 evdConfig.isPacked = mSymbolEnv.isPacked(field);
1129 evdConfig.isUBO = mSymbolEnv.isUBO(field);
1130 evdConfig.isPointer = mSymbolEnv.isPointer(field);
1131 evdConfig.isReference = mSymbolEnv.isReference(field);
1132 emitVariableDeclaration(VarDecl(field), evdConfig);
1133
1134 const TQualifier qual = type.getQualifier();
1135 switch (qual)
1136 {
1137 case TQualifier::EvqFlatIn:
1138 if (mPipelineStructs.fragmentIn.external == &parent)
1139 {
1140 mOut << " [[flat]]";
1141 TranslatorMetalReflection *reflection =
1142 mtl::getTranslatorMetalReflection(&mCompiler);
1143 reflection->hasFlatInput = true;
1144 }
1145 break;
1146
1147 case TQualifier::EvqNoPerspectiveIn:
1148 if (mPipelineStructs.fragmentIn.external == &parent && !type.isInterpolant())
1149 {
1150 mOut << " [[center_no_perspective]]";
1151 }
1152 break;
1153
1154 case TQualifier::EvqCentroidIn:
1155 if (mPipelineStructs.fragmentIn.external == &parent && !type.isInterpolant())
1156 {
1157 mOut << " [[centroid_perspective]]";
1158 }
1159 break;
1160
1161 case TQualifier::EvqSampleIn:
1162 if (mPipelineStructs.fragmentIn.external == &parent && !type.isInterpolant())
1163 {
1164 mOut << " [[sample_perspective]]";
1165 }
1166 break;
1167
1168 case TQualifier::EvqNoPerspectiveCentroidIn:
1169 if (mPipelineStructs.fragmentIn.external == &parent && !type.isInterpolant())
1170 {
1171 mOut << " [[centroid_no_perspective]]";
1172 }
1173 break;
1174
1175 case TQualifier::EvqNoPerspectiveSampleIn:
1176 if (mPipelineStructs.fragmentIn.external == &parent && !type.isInterpolant())
1177 {
1178 mOut << " [[sample_no_perspective]]";
1179 }
1180 break;
1181
1182 case TQualifier::EvqFragColor:
1183 mOut << " [[color(0)]]";
1184 break;
1185
1186 case TQualifier::EvqSecondaryFragColorEXT:
1187 case TQualifier::EvqSecondaryFragDataEXT:
1188 mOut << " [[color(0), index(1)]]";
1189 break;
1190
1191 case TQualifier::EvqFragmentOut:
1192 case TQualifier::EvqFragmentInOut:
1193 case TQualifier::EvqFragData:
1194 if (mPipelineStructs.fragmentOut.external == &parent)
1195 {
1196 if ((type.isVector() &&
1197 (basic == TBasicType::EbtInt || basic == TBasicType::EbtUInt ||
1198 basic == TBasicType::EbtFloat)) ||
1199 qual == EvqFragData)
1200 {
1201 // The OpenGL ES 3.0 spec says locations must be specified
1202 // unless there is only a single output, in which case the
1203 // location is 0. So, when we get to this point the shader
1204 // will have been rejected if locations are not specified
1205 // and there is more than one output.
1206 const TLayoutQualifier &layoutQualifier = type.getLayoutQualifier();
1207 if (layoutQualifier.locationsSpecified)
1208 {
1209 mOut << " [[color(" << layoutQualifier.location << ")";
1210 ASSERT(layoutQualifier.index >= -1 && layoutQualifier.index <= 1);
1211 if (layoutQualifier.index == 1)
1212 {
1213 mOut << ", index(1)";
1214 }
1215 }
1216 else if (qual == EvqFragData)
1217 {
1218 mOut << " [[color(" << annotationIndices.color++ << ")";
1219 }
1220 else
1221 {
1222 // Either the only output or EXT_blend_func_extended is used;
1223 // actual assignment will happen in UpdateFragmentShaderOutputs.
1224 mOut << " [[" << sh::kUnassignedFragmentOutputString;
1225 }
1226 if (mRasterOrderGroupsSupported && qual == TQualifier::EvqFragmentInOut)
1227 {
1228 // Put fragment inouts in their own raster order group for better
1229 // parallelism.
1230 // NOTE: this is not required for the reads to be ordered and coherent.
1231 // TODO(anglebug.com/7279): Consider making raster order groups a PLS layout
1232 // qualifier?
1233 mOut << ", raster_order_group(0)";
1234 }
1235 mOut << "]]";
1236 }
1237 }
1238 break;
1239
1240 case TQualifier::EvqFragDepth:
1241 mOut << " [[depth(";
1242 switch (type.getLayoutQualifier().depth)
1243 {
1244 case EdGreater:
1245 mOut << "greater";
1246 break;
1247 case EdLess:
1248 mOut << "less";
1249 break;
1250 default:
1251 mOut << "any";
1252 break;
1253 }
1254 mOut << "), function_constant(" << sh::mtl::kDepthWriteEnabledConstName << ")]]";
1255 break;
1256
1257 case TQualifier::EvqSampleMask:
1258 if (field.symbolType() == SymbolType::AngleInternal)
1259 {
1260 mOut << " [[sample_mask, function_constant("
1261 << sh::mtl::kMultisampledRenderingConstName << ")]]";
1262 }
1263 break;
1264
1265 default:
1266 break;
1267 }
1268
1269 if (IsImage(type.getBasicType()))
1270 {
1271 if (type.getArraySizeProduct() != 1)
1272 {
1273 UNIMPLEMENTED();
1274 }
1275 mOut << " [[";
1276 if (type.getLayoutQualifier().rasterOrdered)
1277 {
1278 mOut << "raster_order_group(0), ";
1279 }
1280 mOut << "texture(" << mMainTextureIndex << ")]]";
1281 TranslatorMetalReflection *reflection = mtl::getTranslatorMetalReflection(&mCompiler);
1282 reflection->addRWTextureBinding(type.getLayoutQualifier().binding,
1283 static_cast<int>(mMainTextureIndex));
1284 ++mMainTextureIndex;
1285 }
1286 }
1287
BuildExternalAttributeIndexMap(const TCompiler & compiler,const PipelineScoped<TStructure> & structure)1288 static std::map<Name, size_t> BuildExternalAttributeIndexMap(
1289 const TCompiler &compiler,
1290 const PipelineScoped<TStructure> &structure)
1291 {
1292 ASSERT(structure.isTotallyFull());
1293
1294 const auto &shaderVars = compiler.getAttributes();
1295 const size_t shaderVarSize = shaderVars.size();
1296 size_t shaderVarIndex = 0;
1297
1298 const auto &externalFields = structure.external->fields();
1299 const size_t externalSize = externalFields.size();
1300 size_t externalIndex = 0;
1301
1302 const auto &internalFields = structure.internal->fields();
1303 const size_t internalSize = internalFields.size();
1304 size_t internalIndex = 0;
1305
1306 // Internal fields are never split. External fields are sometimes split.
1307 ASSERT(externalSize >= internalSize);
1308
1309 // Structures do not contain any inactive fields.
1310 ASSERT(shaderVarSize >= internalSize);
1311
1312 std::map<Name, size_t> externalNameToAttributeIndex;
1313 size_t attributeIndex = 0;
1314
1315 while (internalIndex < internalSize)
1316 {
1317 const TField &internalField = *internalFields[internalIndex];
1318 const Name internalName = Name(internalField);
1319 const TType &internalType = *internalField.type();
1320 while (internalName.rawName() != shaderVars[shaderVarIndex].name &&
1321 internalName.rawName() != shaderVars[shaderVarIndex].mappedName)
1322 {
1323 // This case represents an inactive field.
1324
1325 ++shaderVarIndex;
1326 ASSERT(shaderVarIndex < shaderVarSize);
1327
1328 ++attributeIndex; // TODO: Might need to increment more if shader var type is a matrix.
1329 }
1330
1331 const size_t cols =
1332 (internalType.isMatrix() && !externalFields[externalIndex]->type()->isMatrix())
1333 ? internalType.getCols()
1334 : 1;
1335
1336 for (size_t c = 0; c < cols; ++c)
1337 {
1338 const TField &externalField = *externalFields[externalIndex];
1339 const Name externalName = Name(externalField);
1340
1341 externalNameToAttributeIndex[externalName] = attributeIndex;
1342
1343 ++externalIndex;
1344 ++attributeIndex;
1345 }
1346
1347 ++shaderVarIndex;
1348 ++internalIndex;
1349 }
1350
1351 ASSERT(shaderVarIndex <= shaderVarSize);
1352 ASSERT(externalIndex <= externalSize); // less than if padding was introduced
1353 ASSERT(internalIndex == internalSize);
1354
1355 return externalNameToAttributeIndex;
1356 }
1357
emitAttributeDeclaration(const TField & field,FieldAnnotationIndices & annotationIndices)1358 void GenMetalTraverser::emitAttributeDeclaration(const TField &field,
1359 FieldAnnotationIndices &annotationIndices)
1360 {
1361 EmitVariableDeclarationConfig evdConfig;
1362 evdConfig.disableStructSpecifier = true;
1363 emitVariableDeclaration(VarDecl(field), evdConfig);
1364 mOut << sh::kUnassignedAttributeString;
1365 }
1366
emitUniformBufferDeclaration(const TField & field,FieldAnnotationIndices & annotationIndices)1367 void GenMetalTraverser::emitUniformBufferDeclaration(const TField &field,
1368 FieldAnnotationIndices &annotationIndices)
1369 {
1370 EmitVariableDeclarationConfig evdConfig;
1371 evdConfig.disableStructSpecifier = true;
1372 evdConfig.isUBO = mSymbolEnv.isUBO(field);
1373 evdConfig.isPointer = mSymbolEnv.isPointer(field);
1374 emitVariableDeclaration(VarDecl(field), evdConfig);
1375 mOut << "[[id(" << annotationIndices.attribute << ")]]";
1376
1377 const TType &type = *field.type();
1378 const int arraySize = type.isArray() ? type.getArraySizeProduct() : 1;
1379
1380 TranslatorMetalReflection *reflection = mtl::getTranslatorMetalReflection(&mCompiler);
1381 ASSERT(type.getBasicType() == TBasicType::EbtStruct);
1382 const TStructure *structure = type.getStruct();
1383 const std::string originalName = reflection->getOriginalName(structure->uniqueId().get());
1384 reflection->addUniformBufferBinding(
1385 originalName,
1386 {.bindIndex = annotationIndices.attribute, .arraySize = static_cast<size_t>(arraySize)});
1387
1388 annotationIndices.attribute += arraySize;
1389 }
1390
emitStructDeclaration(const TType & type)1391 void GenMetalTraverser::emitStructDeclaration(const TType &type)
1392 {
1393 ASSERT(type.getBasicType() == TBasicType::EbtStruct);
1394 ASSERT(type.isStructSpecifier());
1395
1396 mOut << "struct ";
1397 emitBareTypeName(type, {});
1398
1399 mOut << "\n";
1400 emitOpenBrace();
1401
1402 const TStructure &structure = *type.getStruct();
1403 std::map<Name, size_t> fieldToAttributeIndex;
1404 const bool hasAttributeIndices = mPipelineStructs.vertexIn.external == &structure;
1405 const bool hasUniformBufferIndicies = mPipelineStructs.uniformBuffers.external == &structure;
1406 const bool reclaimUnusedAttributeIndices = mCompiler.getShaderVersion() < 300;
1407
1408 if (hasAttributeIndices)
1409 {
1410 fieldToAttributeIndex =
1411 BuildExternalAttributeIndexMap(mCompiler, mPipelineStructs.vertexIn);
1412 }
1413
1414 FieldAnnotationIndices annotationIndices;
1415
1416 for (const TField *field : structure.fields())
1417 {
1418 emitIndentation();
1419 if (hasAttributeIndices)
1420 {
1421 const auto it = fieldToAttributeIndex.find(Name(*field));
1422 if (it == fieldToAttributeIndex.end())
1423 {
1424 ASSERT(field->symbolType() == SymbolType::AngleInternal);
1425 ASSERT(field->name().beginsWith("_"));
1426 ASSERT(angle::EndsWith(field->name().data(), "_pad"));
1427 emitFieldDeclaration(*field, structure, annotationIndices);
1428 }
1429 else
1430 {
1431 ASSERT(field->symbolType() != SymbolType::AngleInternal ||
1432 !field->name().beginsWith("_") ||
1433 !angle::EndsWith(field->name().data(), "_pad"));
1434 if (!reclaimUnusedAttributeIndices)
1435 {
1436 annotationIndices.attribute = it->second;
1437 }
1438 emitAttributeDeclaration(*field, annotationIndices);
1439 }
1440 }
1441 else if (hasUniformBufferIndicies)
1442 {
1443 emitUniformBufferDeclaration(*field, annotationIndices);
1444 }
1445 else
1446 {
1447 emitFieldDeclaration(*field, structure, annotationIndices);
1448 }
1449 mOut << ";\n";
1450 }
1451
1452 emitCloseBrace();
1453 }
1454
emitOrdinaryVariableDeclaration(const VarDecl & decl,const EmitVariableDeclarationConfig & evdConfig)1455 void GenMetalTraverser::emitOrdinaryVariableDeclaration(
1456 const VarDecl &decl,
1457 const EmitVariableDeclarationConfig &evdConfig)
1458 {
1459 EmitTypeConfig etConfig;
1460 etConfig.evdConfig = &evdConfig;
1461
1462 const TType &type = decl.type();
1463 if (type.getQualifier() == TQualifier::EvqClipDistance)
1464 {
1465 // Clip distance output uses float[n] type instead of metal::array.
1466 // The element count is emitted after the post qualifier.
1467 ASSERT(type.getBasicType() == TBasicType::EbtFloat);
1468 mOut << "float";
1469 }
1470 else if (type.getQualifier() == TQualifier::EvqSampleID && evdConfig.isMainParameter)
1471 {
1472 // Metal's [[sample_id]] must be unsigned
1473 ASSERT(type.getBasicType() == TBasicType::EbtInt);
1474 mOut << "uint32_t";
1475 }
1476 else
1477 {
1478 emitType(type, etConfig);
1479 }
1480 if (decl.symbolType() != SymbolType::Empty)
1481 {
1482 mOut << " ";
1483 emitNameOf(decl);
1484 }
1485 }
1486
emitVariableDeclaration(const VarDecl & decl,const EmitVariableDeclarationConfig & evdConfig)1487 void GenMetalTraverser::emitVariableDeclaration(const VarDecl &decl,
1488 const EmitVariableDeclarationConfig &evdConfig)
1489 {
1490 const SymbolType symbolType = decl.symbolType();
1491 const TType &type = decl.type();
1492 const TBasicType basicType = type.getBasicType();
1493
1494 switch (basicType)
1495 {
1496 case TBasicType::EbtStruct:
1497 {
1498 if (type.isStructSpecifier() && !evdConfig.disableStructSpecifier)
1499 {
1500 // It's invalid to declare a struct inside a function argument. When emitting a
1501 // function parameter, the callsite should set evdConfig.disableStructSpecifier.
1502 ASSERT(!evdConfig.isParameter);
1503 emitStructDeclaration(type);
1504 if (symbolType != SymbolType::Empty)
1505 {
1506 mOut << " ";
1507 emitNameOf(decl);
1508 }
1509 }
1510 else
1511 {
1512 emitOrdinaryVariableDeclaration(decl, evdConfig);
1513 }
1514 }
1515 break;
1516
1517 default:
1518 {
1519 ASSERT(symbolType != SymbolType::Empty || evdConfig.isParameter);
1520 emitOrdinaryVariableDeclaration(decl, evdConfig);
1521 }
1522 }
1523
1524 if (evdConfig.emitPostQualifier)
1525 {
1526 emitPostQualifier(evdConfig, decl, type.getQualifier());
1527 }
1528 }
1529
visitSymbol(TIntermSymbol * symbolNode)1530 void GenMetalTraverser::visitSymbol(TIntermSymbol *symbolNode)
1531 {
1532 const TVariable &var = symbolNode->variable();
1533 const TType &type = var.getType();
1534 ASSERT(var.symbolType() != SymbolType::Empty);
1535
1536 if (type.getBasicType() == TBasicType::EbtVoid)
1537 {
1538 mOut << "/*";
1539 emitNameOf(var);
1540 mOut << "*/";
1541 }
1542 else
1543 {
1544 emitNameOf(var);
1545 }
1546 }
1547
emitSingleConstant(const TConstantUnion * const constUnion)1548 void GenMetalTraverser::emitSingleConstant(const TConstantUnion *const constUnion)
1549 {
1550 switch (constUnion->getType())
1551 {
1552 case TBasicType::EbtBool:
1553 {
1554 mOut << (constUnion->getBConst() ? "true" : "false");
1555 }
1556 break;
1557
1558 case TBasicType::EbtFloat:
1559 {
1560 float value = constUnion->getFConst();
1561 if (std::isnan(value))
1562 {
1563 mOut << "NAN";
1564 }
1565 else if (std::isinf(value))
1566 {
1567 if (value < 0)
1568 {
1569 mOut << "-";
1570 }
1571 mOut << "INFINITY";
1572 }
1573 else
1574 {
1575 mOut << value << "f";
1576 }
1577 }
1578 break;
1579
1580 case TBasicType::EbtInt:
1581 {
1582 mOut << constUnion->getIConst();
1583 }
1584 break;
1585
1586 case TBasicType::EbtUInt:
1587 {
1588 mOut << constUnion->getUConst() << "u";
1589 }
1590 break;
1591
1592 default:
1593 {
1594 UNIMPLEMENTED();
1595 }
1596 }
1597 }
1598
emitConstantUnionArray(const TConstantUnion * const constUnion,const size_t size)1599 const TConstantUnion *GenMetalTraverser::emitConstantUnionArray(
1600 const TConstantUnion *const constUnion,
1601 const size_t size)
1602 {
1603 const TConstantUnion *constUnionIterated = constUnion;
1604 for (size_t i = 0; i < size; i++, constUnionIterated++)
1605 {
1606 emitSingleConstant(constUnionIterated);
1607
1608 if (i != size - 1)
1609 {
1610 mOut << ", ";
1611 }
1612 }
1613 return constUnionIterated;
1614 }
1615
emitConstantUnion(const TType & type,const TConstantUnion * constUnionBegin)1616 const TConstantUnion *GenMetalTraverser::emitConstantUnion(const TType &type,
1617 const TConstantUnion *constUnionBegin)
1618 {
1619 const TConstantUnion *constUnionCurr = constUnionBegin;
1620 const TStructure *structure = type.getStruct();
1621 if (structure)
1622 {
1623 EmitTypeConfig config = EmitTypeConfig{nullptr};
1624 emitType(type, config);
1625 mOut << "{";
1626 const TFieldList &fields = structure->fields();
1627 for (size_t i = 0; i < fields.size(); ++i)
1628 {
1629 const TType *fieldType = fields[i]->type();
1630 constUnionCurr = emitConstantUnion(*fieldType, constUnionCurr);
1631 if (i != fields.size() - 1)
1632 {
1633 mOut << ", ";
1634 }
1635 }
1636 mOut << "}";
1637 }
1638 else
1639 {
1640 size_t size = type.getObjectSize();
1641 bool writeType = size > 1;
1642 if (writeType)
1643 {
1644 EmitTypeConfig config = EmitTypeConfig{nullptr};
1645 emitType(type, config);
1646 mOut << "(";
1647 }
1648 constUnionCurr = emitConstantUnionArray(constUnionCurr, size);
1649 if (writeType)
1650 {
1651 mOut << ")";
1652 }
1653 }
1654 return constUnionCurr;
1655 }
1656
visitConstantUnion(TIntermConstantUnion * constValueNode)1657 void GenMetalTraverser::visitConstantUnion(TIntermConstantUnion *constValueNode)
1658 {
1659 emitConstantUnion(constValueNode->getType(), constValueNode->getConstantValue());
1660 }
1661
visitSwizzle(Visit,TIntermSwizzle * swizzleNode)1662 bool GenMetalTraverser::visitSwizzle(Visit, TIntermSwizzle *swizzleNode)
1663 {
1664 groupedTraverse(*swizzleNode->getOperand());
1665 mOut << ".";
1666
1667 {
1668 #if defined(ANGLE_ENABLE_ASSERTS)
1669 DebugSink::EscapedSink escapedOut(mOut.escape());
1670 TInfoSinkBase &out = escapedOut.get();
1671 #else
1672 TInfoSinkBase &out = mOut;
1673 #endif
1674 swizzleNode->writeOffsetsAsXYZW(&out);
1675 }
1676
1677 return false;
1678 }
1679
getDirectField(const TFieldListCollection & fieldListCollection,const TConstantUnion & index)1680 const TField &GenMetalTraverser::getDirectField(const TFieldListCollection &fieldListCollection,
1681 const TConstantUnion &index)
1682 {
1683 ASSERT(index.getType() == TBasicType::EbtInt);
1684
1685 const TFieldList &fieldList = fieldListCollection.fields();
1686 const int indexVal = index.getIConst();
1687 const TField &field = *fieldList[indexVal];
1688
1689 return field;
1690 }
1691
getDirectField(const TIntermTyped & fieldsNode,TIntermTyped & indexNode)1692 const TField &GenMetalTraverser::getDirectField(const TIntermTyped &fieldsNode,
1693 TIntermTyped &indexNode)
1694 {
1695 const TType &fieldsType = fieldsNode.getType();
1696
1697 const TFieldListCollection *fieldListCollection = fieldsType.getStruct();
1698 if (fieldListCollection == nullptr)
1699 {
1700 fieldListCollection = fieldsType.getInterfaceBlock();
1701 }
1702 ASSERT(fieldListCollection);
1703
1704 const TIntermConstantUnion *indexNode_ = indexNode.getAsConstantUnion();
1705 ASSERT(indexNode_);
1706 const TConstantUnion &index = *indexNode_->getConstantValue();
1707
1708 return getDirectField(*fieldListCollection, index);
1709 }
1710
visitBinary(Visit,TIntermBinary * binaryNode)1711 bool GenMetalTraverser::visitBinary(Visit, TIntermBinary *binaryNode)
1712 {
1713 const TOperator op = binaryNode->getOp();
1714 TIntermTyped &leftNode = *binaryNode->getLeft();
1715 TIntermTyped &rightNode = *binaryNode->getRight();
1716
1717 switch (op)
1718 {
1719 case TOperator::EOpIndexDirectStruct:
1720 case TOperator::EOpIndexDirectInterfaceBlock:
1721 {
1722 const TField &field = getDirectField(leftNode, rightNode);
1723 if (mSymbolEnv.isPointer(field) && mSymbolEnv.isUBO(field))
1724 {
1725 emitOpeningPointerParen();
1726 }
1727 groupedTraverse(leftNode);
1728 if (!mSymbolEnv.isPointer(field))
1729 {
1730 emitClosingPointerParen();
1731 }
1732 mOut << ".";
1733 emitNameOf(field);
1734 }
1735 break;
1736
1737 case TOperator::EOpIndexDirect:
1738 case TOperator::EOpIndexIndirect:
1739 {
1740 TType leftType = leftNode.getType();
1741 groupedTraverse(leftNode);
1742 mOut << "[";
1743 {
1744 mOut << "ANGLE_int_clamp(";
1745 groupedTraverse(rightNode);
1746 mOut << ", 0, ";
1747 if (leftType.isUnsizedArray())
1748 {
1749 groupedTraverse(leftNode);
1750 mOut << ".size()";
1751 }
1752 else
1753 {
1754 uint32_t maxSize;
1755 if (leftType.isArray())
1756 {
1757 maxSize = leftType.getOutermostArraySize() - 1;
1758 }
1759 else
1760 {
1761 maxSize = leftType.getNominalSize() - 1;
1762 }
1763 mOut << maxSize;
1764 }
1765 mOut << ")";
1766 }
1767 mOut << "]";
1768 }
1769 break;
1770
1771 default:
1772 {
1773 const TType &resultType = binaryNode->getType();
1774 const TType &leftType = leftNode.getType();
1775 const TType &rightType = rightNode.getType();
1776
1777 if (IsSymbolicOperator(op, resultType, &leftType, &rightType))
1778 {
1779 groupedTraverse(leftNode);
1780 if (op != TOperator::EOpComma)
1781 {
1782 mOut << " ";
1783 }
1784 else
1785 {
1786 emitClosingPointerParen();
1787 }
1788 mOut << GetOperatorString(op, resultType, &leftType, &rightType, nullptr) << " ";
1789 groupedTraverse(rightNode);
1790 }
1791 else
1792 {
1793 emitClosingPointerParen();
1794 mOut << GetOperatorString(op, resultType, &leftType, &rightType, nullptr) << "(";
1795 leftNode.traverse(this);
1796 mOut << ", ";
1797 rightNode.traverse(this);
1798 mOut << ")";
1799 }
1800 }
1801 }
1802
1803 return false;
1804 }
1805
IsPostfix(TOperator op)1806 static bool IsPostfix(TOperator op)
1807 {
1808 switch (op)
1809 {
1810 case TOperator::EOpPostIncrement:
1811 case TOperator::EOpPostDecrement:
1812 return true;
1813
1814 default:
1815 return false;
1816 }
1817 }
1818
visitUnary(Visit,TIntermUnary * unaryNode)1819 bool GenMetalTraverser::visitUnary(Visit, TIntermUnary *unaryNode)
1820 {
1821 const TOperator op = unaryNode->getOp();
1822 const TType &resultType = unaryNode->getType();
1823
1824 TIntermTyped &arg = *unaryNode->getOperand();
1825 const TType &argType = arg.getType();
1826
1827 const char *name = GetOperatorString(op, resultType, &argType, nullptr, nullptr);
1828
1829 if (IsSymbolicOperator(op, resultType, &argType, nullptr))
1830 {
1831 const bool postfix = IsPostfix(op);
1832 if (!postfix)
1833 {
1834 mOut << name;
1835 }
1836 groupedTraverse(arg);
1837 if (postfix)
1838 {
1839 mOut << name;
1840 }
1841 }
1842 else
1843 {
1844 mOut << name << "(";
1845 arg.traverse(this);
1846 mOut << ")";
1847 }
1848
1849 return false;
1850 }
1851
visitTernary(Visit,TIntermTernary * conditionalNode)1852 bool GenMetalTraverser::visitTernary(Visit, TIntermTernary *conditionalNode)
1853 {
1854 groupedTraverse(*conditionalNode->getCondition());
1855 mOut << " ? ";
1856 groupedTraverse(*conditionalNode->getTrueExpression());
1857 mOut << " : ";
1858 groupedTraverse(*conditionalNode->getFalseExpression());
1859
1860 return false;
1861 }
1862
visitIfElse(Visit,TIntermIfElse * ifThenElseNode)1863 bool GenMetalTraverser::visitIfElse(Visit, TIntermIfElse *ifThenElseNode)
1864 {
1865 TIntermTyped &condNode = *ifThenElseNode->getCondition();
1866 TIntermBlock *thenNode = ifThenElseNode->getTrueBlock();
1867 TIntermBlock *elseNode = ifThenElseNode->getFalseBlock();
1868
1869 emitIndentation();
1870 mOut << "if (";
1871 condNode.traverse(this);
1872 mOut << ")";
1873
1874 if (thenNode)
1875 {
1876 mOut << "\n";
1877 thenNode->traverse(this);
1878 }
1879 else
1880 {
1881 mOut << " {}";
1882 }
1883
1884 if (elseNode)
1885 {
1886 mOut << "\n";
1887 emitIndentation();
1888 mOut << "else\n";
1889 elseNode->traverse(this);
1890 }
1891 else
1892 {
1893 // Always emit "else" even when empty block to avoid nested if-stmt issues.
1894 mOut << " else {}";
1895 }
1896
1897 return false;
1898 }
1899
visitSwitch(Visit,TIntermSwitch * switchNode)1900 bool GenMetalTraverser::visitSwitch(Visit, TIntermSwitch *switchNode)
1901 {
1902 emitIndentation();
1903 mOut << "switch (";
1904 switchNode->getInit()->traverse(this);
1905 mOut << ")\n";
1906
1907 ASSERT(!mParentIsSwitch);
1908 mParentIsSwitch = true;
1909 switchNode->getStatementList()->traverse(this);
1910 mParentIsSwitch = false;
1911
1912 return false;
1913 }
1914
visitCase(Visit,TIntermCase * caseNode)1915 bool GenMetalTraverser::visitCase(Visit, TIntermCase *caseNode)
1916 {
1917 emitIndentation();
1918
1919 if (caseNode->hasCondition())
1920 {
1921 TIntermTyped *condExpr = caseNode->getCondition();
1922 mOut << "case ";
1923 condExpr->traverse(this);
1924 mOut << ":";
1925 }
1926 else
1927 {
1928 mOut << "default:\n";
1929 }
1930
1931 return false;
1932 }
1933
emitFunctionSignature(const TFunction & func)1934 void GenMetalTraverser::emitFunctionSignature(const TFunction &func)
1935 {
1936 const bool isMain = func.isMain();
1937
1938 emitFunctionReturn(func);
1939
1940 mOut << " ";
1941 emitNameOf(func);
1942 if (isMain)
1943 {
1944 mOut << "0";
1945 }
1946 mOut << "(";
1947
1948 bool emitComma = false;
1949 const size_t paramCount = func.getParamCount();
1950 for (size_t i = 0; i < paramCount; ++i)
1951 {
1952 if (emitComma)
1953 {
1954 mOut << ", ";
1955 }
1956 emitComma = true;
1957
1958 const TVariable ¶m = *func.getParam(i);
1959 emitFunctionParameter(func, param);
1960 }
1961
1962 if (isTraversingVertexMain)
1963 {
1964 mOut << " @@XFB-Bindings@@ ";
1965 }
1966
1967 mOut << ")";
1968 }
1969
emitFunctionReturn(const TFunction & func)1970 void GenMetalTraverser::emitFunctionReturn(const TFunction &func)
1971 {
1972 const bool isMain = func.isMain();
1973 bool isVertexMain = false;
1974 const TType &returnType = func.getReturnType();
1975 if (isMain)
1976 {
1977 const TStructure *structure = returnType.getStruct();
1978 if (mPipelineStructs.fragmentOut.matches(*structure))
1979 {
1980 if (mCompiler.isEarlyFragmentTestsSpecified())
1981 {
1982 mOut << "[[early_fragment_tests]]\n";
1983 }
1984 mOut << "fragment ";
1985 }
1986 else if (mPipelineStructs.vertexOut.matches(*structure))
1987 {
1988 ASSERT(structure != nullptr);
1989 mOut << "vertex __VERTEX_OUT(";
1990 isVertexMain = true;
1991 }
1992 else
1993 {
1994 UNIMPLEMENTED();
1995 }
1996 }
1997 emitType(returnType, EmitTypeConfig());
1998 if (isVertexMain)
1999 mOut << ") ";
2000 }
2001
emitFunctionParameter(const TFunction & func,const TVariable & param)2002 void GenMetalTraverser::emitFunctionParameter(const TFunction &func, const TVariable ¶m)
2003 {
2004 const bool isMain = func.isMain();
2005
2006 const TType &type = param.getType();
2007 const TStructure *structure = type.getStruct();
2008
2009 EmitVariableDeclarationConfig evdConfig;
2010 evdConfig.isParameter = true;
2011 evdConfig.disableStructSpecifier = true; // It's invalid to declare a struct in a function arg.
2012 evdConfig.isMainParameter = isMain;
2013 evdConfig.emitPostQualifier = isMain;
2014 evdConfig.isUBO = mSymbolEnv.isUBO(param);
2015 evdConfig.isPointer = mSymbolEnv.isPointer(param);
2016 evdConfig.isReference = mSymbolEnv.isReference(param);
2017 emitVariableDeclaration(VarDecl(param), evdConfig);
2018
2019 if (isMain)
2020 {
2021 TranslatorMetalReflection *reflection = mtl::getTranslatorMetalReflection(&mCompiler);
2022 if (structure)
2023 {
2024 if (mPipelineStructs.fragmentIn.matches(*structure) ||
2025 mPipelineStructs.vertexIn.matches(*structure))
2026 {
2027 mOut << " [[stage_in]]";
2028 }
2029 else if (mPipelineStructs.angleUniforms.matches(*structure))
2030 {
2031 mOut << " [[buffer(" << mDriverUniformsBindingIndex << ")]]";
2032 }
2033 else if (mPipelineStructs.uniformBuffers.matches(*structure))
2034 {
2035 mOut << " [[buffer(" << mUBOArgumentBufferBindingIndex << ")]]";
2036 reflection->hasUBOs = true;
2037 }
2038 else if (mPipelineStructs.userUniforms.matches(*structure))
2039 {
2040 mOut << " [[buffer(" << mMainUniformBufferIndex << ")]]";
2041 reflection->addUserUniformBufferBinding(param.name().data(),
2042 mMainUniformBufferIndex);
2043 mMainUniformBufferIndex += type.getArraySizeProduct();
2044 }
2045 else if (structure->name() == "metal::sampler")
2046 {
2047 mOut << " [[sampler(" << (mMainSamplerIndex) << ")]]";
2048 const std::string originalName =
2049 reflection->getOriginalName(param.uniqueId().get());
2050 reflection->addSamplerBinding(originalName, mMainSamplerIndex);
2051 mMainSamplerIndex += type.getArraySizeProduct();
2052 }
2053 }
2054 else if (IsSampler(type.getBasicType()))
2055 {
2056 mOut << " [[texture(" << (mMainTextureIndex) << ")]]";
2057 const std::string originalName = reflection->getOriginalName(param.uniqueId().get());
2058 reflection->addTextureBinding(originalName, mMainTextureIndex);
2059 mMainTextureIndex += type.getArraySizeProduct();
2060 }
2061 else if (Name(param) == Pipeline{Pipeline::Type::InstanceId, nullptr}.getStructInstanceName(
2062 Pipeline::Variant::Modified))
2063 {
2064 mOut << " [[instance_id]]";
2065 }
2066 else if (Name(param) == kBaseInstanceName)
2067 {
2068 mOut << " [[base_instance]]";
2069 }
2070 }
2071 }
2072
visitFunctionPrototype(TIntermFunctionPrototype * funcProtoNode)2073 void GenMetalTraverser::visitFunctionPrototype(TIntermFunctionPrototype *funcProtoNode)
2074 {
2075 const TFunction &func = *funcProtoNode->getFunction();
2076
2077 emitIndentation();
2078 emitFunctionSignature(func);
2079 }
2080
visitFunctionDefinition(Visit,TIntermFunctionDefinition * funcDefNode)2081 bool GenMetalTraverser::visitFunctionDefinition(Visit, TIntermFunctionDefinition *funcDefNode)
2082 {
2083 const TFunction &func = *funcDefNode->getFunction();
2084 TIntermBlock &body = *funcDefNode->getBody();
2085 if (func.isMain())
2086 {
2087 const TType &returnType = func.getReturnType();
2088 const TStructure *structure = returnType.getStruct();
2089 isTraversingVertexMain = (mPipelineStructs.vertexOut.matches(*structure)) &&
2090 mCompiler.getShaderType() == GL_VERTEX_SHADER;
2091 }
2092 emitIndentation();
2093 emitFunctionSignature(func);
2094 mOut << "\n";
2095 body.traverse(this);
2096 if (isTraversingVertexMain)
2097 {
2098 isTraversingVertexMain = false;
2099 }
2100 return false;
2101 }
2102
BuildFuncToName()2103 GenMetalTraverser::FuncToName GenMetalTraverser::BuildFuncToName()
2104 {
2105 FuncToName map;
2106
2107 auto putAngle = [&](const char *nameStr) {
2108 const ImmutableString name(nameStr);
2109 ASSERT(map.find(name) == map.end());
2110 map[name] = Name(nameStr, SymbolType::AngleInternal);
2111 };
2112
2113 putAngle("texelFetch");
2114 putAngle("texelFetchOffset");
2115 putAngle("texture");
2116 putAngle("texture1D");
2117 putAngle("texture1DLod");
2118 putAngle("texture1DProjLod");
2119 putAngle("texture2D");
2120 putAngle("texture2DGradEXT");
2121 putAngle("texture2DLod");
2122 putAngle("texture2DLodEXT");
2123 putAngle("texture2DProj");
2124 putAngle("texture2DProjGradEXT");
2125 putAngle("texture2DProjLod");
2126 putAngle("texture2DProjLodEXT");
2127 putAngle("texture2DRect");
2128 putAngle("texture2DRectProj");
2129 putAngle("texture3D");
2130 putAngle("texture3DLod");
2131 putAngle("texture3DProjLod");
2132 putAngle("textureCube");
2133 putAngle("textureCubeGradEXT");
2134 putAngle("textureCubeLod");
2135 putAngle("textureCubeLodEXT");
2136 putAngle("textureCubeProjLod");
2137 putAngle("textureGrad");
2138 putAngle("textureGradOffset");
2139 putAngle("textureLod");
2140 putAngle("textureLodOffset");
2141 putAngle("textureOffset");
2142 putAngle("textureProj");
2143 putAngle("textureProjGrad");
2144 putAngle("textureProjGradOffset");
2145 putAngle("textureProjLod");
2146 putAngle("textureProjLodOffset");
2147 putAngle("textureProjOffset");
2148 putAngle("textureSize");
2149 putAngle("imageLoad");
2150 putAngle("imageStore");
2151 putAngle("memoryBarrierImage");
2152
2153 return map;
2154 }
2155
visitAggregate(Visit,TIntermAggregate * aggregateNode)2156 bool GenMetalTraverser::visitAggregate(Visit, TIntermAggregate *aggregateNode)
2157 {
2158 const TIntermSequence &args = *aggregateNode->getSequence();
2159
2160 auto emitArgList = [&](const char *open, const char *close) {
2161 mOut << open;
2162
2163 bool emitComma = false;
2164 for (TIntermNode *arg : args)
2165 {
2166 if (emitComma)
2167 {
2168 emitClosingPointerParen();
2169 mOut << ", ";
2170 }
2171 emitComma = true;
2172 arg->traverse(this);
2173 }
2174
2175 mOut << close;
2176 };
2177
2178 const TType &retType = aggregateNode->getType();
2179
2180 if (aggregateNode->isConstructor())
2181 {
2182 const bool isStandalone = getParentNode()->getAsBlock();
2183 if (isStandalone)
2184 {
2185 // Prevent constructor from being interpreted as a declaration by wrapping in parens.
2186 // This can happen if given something like:
2187 // int(symbol); // <- This will be treated like `int symbol;`... don't want that.
2188 // So instead emit:
2189 // (int(symbol));
2190 mOut << "(";
2191 }
2192
2193 const EmitTypeConfig etConfig;
2194
2195 if (retType.isArray())
2196 {
2197 emitType(retType, etConfig);
2198 emitArgList("{", "}");
2199 }
2200 else if (retType.getStruct())
2201 {
2202 emitType(retType, etConfig);
2203 emitArgList("{", "}");
2204 }
2205 else
2206 {
2207 emitType(retType, etConfig);
2208 emitArgList("(", ")");
2209 }
2210
2211 if (isStandalone)
2212 {
2213 mOut << ")";
2214 }
2215
2216 return false;
2217 }
2218 else
2219 {
2220 const TOperator op = aggregateNode->getOp();
2221 if (op == EOpAtan)
2222 {
2223 TranslatorMetalReflection *reflection = mtl::getTranslatorMetalReflection(&mCompiler);
2224 reflection->hasAtan = true;
2225 }
2226 switch (op)
2227 {
2228 case TOperator::EOpCallFunctionInAST:
2229 case TOperator::EOpCallInternalRawFunction:
2230 {
2231 const TFunction &func = *aggregateNode->getFunction();
2232 emitNameOf(func);
2233 //'@' symbol in name specifices a macro substitution marker.
2234 if (!func.name().contains("@"))
2235 {
2236 emitArgList("(", ")");
2237 }
2238 else
2239 {
2240 mTemporarilyDisableSemicolon =
2241 true; // Disable semicolon for macro substitution.
2242 }
2243 return false;
2244 }
2245
2246 default:
2247 {
2248 auto getArgType = [&](size_t index) -> const TType * {
2249 if (index < args.size())
2250 {
2251 TIntermTyped *arg = args[index]->getAsTyped();
2252 ASSERT(arg);
2253 return &arg->getType();
2254 }
2255 return nullptr;
2256 };
2257
2258 const TType *argType0 = getArgType(0);
2259 const TType *argType1 = getArgType(1);
2260 const TType *argType2 = getArgType(2);
2261
2262 const char *opName = GetOperatorString(op, retType, argType0, argType1, argType2);
2263
2264 if (IsSymbolicOperator(op, retType, argType0, argType1))
2265 {
2266 switch (args.size())
2267 {
2268 case 1:
2269 {
2270 TIntermNode &operandNode = *aggregateNode->getChildNode(0);
2271 if (IsPostfix(op))
2272 {
2273 mOut << opName;
2274 groupedTraverse(operandNode);
2275 }
2276 else
2277 {
2278 groupedTraverse(operandNode);
2279 mOut << opName;
2280 }
2281 return false;
2282 }
2283
2284 case 2:
2285 {
2286 TIntermNode &leftNode = *aggregateNode->getChildNode(0);
2287 TIntermNode &rightNode = *aggregateNode->getChildNode(1);
2288 groupedTraverse(leftNode);
2289 mOut << " " << opName << " ";
2290 groupedTraverse(rightNode);
2291 return false;
2292 }
2293
2294 default:
2295 UNREACHABLE();
2296 return false;
2297 }
2298 }
2299 else if (opName == nullptr)
2300 {
2301 const TFunction &func = *aggregateNode->getFunction();
2302 auto it = mFuncToName.find(func.name());
2303 ASSERT(it != mFuncToName.end());
2304 EmitName(mOut, it->second);
2305 emitArgList("(", ")");
2306 return false;
2307 }
2308 else
2309 {
2310 mOut << opName;
2311 emitArgList("(", ")");
2312 return false;
2313 }
2314 }
2315 }
2316 }
2317 }
2318
emitOpenBrace()2319 void GenMetalTraverser::emitOpenBrace()
2320 {
2321 ASSERT(mIndentLevel >= 0);
2322
2323 emitIndentation();
2324 mOut << "{\n";
2325 ++mIndentLevel;
2326 }
2327
emitCloseBrace()2328 void GenMetalTraverser::emitCloseBrace()
2329 {
2330 ASSERT(mIndentLevel >= 1);
2331
2332 --mIndentLevel;
2333 emitIndentation();
2334 mOut << "}";
2335 }
2336
RequiresSemicolonTerminator(TIntermNode & node)2337 static bool RequiresSemicolonTerminator(TIntermNode &node)
2338 {
2339 if (node.getAsBlock())
2340 {
2341 return false;
2342 }
2343 if (node.getAsLoopNode())
2344 {
2345 return false;
2346 }
2347 if (node.getAsSwitchNode())
2348 {
2349 return false;
2350 }
2351 if (node.getAsIfElseNode())
2352 {
2353 return false;
2354 }
2355 if (node.getAsFunctionDefinition())
2356 {
2357 return false;
2358 }
2359 if (node.getAsCaseNode())
2360 {
2361 return false;
2362 }
2363
2364 return true;
2365 }
2366
NewlinePad(TIntermNode & node)2367 static bool NewlinePad(TIntermNode &node)
2368 {
2369 if (node.getAsFunctionDefinition())
2370 {
2371 return true;
2372 }
2373 if (TIntermDeclaration *declNode = node.getAsDeclarationNode())
2374 {
2375 ASSERT(declNode->getChildCount() == 1);
2376 TIntermNode &childNode = *declNode->getChildNode(0);
2377 if (TIntermSymbol *symbolNode = childNode.getAsSymbolNode())
2378 {
2379 const TVariable &var = symbolNode->variable();
2380 return var.getType().isStructSpecifier();
2381 }
2382 return false;
2383 }
2384 return false;
2385 }
2386
visitBlock(Visit,TIntermBlock * blockNode)2387 bool GenMetalTraverser::visitBlock(Visit, TIntermBlock *blockNode)
2388 {
2389 ASSERT(mIndentLevel >= -1);
2390 const bool isGlobalScope = mIndentLevel == -1;
2391 const bool parentIsSwitch = mParentIsSwitch;
2392 mParentIsSwitch = false;
2393
2394 if (isGlobalScope)
2395 {
2396 ++mIndentLevel;
2397 }
2398 else
2399 {
2400 emitOpenBrace();
2401 if (parentIsSwitch)
2402 {
2403 ++mIndentLevel;
2404 }
2405 }
2406
2407 TIntermNode *prevStmtNode = nullptr;
2408
2409 const size_t stmtCount = blockNode->getChildCount();
2410 for (size_t i = 0; i < stmtCount; ++i)
2411 {
2412 TIntermNode &stmtNode = *blockNode->getChildNode(i);
2413
2414 if (isGlobalScope && prevStmtNode && (NewlinePad(*prevStmtNode) || NewlinePad(stmtNode)))
2415 {
2416 mOut << "\n";
2417 }
2418 const bool isCase = stmtNode.getAsCaseNode();
2419 mIndentLevel -= isCase;
2420 emitIndentation();
2421 mIndentLevel += isCase;
2422 stmtNode.traverse(this);
2423 if (RequiresSemicolonTerminator(stmtNode) && !mTemporarilyDisableSemicolon)
2424 {
2425 mOut << ";";
2426 }
2427 mTemporarilyDisableSemicolon = false;
2428 mOut << "\n";
2429
2430 prevStmtNode = &stmtNode;
2431 }
2432
2433 if (isGlobalScope)
2434 {
2435 ASSERT(mIndentLevel == 0);
2436 --mIndentLevel;
2437 }
2438 else
2439 {
2440 if (parentIsSwitch)
2441 {
2442 ASSERT(mIndentLevel >= 1);
2443 --mIndentLevel;
2444 }
2445 emitCloseBrace();
2446 mParentIsSwitch = parentIsSwitch;
2447 }
2448
2449 return false;
2450 }
2451
visitGlobalQualifierDeclaration(Visit,TIntermGlobalQualifierDeclaration *)2452 bool GenMetalTraverser::visitGlobalQualifierDeclaration(Visit, TIntermGlobalQualifierDeclaration *)
2453 {
2454 return false;
2455 }
2456
visitDeclaration(Visit,TIntermDeclaration * declNode)2457 bool GenMetalTraverser::visitDeclaration(Visit, TIntermDeclaration *declNode)
2458 {
2459 ASSERT(declNode->getChildCount() == 1);
2460 TIntermNode &node = *declNode->getChildNode(0);
2461
2462 EmitVariableDeclarationConfig evdConfig;
2463
2464 if (TIntermSymbol *symbolNode = node.getAsSymbolNode())
2465 {
2466 const TVariable &var = symbolNode->variable();
2467 emitVariableDeclaration(VarDecl(var), evdConfig);
2468 }
2469 else if (TIntermBinary *initNode = node.getAsBinaryNode())
2470 {
2471 ASSERT(initNode->getOp() == TOperator::EOpInitialize);
2472 TIntermSymbol *leftSymbolNode = initNode->getLeft()->getAsSymbolNode();
2473 TIntermTyped *valueNode = initNode->getRight()->getAsTyped();
2474 ASSERT(leftSymbolNode && valueNode);
2475
2476 if (getRootNode() == getParentBlock())
2477 {
2478 // DeferGlobalInitializers should have turned non-const global initializers into
2479 // deferred initializers. Note that variables marked as EvqGlobal can be treated as
2480 // EvqConst in some ANGLE code but not actually have their qualifier actually changed to
2481 // EvqConst. Thus just assume all EvqGlobal are actually EvqConst for all code run after
2482 // DeferGlobalInitializers.
2483 mOut << "constant ";
2484 }
2485
2486 const TVariable &var = leftSymbolNode->variable();
2487 const Name varName(var);
2488
2489 if (ExpressionContainsName(varName, *valueNode))
2490 {
2491 mRenamedSymbols[&var] = mIdGen.createNewName(varName);
2492 }
2493
2494 emitVariableDeclaration(VarDecl(var), evdConfig);
2495 mOut << " = ";
2496 groupedTraverse(*valueNode);
2497 }
2498 else
2499 {
2500 UNREACHABLE();
2501 }
2502
2503 return false;
2504 }
2505
visitLoop(Visit,TIntermLoop * loopNode)2506 bool GenMetalTraverser::visitLoop(Visit, TIntermLoop *loopNode)
2507 {
2508 const TLoopType loopType = loopNode->getType();
2509
2510 switch (loopType)
2511 {
2512 case TLoopType::ELoopFor:
2513 return visitForLoop(loopNode);
2514 case TLoopType::ELoopWhile:
2515 return visitWhileLoop(loopNode);
2516 case TLoopType::ELoopDoWhile:
2517 return visitDoWhileLoop(loopNode);
2518 }
2519 }
2520
visitForLoop(TIntermLoop * loopNode)2521 bool GenMetalTraverser::visitForLoop(TIntermLoop *loopNode)
2522 {
2523 ASSERT(loopNode->getType() == TLoopType::ELoopFor);
2524
2525 TIntermNode *initNode = loopNode->getInit();
2526 TIntermTyped *condNode = loopNode->getCondition();
2527 TIntermTyped *exprNode = loopNode->getExpression();
2528 TIntermBlock *bodyNode = loopNode->getBody();
2529 ASSERT(bodyNode);
2530
2531 mOut << "for (";
2532
2533 if (initNode)
2534 {
2535 initNode->traverse(this);
2536 }
2537 else
2538 {
2539 mOut << " ";
2540 }
2541
2542 mOut << "; ";
2543
2544 if (condNode)
2545 {
2546 condNode->traverse(this);
2547 }
2548
2549 mOut << "; ";
2550
2551 if (exprNode)
2552 {
2553 exprNode->traverse(this);
2554 }
2555
2556 mOut << ")\n";
2557
2558 bodyNode->traverse(this);
2559
2560 return false;
2561 }
2562
visitWhileLoop(TIntermLoop * loopNode)2563 bool GenMetalTraverser::visitWhileLoop(TIntermLoop *loopNode)
2564 {
2565 ASSERT(loopNode->getType() == TLoopType::ELoopWhile);
2566
2567 TIntermNode *initNode = loopNode->getInit();
2568 TIntermTyped *condNode = loopNode->getCondition();
2569 TIntermTyped *exprNode = loopNode->getExpression();
2570 TIntermBlock *bodyNode = loopNode->getBody();
2571 ASSERT(condNode && bodyNode);
2572 ASSERT(!initNode && !exprNode);
2573
2574 emitIndentation();
2575 mOut << "while (";
2576 condNode->traverse(this);
2577 mOut << ")\n";
2578 bodyNode->traverse(this);
2579
2580 return false;
2581 }
2582
visitDoWhileLoop(TIntermLoop * loopNode)2583 bool GenMetalTraverser::visitDoWhileLoop(TIntermLoop *loopNode)
2584 {
2585 ASSERT(loopNode->getType() == TLoopType::ELoopDoWhile);
2586
2587 TIntermNode *initNode = loopNode->getInit();
2588 TIntermTyped *condNode = loopNode->getCondition();
2589 TIntermTyped *exprNode = loopNode->getExpression();
2590 TIntermBlock *bodyNode = loopNode->getBody();
2591 ASSERT(condNode && bodyNode);
2592 ASSERT(!initNode && !exprNode);
2593
2594 emitIndentation();
2595 mOut << "do\n";
2596 bodyNode->traverse(this);
2597 mOut << "\n";
2598 emitIndentation();
2599 mOut << "while (";
2600 condNode->traverse(this);
2601 mOut << ");";
2602
2603 return false;
2604 }
2605
visitBranch(Visit,TIntermBranch * branchNode)2606 bool GenMetalTraverser::visitBranch(Visit, TIntermBranch *branchNode)
2607 {
2608 const TOperator flowOp = branchNode->getFlowOp();
2609 TIntermTyped *exprNode = branchNode->getExpression();
2610
2611 emitIndentation();
2612
2613 switch (flowOp)
2614 {
2615 case TOperator::EOpKill:
2616 {
2617 ASSERT(exprNode == nullptr);
2618 mOut << "metal::discard_fragment()";
2619 }
2620 break;
2621
2622 case TOperator::EOpReturn:
2623 {
2624 if (isTraversingVertexMain)
2625 {
2626 mOut << "#if TRANSFORM_FEEDBACK_ENABLED\n";
2627 emitIndentation();
2628 mOut << "return;\n";
2629 emitIndentation();
2630 mOut << "#else\n";
2631 emitIndentation();
2632 }
2633 mOut << "return";
2634 if (exprNode)
2635 {
2636 mOut << " ";
2637 exprNode->traverse(this);
2638 mOut << ";";
2639 }
2640 if (isTraversingVertexMain)
2641 {
2642 mOut << "\n";
2643 emitIndentation();
2644 mOut << "#endif\n";
2645 mTemporarilyDisableSemicolon = true;
2646 }
2647 }
2648 break;
2649
2650 case TOperator::EOpBreak:
2651 {
2652 ASSERT(exprNode == nullptr);
2653 mOut << "break";
2654 }
2655 break;
2656
2657 case TOperator::EOpContinue:
2658 {
2659 ASSERT(exprNode == nullptr);
2660 mOut << "continue";
2661 }
2662 break;
2663
2664 default:
2665 {
2666 UNREACHABLE();
2667 }
2668 }
2669
2670 return false;
2671 }
2672
2673 static size_t emitMetalCallCount = 0;
2674
EmitMetal(TCompiler & compiler,TIntermBlock & root,IdGen & idGen,const PipelineStructs & pipelineStructs,SymbolEnv & symbolEnv,const ProgramPreludeConfig & ppc,const ShCompileOptions & compileOptions)2675 bool sh::EmitMetal(TCompiler &compiler,
2676 TIntermBlock &root,
2677 IdGen &idGen,
2678 const PipelineStructs &pipelineStructs,
2679 SymbolEnv &symbolEnv,
2680 const ProgramPreludeConfig &ppc,
2681 const ShCompileOptions &compileOptions)
2682 {
2683 TInfoSinkBase &out = compiler.getInfoSink().obj;
2684
2685 {
2686 ++emitMetalCallCount;
2687 std::string filenameProto = angle::GetEnvironmentVar("GMD_FIXED_EMIT");
2688 if (!filenameProto.empty())
2689 {
2690 if (filenameProto != "/dev/null")
2691 {
2692 auto tryOpen = [&](char const *ext) {
2693 auto filename = filenameProto;
2694 filename += std::to_string(emitMetalCallCount);
2695 filename += ".";
2696 filename += ext;
2697 return fopen(filename.c_str(), "rb");
2698 };
2699 FILE *file = tryOpen("metal");
2700 if (!file)
2701 {
2702 file = tryOpen("cpp");
2703 }
2704 ASSERT(file);
2705
2706 fseek(file, 0, SEEK_END);
2707 size_t fileSize = ftell(file);
2708 fseek(file, 0, SEEK_SET);
2709
2710 std::vector<char> buff;
2711 buff.resize(fileSize + 1);
2712 fread(buff.data(), fileSize, 1, file);
2713 buff.back() = '\0';
2714
2715 fclose(file);
2716
2717 out << buff.data();
2718 }
2719
2720 return true;
2721 }
2722 }
2723
2724 out << "\n\n";
2725
2726 if (!EmitProgramPrelude(root, out, ppc))
2727 {
2728 return false;
2729 }
2730
2731 {
2732 #if defined(ANGLE_ENABLE_ASSERTS)
2733 DebugSink outWrapper(out, angle::GetBoolEnvironmentVar("GMD_STDOUT"));
2734 outWrapper.watch(angle::GetEnvironmentVar("GMD_WATCH_STRING"));
2735 #else
2736 TInfoSinkBase &outWrapper = out;
2737 #endif
2738 GenMetalTraverser gen(compiler, outWrapper, idGen, pipelineStructs, symbolEnv,
2739 compileOptions);
2740 root.traverse(&gen);
2741 }
2742
2743 out << "\n";
2744
2745 return true;
2746 }
2747