1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include <cctype>
8 #include <map>
9
10 #include "common/system_utils.h"
11 #include "compiler/translator/BaseTypes.h"
12 #include "compiler/translator/ImmutableStringBuilder.h"
13 #include "compiler/translator/OutputTree.h"
14 #include "compiler/translator/SymbolTable.h"
15 #include "compiler/translator/msl/AstHelpers.h"
16 #include "compiler/translator/msl/DebugSink.h"
17 #include "compiler/translator/msl/EmitMetal.h"
18 #include "compiler/translator/msl/Layout.h"
19 #include "compiler/translator/msl/Name.h"
20 #include "compiler/translator/msl/ProgramPrelude.h"
21 #include "compiler/translator/msl/RewritePipelines.h"
22 #include "compiler/translator/msl/TranslatorMSL.h"
23 #include "compiler/translator/tree_util/IntermTraverse.h"
24
25 using namespace sh;
26
27 ////////////////////////////////////////////////////////////////////////////////
28
29 #if defined(ANGLE_ENABLE_ASSERTS)
30 using Sink = DebugSink;
31 #else
32 using Sink = TInfoSinkBase;
33 #endif
34
35 ////////////////////////////////////////////////////////////////////////////////
36
37 namespace
38 {
39
40 struct VarDecl
41 {
VarDecl__anonc2fa8c430111::VarDecl42 explicit VarDecl(const TVariable &var) : mVariable(&var), mIsField(false) {}
VarDecl__anonc2fa8c430111::VarDecl43 explicit VarDecl(const TField &field) : mField(&field), mIsField(true) {}
44
variable__anonc2fa8c430111::VarDecl45 ANGLE_INLINE const TVariable &variable() const
46 {
47 ASSERT(isVariable());
48 return *mVariable;
49 }
50
field__anonc2fa8c430111::VarDecl51 ANGLE_INLINE const TField &field() const
52 {
53 ASSERT(isField());
54 return *mField;
55 }
56
isVariable__anonc2fa8c430111::VarDecl57 ANGLE_INLINE bool isVariable() const { return !mIsField; }
58
isField__anonc2fa8c430111::VarDecl59 ANGLE_INLINE bool isField() const { return mIsField; }
60
type__anonc2fa8c430111::VarDecl61 const TType &type() const { return isField() ? *field().type() : variable().getType(); }
62
symbolType__anonc2fa8c430111::VarDecl63 SymbolType symbolType() const
64 {
65 return isField() ? field().symbolType() : variable().symbolType();
66 }
67
68 private:
69 union
70 {
71 const TVariable *mVariable;
72 const TField *mField;
73 };
74 bool mIsField;
75 };
76
77 class GenMetalTraverser : public TIntermTraverser
78 {
79 public:
80 ~GenMetalTraverser() override;
81
82 GenMetalTraverser(const TCompiler &compiler,
83 Sink &out,
84 IdGen &idGen,
85 const PipelineStructs &pipelineStructs,
86 SymbolEnv &symbolEnv,
87 const ShCompileOptions &compileOptions);
88
89 void visitSymbol(TIntermSymbol *) override;
90 void visitConstantUnion(TIntermConstantUnion *) override;
91 bool visitSwizzle(Visit, TIntermSwizzle *) override;
92 bool visitBinary(Visit, TIntermBinary *) override;
93 bool visitUnary(Visit, TIntermUnary *) override;
94 bool visitTernary(Visit, TIntermTernary *) override;
95 bool visitIfElse(Visit, TIntermIfElse *) override;
96 bool visitSwitch(Visit, TIntermSwitch *) override;
97 bool visitCase(Visit, TIntermCase *) override;
98 void visitFunctionPrototype(TIntermFunctionPrototype *) override;
99 bool visitFunctionDefinition(Visit, TIntermFunctionDefinition *) override;
100 bool visitAggregate(Visit, TIntermAggregate *) override;
101 bool visitBlock(Visit, TIntermBlock *) override;
102 bool visitGlobalQualifierDeclaration(Visit, TIntermGlobalQualifierDeclaration *) override;
103 bool visitDeclaration(Visit, TIntermDeclaration *) override;
104 bool visitLoop(Visit, TIntermLoop *) override;
105 bool visitForLoop(TIntermLoop *);
106 bool visitWhileLoop(TIntermLoop *);
107 bool visitDoWhileLoop(TIntermLoop *);
108 bool visitBranch(Visit, TIntermBranch *) override;
109
110 private:
111 using FuncToName = std::map<ImmutableString, Name>;
112 static FuncToName BuildFuncToName();
113
114 struct EmitVariableDeclarationConfig
115 {
116 bool isParameter = false;
117 bool isMainParameter = false;
118 bool emitPostQualifier = false;
119 bool isPacked = false;
120 bool disableStructSpecifier = false;
121 bool isUBO = false;
122 const AddressSpace *isPointer = nullptr;
123 const AddressSpace *isReference = nullptr;
124 };
125
126 struct EmitTypeConfig
127 {
128 const EmitVariableDeclarationConfig *evdConfig = nullptr;
129 };
130
131 void emitIndentation();
132 void emitOpeningPointerParen();
133 void emitClosingPointerParen();
134 void emitFunctionSignature(const TFunction &func);
135 void emitFunctionReturn(const TFunction &func);
136 void emitFunctionParameter(const TFunction &func, const TVariable ¶m);
137
138 void emitNameOf(const TField &object);
139 void emitNameOf(const TSymbol &object);
140 void emitNameOf(const VarDecl &object);
141
142 void emitBareTypeName(const TType &type, const EmitTypeConfig &etConfig);
143 void emitType(const TType &type, const EmitTypeConfig &etConfig);
144 void emitPostQualifier(const EmitVariableDeclarationConfig &evdConfig,
145 const VarDecl &decl,
146 const TQualifier qualifier);
147
148 void emitLoopBody(TIntermBlock *bodyNode);
149
150 struct FieldAnnotationIndices
151 {
152 size_t attribute = 0;
153 size_t color = 0;
154 };
155
156 void emitFieldDeclaration(const TField &field,
157 const TStructure &parent,
158 FieldAnnotationIndices &annotationIndices);
159 void emitAttributeDeclaration(const TField &field, FieldAnnotationIndices &annotationIndices);
160 void emitUniformBufferDeclaration(const TField &field,
161 FieldAnnotationIndices &annotationIndices);
162 void emitStructDeclaration(const TType &type);
163 void emitOrdinaryVariableDeclaration(const VarDecl &decl,
164 const EmitVariableDeclarationConfig &evdConfig);
165 void emitVariableDeclaration(const VarDecl &decl,
166 const EmitVariableDeclarationConfig &evdConfig);
167
168 void emitOpenBrace();
169 void emitCloseBrace();
170
171 void groupedTraverse(TIntermNode &node);
172
173 const TField &getDirectField(const TFieldListCollection &fieldsNode,
174 const TConstantUnion &index);
175 const TField &getDirectField(const TIntermTyped &fieldsNode, TIntermTyped &indexNode);
176
177 const TConstantUnion *emitConstantUnionArray(const TConstantUnion *const constUnion,
178 const size_t size);
179
180 const TConstantUnion *emitConstantUnion(const TType &type, const TConstantUnion *constUnion);
181
182 void emitSingleConstant(const TConstantUnion *const constUnion);
183
184 private:
185 Sink &mOut;
186 const TCompiler &mCompiler;
187 const PipelineStructs &mPipelineStructs;
188 SymbolEnv &mSymbolEnv;
189 IdGen &mIdGen;
190 int mIndentLevel = -1;
191 int mLastIndentationPos = -1;
192 int mOpenPointerParenCount = 0;
193 bool mParentIsSwitch = false;
194 bool isTraversingVertexMain = false;
195 bool mTemporarilyDisableSemicolon = false;
196 std::unordered_map<const TSymbol *, Name> mRenamedSymbols;
197 const FuncToName mFuncToName = BuildFuncToName();
198 size_t mMainTextureIndex = 0;
199 size_t mMainSamplerIndex = 0;
200 size_t mMainUniformBufferIndex = 0;
201 size_t mDriverUniformsBindingIndex = 0;
202 size_t mUBOArgumentBufferBindingIndex = 0;
203 bool mRasterOrderGroupsSupported = false;
204 bool mInjectAsmStatementIntoLoopBodies = false;
205 };
206 } // anonymous namespace
207
~GenMetalTraverser()208 GenMetalTraverser::~GenMetalTraverser()
209 {
210 ASSERT(mIndentLevel == -1);
211 ASSERT(!mParentIsSwitch);
212 ASSERT(mOpenPointerParenCount == 0);
213 }
214
GenMetalTraverser(const TCompiler & compiler,Sink & out,IdGen & idGen,const PipelineStructs & pipelineStructs,SymbolEnv & symbolEnv,const ShCompileOptions & compileOptions)215 GenMetalTraverser::GenMetalTraverser(const TCompiler &compiler,
216 Sink &out,
217 IdGen &idGen,
218 const PipelineStructs &pipelineStructs,
219 SymbolEnv &symbolEnv,
220 const ShCompileOptions &compileOptions)
221 : TIntermTraverser(true, false, false),
222 mOut(out),
223 mCompiler(compiler),
224 mPipelineStructs(pipelineStructs),
225 mSymbolEnv(symbolEnv),
226 mIdGen(idGen),
227 mMainUniformBufferIndex(compileOptions.metal.defaultUniformsBindingIndex),
228 mDriverUniformsBindingIndex(compileOptions.metal.driverUniformsBindingIndex),
229 mUBOArgumentBufferBindingIndex(compileOptions.metal.UBOArgumentBufferBindingIndex),
230 mRasterOrderGroupsSupported(compileOptions.pls.fragmentSyncType ==
231 ShFragmentSynchronizationType::RasterOrderGroups_Metal),
232 mInjectAsmStatementIntoLoopBodies(compileOptions.metal.injectAsmStatementIntoLoopBodies)
233 {}
234
emitIndentation()235 void GenMetalTraverser::emitIndentation()
236 {
237 ASSERT(mIndentLevel >= 0);
238
239 if (mLastIndentationPos == mOut.size())
240 {
241 return; // Line is already indented.
242 }
243
244 for (int i = 0; i < mIndentLevel; ++i)
245 {
246 mOut << " ";
247 }
248
249 mLastIndentationPos = mOut.size();
250 }
251
emitOpeningPointerParen()252 void GenMetalTraverser::emitOpeningPointerParen()
253 {
254 mOut << "(*";
255 mOpenPointerParenCount++;
256 }
257
emitClosingPointerParen()258 void GenMetalTraverser::emitClosingPointerParen()
259 {
260 if (mOpenPointerParenCount > 0)
261 {
262 mOut << ")";
263 mOpenPointerParenCount--;
264 }
265 }
266
GetOperatorString(TOperator op,const TType & resultType,const TType * argType0,const TType * argType1,const TType * argType2)267 static const char *GetOperatorString(TOperator op,
268 const TType &resultType,
269 const TType *argType0,
270 const TType *argType1,
271 const TType *argType2)
272 {
273 switch (op)
274 {
275 case TOperator::EOpComma:
276 return ",";
277 case TOperator::EOpAssign:
278 return "=";
279 case TOperator::EOpInitialize:
280 return "=";
281 case TOperator::EOpAddAssign:
282 return "+=";
283 case TOperator::EOpSubAssign:
284 return "-=";
285 case TOperator::EOpMulAssign:
286 return "*=";
287 case TOperator::EOpDivAssign:
288 return "/=";
289 case TOperator::EOpIModAssign:
290 return "%=";
291 case TOperator::EOpBitShiftLeftAssign:
292 return "<<="; // TODO: Check logical vs arithmetic shifting.
293 case TOperator::EOpBitShiftRightAssign:
294 return ">>="; // TODO: Check logical vs arithmetic shifting.
295 case TOperator::EOpBitwiseAndAssign:
296 return "&=";
297 case TOperator::EOpBitwiseXorAssign:
298 return "^=";
299 case TOperator::EOpBitwiseOrAssign:
300 return "|=";
301 case TOperator::EOpAdd:
302 return "+";
303 case TOperator::EOpSub:
304 return "-";
305 case TOperator::EOpMul:
306 return "*";
307 case TOperator::EOpDiv:
308 return "/";
309 case TOperator::EOpIMod:
310 return "%";
311 case TOperator::EOpBitShiftLeft:
312 return "<<"; // TODO: Check logical vs arithmetic shifting.
313 case TOperator::EOpBitShiftRight:
314 return ">>"; // TODO: Check logical vs arithmetic shifting.
315 case TOperator::EOpBitwiseAnd:
316 return "&";
317 case TOperator::EOpBitwiseXor:
318 return "^";
319 case TOperator::EOpBitwiseOr:
320 return "|";
321 case TOperator::EOpLessThan:
322 return "<";
323 case TOperator::EOpGreaterThan:
324 return ">";
325 case TOperator::EOpLessThanEqual:
326 return "<=";
327 case TOperator::EOpGreaterThanEqual:
328 return ">=";
329 case TOperator::EOpLessThanComponentWise:
330 return "<";
331 case TOperator::EOpLessThanEqualComponentWise:
332 return "<=";
333 case TOperator::EOpGreaterThanEqualComponentWise:
334 return ">=";
335 case TOperator::EOpGreaterThanComponentWise:
336 return ">";
337 case TOperator::EOpLogicalOr:
338 return "||";
339 case TOperator::EOpLogicalXor:
340 return "!=/*xor*/"; // XXX: This might need to be handled differently for some obtuse
341 // use case.
342 case TOperator::EOpLogicalAnd:
343 return "&&";
344 case TOperator::EOpNegative:
345 return "-";
346 case TOperator::EOpPositive:
347 if (argType0->isMatrix())
348 {
349 return "";
350 }
351 return "+";
352 case TOperator::EOpLogicalNot:
353 return "!";
354 case TOperator::EOpNotComponentWise:
355 return "!";
356 case TOperator::EOpBitwiseNot:
357 return "~";
358 case TOperator::EOpPostIncrement:
359 return "++";
360 case TOperator::EOpPostDecrement:
361 return "--";
362 case TOperator::EOpPreIncrement:
363 return "++";
364 case TOperator::EOpPreDecrement:
365 return "--";
366 case TOperator::EOpVectorTimesScalarAssign:
367 return "*=";
368 case TOperator::EOpVectorTimesMatrixAssign:
369 return "*=";
370 case TOperator::EOpMatrixTimesScalarAssign:
371 return "*=";
372 case TOperator::EOpMatrixTimesMatrixAssign:
373 return "*=";
374 case TOperator::EOpVectorTimesScalar:
375 return "*";
376 case TOperator::EOpVectorTimesMatrix:
377 return "*";
378 case TOperator::EOpMatrixTimesVector:
379 return "*";
380 case TOperator::EOpMatrixTimesScalar:
381 return "*";
382 case TOperator::EOpMatrixTimesMatrix:
383 return "*";
384 case TOperator::EOpEqualComponentWise:
385 return "==";
386 case TOperator::EOpNotEqualComponentWise:
387 return "!=";
388
389 case TOperator::EOpEqual:
390 if ((argType0->getStruct() && argType1->getStruct()) &&
391 (argType0->isArray() && argType1->isArray()))
392 {
393 return "ANGLE_equalStructArray";
394 }
395
396 if ((argType0->isVector() && argType1->isVector()) ||
397 (argType0->getStruct() && argType1->getStruct()) ||
398 (argType0->isArray() && argType1->isArray()) ||
399 (argType0->isMatrix() && argType1->isMatrix()))
400
401 {
402 return "ANGLE_equal";
403 }
404
405 return "==";
406
407 case TOperator::EOpNotEqual:
408 if ((argType0->getStruct() && argType1->getStruct()) &&
409 (argType0->isArray() && argType1->isArray()))
410 {
411 return "ANGLE_notEqualStructArray";
412 }
413
414 if ((argType0->isVector() && argType1->isVector()) ||
415 (argType0->isArray() && argType1->isArray()) ||
416 (argType0->isMatrix() && argType1->isMatrix()))
417 {
418 return "ANGLE_notEqual";
419 }
420 else if (argType0->getStruct() && argType1->getStruct())
421 {
422 return "ANGLE_notEqualStruct";
423 }
424 return "!=";
425
426 case TOperator::EOpKill:
427 UNIMPLEMENTED();
428 return "kill";
429 case TOperator::EOpReturn:
430 return "return";
431 case TOperator::EOpBreak:
432 return "break";
433 case TOperator::EOpContinue:
434 return "continue";
435
436 case TOperator::EOpRadians:
437 return "ANGLE_radians";
438 case TOperator::EOpDegrees:
439 return "ANGLE_degrees";
440 case TOperator::EOpAtan:
441 return argType1 == nullptr ? "metal::atan" : "metal::atan2";
442 case TOperator::EOpMod:
443 return "ANGLE_mod"; // differs from metal::mod
444 case TOperator::EOpRefract:
445 return argType0->isVector() ? "metal::refract" : "ANGLE_refract_scalar";
446 case TOperator::EOpDistance:
447 return argType0->isVector() ? "metal::distance" : "ANGLE_distance_scalar";
448 case TOperator::EOpLength:
449 return argType0->isVector() ? "metal::length" : "metal::abs";
450 case TOperator::EOpDot:
451 return argType0->isVector() ? "metal::dot" : "*";
452 case TOperator::EOpNormalize:
453 return argType0->isVector() ? "metal::fast::normalize" : "metal::sign";
454 case TOperator::EOpFaceforward:
455 return argType0->isVector() ? "metal::faceforward" : "ANGLE_faceforward_scalar";
456 case TOperator::EOpReflect:
457 return argType0->isVector() ? "metal::reflect" : "ANGLE_reflect_scalar";
458 case TOperator::EOpMatrixCompMult:
459 return "ANGLE_componentWiseMultiply";
460 case TOperator::EOpOuterProduct:
461 return "ANGLE_outerProduct";
462 case TOperator::EOpSign:
463 return argType0->getBasicType() == EbtFloat ? "metal::sign" : "ANGLE_sign_int";
464
465 case TOperator::EOpAbs:
466 return "metal::abs";
467 case TOperator::EOpAll:
468 return "metal::all";
469 case TOperator::EOpAny:
470 return "metal::any";
471 case TOperator::EOpSin:
472 return "metal::sin";
473 case TOperator::EOpCos:
474 return "metal::cos";
475 case TOperator::EOpTan:
476 return "metal::tan";
477 case TOperator::EOpAsin:
478 return "metal::asin";
479 case TOperator::EOpAcos:
480 return "metal::acos";
481 case TOperator::EOpSinh:
482 return "metal::sinh";
483 case TOperator::EOpCosh:
484 return "metal::cosh";
485 case TOperator::EOpTanh:
486 return resultType.getPrecision() == TPrecision::EbpHigh ? "metal::precise::tanh"
487 : "metal::tanh";
488 case TOperator::EOpAsinh:
489 return "metal::asinh";
490 case TOperator::EOpAcosh:
491 return "metal::acosh";
492 case TOperator::EOpAtanh:
493 return "metal::atanh";
494 case TOperator::EOpFma:
495 return "metal::fma";
496 case TOperator::EOpPow:
497 return "metal::powr"; // GLSL's pow excludes negative x
498 case TOperator::EOpExp:
499 return "metal::exp";
500 case TOperator::EOpExp2:
501 return "metal::exp2";
502 case TOperator::EOpLog:
503 return "metal::log";
504 case TOperator::EOpLog2:
505 return "metal::log2";
506 case TOperator::EOpSqrt:
507 return "metal::sqrt";
508 case TOperator::EOpFloor:
509 return "metal::floor";
510 case TOperator::EOpTrunc:
511 return "metal::trunc";
512 case TOperator::EOpCeil:
513 return "metal::ceil";
514 case TOperator::EOpFract:
515 return "metal::fract";
516 case TOperator::EOpMin:
517 return "metal::min";
518 case TOperator::EOpMax:
519 return "metal::max";
520 case TOperator::EOpRound:
521 return "metal::round";
522 case TOperator::EOpRoundEven:
523 return "metal::rint";
524 case TOperator::EOpClamp:
525 return "metal::clamp"; // TODO fast vs precise namespace
526 case TOperator::EOpSaturate:
527 return "metal::saturate"; // TODO fast vs precise namespace
528 case TOperator::EOpMix:
529 if (!argType1->isScalar() && argType2 && argType2->getBasicType() == EbtBool)
530 {
531 return "ANGLE_mix_bool";
532 }
533 return "metal::mix";
534 case TOperator::EOpStep:
535 return "metal::step";
536 case TOperator::EOpSmoothstep:
537 return "metal::smoothstep";
538 case TOperator::EOpModf:
539 return "metal::modf";
540 case TOperator::EOpIsnan:
541 return "metal::isnan";
542 case TOperator::EOpIsinf:
543 return "metal::isinf";
544 case TOperator::EOpLdexp:
545 return "metal::ldexp";
546 case TOperator::EOpFrexp:
547 return "metal::frexp";
548 case TOperator::EOpInversesqrt:
549 return "metal::rsqrt";
550 case TOperator::EOpCross:
551 return "metal::cross";
552 case TOperator::EOpDFdx:
553 return "metal::dfdx";
554 case TOperator::EOpDFdy:
555 return "metal::dfdy";
556 case TOperator::EOpFwidth:
557 return "metal::fwidth";
558 case TOperator::EOpTranspose:
559 return "metal::transpose";
560 case TOperator::EOpDeterminant:
561 return "metal::determinant";
562
563 case TOperator::EOpInverse:
564 return "ANGLE_inverse";
565
566 case TOperator::EOpInterpolateAtCentroid:
567 return "ANGLE_interpolateAtCentroid";
568 case TOperator::EOpInterpolateAtSample:
569 return "ANGLE_interpolateAtSample";
570 case TOperator::EOpInterpolateAtOffset:
571 return "ANGLE_interpolateAtOffset";
572 case TOperator::EOpInterpolateAtCenter:
573 return "ANGLE_interpolateAtCenter";
574
575 case TOperator::EOpFloatBitsToInt:
576 case TOperator::EOpFloatBitsToUint:
577 case TOperator::EOpIntBitsToFloat:
578 case TOperator::EOpUintBitsToFloat:
579 {
580 #define RETURN_AS_TYPE_SCALAR() \
581 do \
582 switch (resultType.getBasicType()) \
583 { \
584 case TBasicType::EbtInt: \
585 return "as_type<int>"; \
586 case TBasicType::EbtUInt: \
587 return "as_type<uint32_t>"; \
588 case TBasicType::EbtFloat: \
589 return "as_type<float>"; \
590 default: \
591 UNIMPLEMENTED(); \
592 return "TOperator_TODO"; \
593 } \
594 while (false)
595
596 #define RETURN_AS_TYPE(post) \
597 do \
598 switch (resultType.getBasicType()) \
599 { \
600 case TBasicType::EbtInt: \
601 return "as_type<int" post ">"; \
602 case TBasicType::EbtUInt: \
603 return "as_type<uint" post ">"; \
604 case TBasicType::EbtFloat: \
605 return "as_type<float" post ">"; \
606 default: \
607 UNIMPLEMENTED(); \
608 return "TOperator_TODO"; \
609 } \
610 while (false)
611
612 if (resultType.isScalar())
613 {
614 RETURN_AS_TYPE_SCALAR();
615 }
616 else if (resultType.isVector())
617 {
618 switch (resultType.getNominalSize())
619 {
620 case 2:
621 RETURN_AS_TYPE("2");
622 case 3:
623 RETURN_AS_TYPE("3");
624 case 4:
625 RETURN_AS_TYPE("4");
626 default:
627 UNREACHABLE();
628 return nullptr;
629 }
630 }
631 else
632 {
633 UNIMPLEMENTED();
634 return "TOperator_TODO";
635 }
636
637 #undef RETURN_AS_TYPE
638 #undef RETURN_AS_TYPE_SCALAR
639 }
640
641 case TOperator::EOpPackUnorm2x16:
642 return "metal::pack_float_to_unorm2x16";
643 case TOperator::EOpPackSnorm2x16:
644 return "metal::pack_float_to_snorm2x16";
645
646 case TOperator::EOpPackUnorm4x8:
647 return "metal::pack_float_to_unorm4x8";
648 case TOperator::EOpPackSnorm4x8:
649 return "metal::pack_float_to_snorm4x8";
650
651 case TOperator::EOpUnpackUnorm2x16:
652 return "metal::unpack_unorm2x16_to_float";
653 case TOperator::EOpUnpackSnorm2x16:
654 return "metal::unpack_snorm2x16_to_float";
655
656 case TOperator::EOpUnpackUnorm4x8:
657 return "metal::unpack_unorm4x8_to_float";
658 case TOperator::EOpUnpackSnorm4x8:
659 return "metal::unpack_snorm4x8_to_float";
660
661 case TOperator::EOpPackHalf2x16:
662 return "ANGLE_pack_half_2x16";
663 case TOperator::EOpUnpackHalf2x16:
664 return "ANGLE_unpack_half_2x16";
665
666 case TOperator::EOpNumSamples:
667 return "metal::get_num_samples";
668 case TOperator::EOpSamplePosition:
669 return "metal::get_sample_position";
670
671 case TOperator::EOpBitfieldExtract:
672 case TOperator::EOpBitfieldInsert:
673 case TOperator::EOpBitfieldReverse:
674 case TOperator::EOpBitCount:
675 case TOperator::EOpFindLSB:
676 case TOperator::EOpFindMSB:
677 case TOperator::EOpUaddCarry:
678 case TOperator::EOpUsubBorrow:
679 case TOperator::EOpUmulExtended:
680 case TOperator::EOpImulExtended:
681 case TOperator::EOpBarrier:
682 case TOperator::EOpMemoryBarrier:
683 case TOperator::EOpMemoryBarrierAtomicCounter:
684 case TOperator::EOpMemoryBarrierBuffer:
685 case TOperator::EOpMemoryBarrierShared:
686 case TOperator::EOpGroupMemoryBarrier:
687 case TOperator::EOpAtomicAdd:
688 case TOperator::EOpAtomicMin:
689 case TOperator::EOpAtomicMax:
690 case TOperator::EOpAtomicAnd:
691 case TOperator::EOpAtomicOr:
692 case TOperator::EOpAtomicXor:
693 case TOperator::EOpAtomicExchange:
694 case TOperator::EOpAtomicCompSwap:
695 case TOperator::EOpEmitVertex:
696 case TOperator::EOpEndPrimitive:
697 case TOperator::EOpFtransform:
698 case TOperator::EOpPackDouble2x32:
699 case TOperator::EOpUnpackDouble2x32:
700 case TOperator::EOpArrayLength:
701 UNIMPLEMENTED();
702 return "TOperator_TODO";
703
704 case TOperator::EOpNull:
705 case TOperator::EOpConstruct:
706 case TOperator::EOpCallFunctionInAST:
707 case TOperator::EOpCallInternalRawFunction:
708 case TOperator::EOpIndexDirect:
709 case TOperator::EOpIndexIndirect:
710 case TOperator::EOpIndexDirectStruct:
711 case TOperator::EOpIndexDirectInterfaceBlock:
712 UNREACHABLE();
713 return nullptr;
714 default:
715 // Any other built-in function.
716 return nullptr;
717 }
718 }
719
IsSymbolicOperator(TOperator op,const TType & resultType,const TType * argType0,const TType * argType1)720 static bool IsSymbolicOperator(TOperator op,
721 const TType &resultType,
722 const TType *argType0,
723 const TType *argType1)
724 {
725 const char *operatorString = GetOperatorString(op, resultType, argType0, argType1, nullptr);
726 if (operatorString == nullptr)
727 {
728 return false;
729 }
730 return !std::isalnum(operatorString[0]);
731 }
732
AsSpecificBinaryNode(TIntermNode & node,TOperator op)733 static TIntermBinary *AsSpecificBinaryNode(TIntermNode &node, TOperator op)
734 {
735 TIntermBinary *binaryNode = node.getAsBinaryNode();
736 if (binaryNode)
737 {
738 return binaryNode->getOp() == op ? binaryNode : nullptr;
739 }
740 return nullptr;
741 }
742
Parenthesize(TIntermNode & node)743 static bool Parenthesize(TIntermNode &node)
744 {
745 if (node.getAsSymbolNode())
746 {
747 return false;
748 }
749 if (node.getAsConstantUnion())
750 {
751 return false;
752 }
753 if (node.getAsAggregate())
754 {
755 return false;
756 }
757 if (node.getAsSwizzleNode())
758 {
759 return false;
760 }
761
762 if (TIntermUnary *unaryNode = node.getAsUnaryNode())
763 {
764 // TODO: Use a precedence and associativity rules instead of this ad-hoc impl.
765 const TType &resultType = unaryNode->getType();
766 const TType &argType = unaryNode->getOperand()->getType();
767 return IsSymbolicOperator(unaryNode->getOp(), resultType, &argType, nullptr);
768 }
769
770 if (TIntermBinary *binaryNode = node.getAsBinaryNode())
771 {
772 // TODO: Use a precedence and associativity rules instead of this ad-hoc impl.
773 const TOperator op = binaryNode->getOp();
774 switch (op)
775 {
776 case TOperator::EOpIndexDirectStruct:
777 case TOperator::EOpIndexDirectInterfaceBlock:
778 case TOperator::EOpIndexDirect:
779 case TOperator::EOpIndexIndirect:
780 return Parenthesize(*binaryNode->getLeft());
781
782 case TOperator::EOpAssign:
783 case TOperator::EOpInitialize:
784 return AsSpecificBinaryNode(*binaryNode->getRight(), TOperator::EOpComma);
785
786 default:
787 {
788 const TType &resultType = binaryNode->getType();
789 const TType &leftType = binaryNode->getLeft()->getType();
790 const TType &rightType = binaryNode->getRight()->getType();
791 return IsSymbolicOperator(binaryNode->getOp(), resultType, &leftType, &rightType);
792 }
793 }
794 }
795
796 return true;
797 }
798
groupedTraverse(TIntermNode & node)799 void GenMetalTraverser::groupedTraverse(TIntermNode &node)
800 {
801 const bool emitParens = Parenthesize(node);
802
803 if (emitParens)
804 {
805 mOut << "(";
806 }
807
808 node.traverse(this);
809
810 if (emitParens)
811 {
812 mOut << ")";
813 }
814 }
815
emitPostQualifier(const EmitVariableDeclarationConfig & evdConfig,const VarDecl & decl,const TQualifier qualifier)816 void GenMetalTraverser::emitPostQualifier(const EmitVariableDeclarationConfig &evdConfig,
817 const VarDecl &decl,
818 const TQualifier qualifier)
819 {
820 bool isInvariant = false;
821 switch (qualifier)
822 {
823 case TQualifier::EvqPosition:
824 isInvariant = decl.type().isInvariant();
825 [[fallthrough]];
826 case TQualifier::EvqFragCoord:
827 mOut << " [[position]]";
828 break;
829
830 case TQualifier::EvqClipDistance:
831 mOut << " [[clip_distance]] [" << decl.type().getOutermostArraySize() << "]";
832 break;
833
834 case TQualifier::EvqPointSize:
835 mOut << " [[point_size]]";
836 break;
837
838 case TQualifier::EvqVertexID:
839 if (evdConfig.isMainParameter)
840 {
841 mOut << " [[vertex_id]]";
842 }
843 break;
844
845 case TQualifier::EvqPointCoord:
846 if (evdConfig.isMainParameter)
847 {
848 mOut << " [[point_coord]]";
849 }
850 break;
851
852 case TQualifier::EvqFrontFacing:
853 if (evdConfig.isMainParameter)
854 {
855 mOut << " [[front_facing]]";
856 }
857 break;
858
859 case TQualifier::EvqSampleID:
860 if (evdConfig.isMainParameter)
861 {
862 mOut << " [[sample_id]]";
863 }
864 break;
865
866 case TQualifier::EvqSampleMaskIn:
867 if (evdConfig.isMainParameter)
868 {
869 mOut << " [[sample_mask]]";
870 }
871 break;
872
873 default:
874 break;
875 }
876
877 if (isInvariant)
878 {
879 mOut << " [[invariant]]";
880
881 TranslatorMetalReflection *reflection = mtl::getTranslatorMetalReflection(&mCompiler);
882 reflection->hasInvariance = true;
883 }
884 }
885
emitLoopBody(TIntermBlock * bodyNode)886 void GenMetalTraverser::emitLoopBody(TIntermBlock *bodyNode)
887 {
888 if (mInjectAsmStatementIntoLoopBodies)
889 {
890 emitOpenBrace();
891
892 emitIndentation();
893 mOut << "__asm__(\"\");\n";
894 }
895
896 bodyNode->traverse(this);
897
898 if (mInjectAsmStatementIntoLoopBodies)
899 {
900 emitCloseBrace();
901 }
902 }
903
EmitName(Sink & out,const Name & name)904 static void EmitName(Sink &out, const Name &name)
905 {
906 #if defined(ANGLE_ENABLE_ASSERTS)
907 DebugSink::EscapedSink escapedOut(out.escape());
908 #else
909 TInfoSinkBase &escapedOut = out;
910 #endif
911 name.emit(escapedOut);
912 }
913
emitNameOf(const TField & object)914 void GenMetalTraverser::emitNameOf(const TField &object)
915 {
916 EmitName(mOut, Name(object));
917 }
918
emitNameOf(const TSymbol & object)919 void GenMetalTraverser::emitNameOf(const TSymbol &object)
920 {
921 auto it = mRenamedSymbols.find(&object);
922 if (it == mRenamedSymbols.end())
923 {
924 EmitName(mOut, Name(object));
925 }
926 else
927 {
928 EmitName(mOut, it->second);
929 }
930 }
931
emitNameOf(const VarDecl & object)932 void GenMetalTraverser::emitNameOf(const VarDecl &object)
933 {
934 if (object.isField())
935 {
936 emitNameOf(object.field());
937 }
938 else
939 {
940 emitNameOf(object.variable());
941 }
942 }
943
emitBareTypeName(const TType & type,const EmitTypeConfig & etConfig)944 void GenMetalTraverser::emitBareTypeName(const TType &type, const EmitTypeConfig &etConfig)
945 {
946 const TBasicType basicType = type.getBasicType();
947
948 switch (basicType)
949 {
950 case TBasicType::EbtVoid:
951 case TBasicType::EbtBool:
952 case TBasicType::EbtFloat:
953 case TBasicType::EbtInt:
954 {
955 mOut << type.getBasicString();
956 break;
957 }
958 case TBasicType::EbtUInt:
959 {
960 if (type.isScalar())
961 {
962 mOut << "uint32_t";
963 }
964 else
965 {
966 mOut << type.getBasicString();
967 }
968 }
969 break;
970
971 case TBasicType::EbtStruct:
972 {
973 const TStructure &structure = *type.getStruct();
974 emitNameOf(structure);
975 }
976 break;
977
978 case TBasicType::EbtInterfaceBlock:
979 {
980 const TInterfaceBlock &interfaceBlock = *type.getInterfaceBlock();
981 emitNameOf(interfaceBlock);
982 }
983 break;
984
985 default:
986 {
987 if (IsSampler(basicType))
988 {
989 if (etConfig.evdConfig && etConfig.evdConfig->isMainParameter)
990 {
991 EmitName(mOut, GetTextureTypeName(basicType));
992 }
993 else
994 {
995 const TStructure &env = mSymbolEnv.getTextureEnv(basicType);
996 emitNameOf(env);
997 }
998 }
999 else if (IsImage(basicType))
1000 {
1001 mOut << "metal::texture2d<";
1002 switch (type.getBasicType())
1003 {
1004 case EbtImage2D:
1005 mOut << "float";
1006 break;
1007 case EbtIImage2D:
1008 mOut << "int";
1009 break;
1010 case EbtUImage2D:
1011 mOut << "uint";
1012 break;
1013 default:
1014 UNIMPLEMENTED();
1015 break;
1016 }
1017 if (type.getMemoryQualifier().readonly || type.getMemoryQualifier().writeonly)
1018 {
1019 UNIMPLEMENTED();
1020 }
1021 mOut << ", metal::access::read_write>";
1022 }
1023 else
1024 {
1025 UNIMPLEMENTED();
1026 }
1027 }
1028 }
1029 }
1030
emitType(const TType & type,const EmitTypeConfig & etConfig)1031 void GenMetalTraverser::emitType(const TType &type, const EmitTypeConfig &etConfig)
1032 {
1033 const bool isUBO = etConfig.evdConfig ? etConfig.evdConfig->isUBO : false;
1034 if (etConfig.evdConfig)
1035 {
1036 const auto &evdConfig = *etConfig.evdConfig;
1037 if (isUBO)
1038 {
1039 if (type.isArray())
1040 {
1041 mOut << "ANGLE_tensor<";
1042 }
1043 }
1044 if (evdConfig.isPointer)
1045 {
1046 mOut << toString(*evdConfig.isPointer);
1047 mOut << " ";
1048 }
1049 else if (evdConfig.isReference)
1050 {
1051 mOut << toString(*evdConfig.isReference);
1052 mOut << " ";
1053 }
1054 }
1055
1056 if (!isUBO)
1057 {
1058 if (type.isArray())
1059 {
1060 mOut << "ANGLE_tensor<";
1061 }
1062 }
1063
1064 if (type.isInterpolant())
1065 {
1066 mOut << "metal::interpolant<";
1067 }
1068
1069 if (type.isVector() || type.isMatrix())
1070 {
1071 mOut << "metal::";
1072 }
1073
1074 if (etConfig.evdConfig && etConfig.evdConfig->isPacked)
1075 {
1076 mOut << "packed_";
1077 }
1078
1079 emitBareTypeName(type, etConfig);
1080
1081 if (type.isVector())
1082 {
1083 mOut << static_cast<uint32_t>(type.getNominalSize());
1084 }
1085 else if (type.isMatrix())
1086 {
1087 mOut << static_cast<uint32_t>(type.getCols()) << "x"
1088 << static_cast<uint32_t>(type.getRows());
1089 }
1090
1091 if (type.isInterpolant())
1092 {
1093 mOut << ", metal::interpolation::";
1094 switch (type.getQualifier())
1095 {
1096 case EvqNoPerspectiveIn:
1097 case EvqNoPerspectiveCentroidIn:
1098 case EvqNoPerspectiveSampleIn:
1099 mOut << "no_";
1100 break;
1101 default:
1102 break;
1103 }
1104 mOut << "perspective>";
1105 }
1106
1107 if (!isUBO)
1108 {
1109 if (type.isArray())
1110 {
1111 for (auto size : type.getArraySizes())
1112 {
1113 mOut << ", " << size;
1114 }
1115 mOut << ">";
1116 }
1117 }
1118
1119 if (etConfig.evdConfig)
1120 {
1121 const auto &evdConfig = *etConfig.evdConfig;
1122 if (evdConfig.isPointer)
1123 {
1124 mOut << " *";
1125 }
1126 else if (evdConfig.isReference)
1127 {
1128 mOut << " &";
1129 }
1130 if (isUBO)
1131 {
1132 if (type.isArray())
1133 {
1134 for (auto size : type.getArraySizes())
1135 {
1136 mOut << ", " << size;
1137 }
1138 mOut << ">";
1139 }
1140 }
1141 }
1142 }
1143
emitFieldDeclaration(const TField & field,const TStructure & parent,FieldAnnotationIndices & annotationIndices)1144 void GenMetalTraverser::emitFieldDeclaration(const TField &field,
1145 const TStructure &parent,
1146 FieldAnnotationIndices &annotationIndices)
1147 {
1148 const TType &type = *field.type();
1149 const TBasicType basic = type.getBasicType();
1150
1151 EmitVariableDeclarationConfig evdConfig;
1152 evdConfig.emitPostQualifier = true;
1153 evdConfig.disableStructSpecifier = true;
1154 evdConfig.isPacked = mSymbolEnv.isPacked(field);
1155 evdConfig.isUBO = mSymbolEnv.isUBO(field);
1156 evdConfig.isPointer = mSymbolEnv.isPointer(field);
1157 evdConfig.isReference = mSymbolEnv.isReference(field);
1158 emitVariableDeclaration(VarDecl(field), evdConfig);
1159
1160 const TQualifier qual = type.getQualifier();
1161 switch (qual)
1162 {
1163 case TQualifier::EvqFlatIn:
1164 if (mPipelineStructs.fragmentIn.external == &parent)
1165 {
1166 mOut << " [[flat]]";
1167 TranslatorMetalReflection *reflection =
1168 mtl::getTranslatorMetalReflection(&mCompiler);
1169 reflection->hasFlatInput = true;
1170 }
1171 break;
1172
1173 case TQualifier::EvqNoPerspectiveIn:
1174 if (mPipelineStructs.fragmentIn.external == &parent && !type.isInterpolant())
1175 {
1176 mOut << " [[center_no_perspective]]";
1177 }
1178 break;
1179
1180 case TQualifier::EvqCentroidIn:
1181 if (mPipelineStructs.fragmentIn.external == &parent && !type.isInterpolant())
1182 {
1183 mOut << " [[centroid_perspective]]";
1184 }
1185 break;
1186
1187 case TQualifier::EvqSampleIn:
1188 if (mPipelineStructs.fragmentIn.external == &parent && !type.isInterpolant())
1189 {
1190 mOut << " [[sample_perspective]]";
1191 }
1192 break;
1193
1194 case TQualifier::EvqNoPerspectiveCentroidIn:
1195 if (mPipelineStructs.fragmentIn.external == &parent && !type.isInterpolant())
1196 {
1197 mOut << " [[centroid_no_perspective]]";
1198 }
1199 break;
1200
1201 case TQualifier::EvqNoPerspectiveSampleIn:
1202 if (mPipelineStructs.fragmentIn.external == &parent && !type.isInterpolant())
1203 {
1204 mOut << " [[sample_no_perspective]]";
1205 }
1206 break;
1207
1208 case TQualifier::EvqFragColor:
1209 mOut << " [[color(0)]]";
1210 break;
1211
1212 case TQualifier::EvqSecondaryFragColorEXT:
1213 case TQualifier::EvqSecondaryFragDataEXT:
1214 mOut << " [[color(0), index(1)]]";
1215 break;
1216
1217 case TQualifier::EvqFragmentOut:
1218 case TQualifier::EvqFragmentInOut:
1219 case TQualifier::EvqFragData:
1220 if (mPipelineStructs.fragmentOut.external == &parent ||
1221 mPipelineStructs.fragmentOut.externalExtra == &parent)
1222 {
1223 if ((type.isVector() &&
1224 (basic == TBasicType::EbtInt || basic == TBasicType::EbtUInt ||
1225 basic == TBasicType::EbtFloat)) ||
1226 qual == EvqFragData)
1227 {
1228 // The OpenGL ES 3.0 spec says locations must be specified
1229 // unless there is only a single output, in which case the
1230 // location is 0. So, when we get to this point the shader
1231 // will have been rejected if locations are not specified
1232 // and there is more than one output.
1233 const TLayoutQualifier &layoutQualifier = type.getLayoutQualifier();
1234 if (layoutQualifier.locationsSpecified)
1235 {
1236 mOut << " [[color(" << layoutQualifier.location << ")";
1237 ASSERT(layoutQualifier.index >= -1 && layoutQualifier.index <= 1);
1238 if (layoutQualifier.index == 1)
1239 {
1240 mOut << ", index(1)";
1241 }
1242 }
1243 else if (qual == EvqFragData)
1244 {
1245 mOut << " [[color(" << annotationIndices.color++ << ")";
1246 }
1247 else
1248 {
1249 // Either the only output or EXT_blend_func_extended is used;
1250 // actual assignment will happen in UpdateFragmentShaderOutputs.
1251 mOut << " [[" << sh::kUnassignedFragmentOutputString;
1252 }
1253 if (mRasterOrderGroupsSupported && qual == TQualifier::EvqFragmentInOut)
1254 {
1255 // Put fragment inouts in their own raster order group for better
1256 // parallelism.
1257 // NOTE: this is not required for the reads to be ordered and coherent.
1258 // TODO(anglebug.com/7279): Consider making raster order groups a PLS layout
1259 // qualifier?
1260 mOut << ", raster_order_group(0)";
1261 }
1262 mOut << "]]";
1263 }
1264 }
1265 break;
1266
1267 case TQualifier::EvqFragDepth:
1268 mOut << " [[depth(";
1269 switch (type.getLayoutQualifier().depth)
1270 {
1271 case EdGreater:
1272 mOut << "greater";
1273 break;
1274 case EdLess:
1275 mOut << "less";
1276 break;
1277 default:
1278 mOut << "any";
1279 break;
1280 }
1281 mOut << "), function_constant(" << sh::mtl::kDepthWriteEnabledConstName << ")]]";
1282 break;
1283
1284 case TQualifier::EvqSampleMask:
1285 if (field.symbolType() == SymbolType::AngleInternal)
1286 {
1287 mOut << " [[sample_mask, function_constant("
1288 << sh::mtl::kSampleMaskWriteEnabledConstName << ")]]";
1289 }
1290 break;
1291
1292 default:
1293 break;
1294 }
1295
1296 if (IsImage(type.getBasicType()))
1297 {
1298 if (type.getArraySizeProduct() != 1)
1299 {
1300 UNIMPLEMENTED();
1301 }
1302 mOut << " [[";
1303 if (type.getLayoutQualifier().rasterOrdered)
1304 {
1305 mOut << "raster_order_group(0), ";
1306 }
1307 mOut << "texture(" << mMainTextureIndex << ")]]";
1308 TranslatorMetalReflection *reflection = mtl::getTranslatorMetalReflection(&mCompiler);
1309 reflection->addRWTextureBinding(type.getLayoutQualifier().binding,
1310 static_cast<int>(mMainTextureIndex));
1311 ++mMainTextureIndex;
1312 }
1313 }
1314
BuildExternalAttributeIndexMap(const TCompiler & compiler,const PipelineScoped<TStructure> & structure)1315 static std::map<Name, size_t> BuildExternalAttributeIndexMap(
1316 const TCompiler &compiler,
1317 const PipelineScoped<TStructure> &structure)
1318 {
1319 ASSERT(structure.isTotallyFull());
1320
1321 const auto &shaderVars = compiler.getAttributes();
1322 const size_t shaderVarSize = shaderVars.size();
1323 size_t shaderVarIndex = 0;
1324
1325 const auto &externalFields = structure.external->fields();
1326 const size_t externalSize = externalFields.size();
1327 size_t externalIndex = 0;
1328
1329 const auto &internalFields = structure.internal->fields();
1330 const size_t internalSize = internalFields.size();
1331 size_t internalIndex = 0;
1332
1333 // Internal fields are never split. External fields are sometimes split.
1334 ASSERT(externalSize >= internalSize);
1335
1336 // Structures do not contain any inactive fields.
1337 ASSERT(shaderVarSize >= internalSize);
1338
1339 std::map<Name, size_t> externalNameToAttributeIndex;
1340 size_t attributeIndex = 0;
1341
1342 while (internalIndex < internalSize)
1343 {
1344 const TField &internalField = *internalFields[internalIndex];
1345 const Name internalName = Name(internalField);
1346 const TType &internalType = *internalField.type();
1347 while (internalName.rawName() != shaderVars[shaderVarIndex].name &&
1348 internalName.rawName() != shaderVars[shaderVarIndex].mappedName)
1349 {
1350 // This case represents an inactive field.
1351
1352 ++shaderVarIndex;
1353 ASSERT(shaderVarIndex < shaderVarSize);
1354
1355 ++attributeIndex; // TODO: Might need to increment more if shader var type is a matrix.
1356 }
1357
1358 const size_t cols =
1359 (internalType.isMatrix() && !externalFields[externalIndex]->type()->isMatrix())
1360 ? internalType.getCols()
1361 : 1;
1362
1363 for (size_t c = 0; c < cols; ++c)
1364 {
1365 const TField &externalField = *externalFields[externalIndex];
1366 const Name externalName = Name(externalField);
1367
1368 externalNameToAttributeIndex[externalName] = attributeIndex;
1369
1370 ++externalIndex;
1371 ++attributeIndex;
1372 }
1373
1374 ++shaderVarIndex;
1375 ++internalIndex;
1376 }
1377
1378 ASSERT(shaderVarIndex <= shaderVarSize);
1379 ASSERT(externalIndex <= externalSize); // less than if padding was introduced
1380 ASSERT(internalIndex == internalSize);
1381
1382 return externalNameToAttributeIndex;
1383 }
1384
emitAttributeDeclaration(const TField & field,FieldAnnotationIndices & annotationIndices)1385 void GenMetalTraverser::emitAttributeDeclaration(const TField &field,
1386 FieldAnnotationIndices &annotationIndices)
1387 {
1388 EmitVariableDeclarationConfig evdConfig;
1389 evdConfig.disableStructSpecifier = true;
1390 emitVariableDeclaration(VarDecl(field), evdConfig);
1391 mOut << sh::kUnassignedAttributeString;
1392 }
1393
emitUniformBufferDeclaration(const TField & field,FieldAnnotationIndices & annotationIndices)1394 void GenMetalTraverser::emitUniformBufferDeclaration(const TField &field,
1395 FieldAnnotationIndices &annotationIndices)
1396 {
1397 EmitVariableDeclarationConfig evdConfig;
1398 evdConfig.disableStructSpecifier = true;
1399 evdConfig.isUBO = mSymbolEnv.isUBO(field);
1400 evdConfig.isPointer = mSymbolEnv.isPointer(field);
1401 emitVariableDeclaration(VarDecl(field), evdConfig);
1402 mOut << "[[id(" << annotationIndices.attribute << ")]]";
1403
1404 const TType &type = *field.type();
1405 const int arraySize = type.isArray() ? type.getArraySizeProduct() : 1;
1406
1407 TranslatorMetalReflection *reflection = mtl::getTranslatorMetalReflection(&mCompiler);
1408 ASSERT(type.getBasicType() == TBasicType::EbtStruct);
1409 const TStructure *structure = type.getStruct();
1410 const std::string originalName = reflection->getOriginalName(structure->uniqueId().get());
1411 reflection->addUniformBufferBinding(
1412 originalName,
1413 {.bindIndex = annotationIndices.attribute, .arraySize = static_cast<size_t>(arraySize)});
1414
1415 annotationIndices.attribute += arraySize;
1416 }
1417
emitStructDeclaration(const TType & type)1418 void GenMetalTraverser::emitStructDeclaration(const TType &type)
1419 {
1420 ASSERT(type.getBasicType() == TBasicType::EbtStruct);
1421 ASSERT(type.isStructSpecifier());
1422
1423 mOut << "struct ";
1424 emitBareTypeName(type, {});
1425
1426 mOut << "\n";
1427 emitOpenBrace();
1428
1429 const TStructure &structure = *type.getStruct();
1430 std::map<Name, size_t> fieldToAttributeIndex;
1431 const bool hasAttributeIndices = mPipelineStructs.vertexIn.external == &structure;
1432 const bool hasUniformBufferIndicies = mPipelineStructs.uniformBuffers.external == &structure;
1433 const bool reclaimUnusedAttributeIndices = mCompiler.getShaderVersion() < 300;
1434
1435 if (hasAttributeIndices)
1436 {
1437 // When attribute aliasing is supported, external attribute struct is filled post-link.
1438 if (mCompiler.supportsAttributeAliasing())
1439 {
1440 mtl::getTranslatorMetalReflection(&mCompiler)->hasAttributeAliasing = true;
1441 mOut << "@@Attrib-Bindings@@\n";
1442 emitCloseBrace();
1443 return;
1444 }
1445
1446 fieldToAttributeIndex =
1447 BuildExternalAttributeIndexMap(mCompiler, mPipelineStructs.vertexIn);
1448 }
1449
1450 FieldAnnotationIndices annotationIndices;
1451
1452 for (const TField *field : structure.fields())
1453 {
1454 emitIndentation();
1455 if (hasAttributeIndices)
1456 {
1457 const auto it = fieldToAttributeIndex.find(Name(*field));
1458 if (it == fieldToAttributeIndex.end())
1459 {
1460 ASSERT(field->symbolType() == SymbolType::AngleInternal);
1461 ASSERT(field->name().beginsWith("_"));
1462 ASSERT(angle::EndsWith(field->name().data(), "_pad"));
1463 emitFieldDeclaration(*field, structure, annotationIndices);
1464 }
1465 else
1466 {
1467 ASSERT(field->symbolType() != SymbolType::AngleInternal ||
1468 !field->name().beginsWith("_") ||
1469 !angle::EndsWith(field->name().data(), "_pad"));
1470 if (!reclaimUnusedAttributeIndices)
1471 {
1472 annotationIndices.attribute = it->second;
1473 }
1474 emitAttributeDeclaration(*field, annotationIndices);
1475 }
1476 }
1477 else if (hasUniformBufferIndicies)
1478 {
1479 emitUniformBufferDeclaration(*field, annotationIndices);
1480 }
1481 else
1482 {
1483 emitFieldDeclaration(*field, structure, annotationIndices);
1484 }
1485 mOut << ";\n";
1486 }
1487
1488 emitCloseBrace();
1489 }
1490
emitOrdinaryVariableDeclaration(const VarDecl & decl,const EmitVariableDeclarationConfig & evdConfig)1491 void GenMetalTraverser::emitOrdinaryVariableDeclaration(
1492 const VarDecl &decl,
1493 const EmitVariableDeclarationConfig &evdConfig)
1494 {
1495 EmitTypeConfig etConfig;
1496 etConfig.evdConfig = &evdConfig;
1497
1498 const TType &type = decl.type();
1499 if (type.getQualifier() == TQualifier::EvqClipDistance)
1500 {
1501 // Clip distance output uses float[n] type instead of metal::array.
1502 // The element count is emitted after the post qualifier.
1503 ASSERT(type.getBasicType() == TBasicType::EbtFloat);
1504 mOut << "float";
1505 }
1506 else if (type.getQualifier() == TQualifier::EvqSampleID && evdConfig.isMainParameter)
1507 {
1508 // Metal's [[sample_id]] must be unsigned
1509 ASSERT(type.getBasicType() == TBasicType::EbtInt);
1510 mOut << "uint32_t";
1511 }
1512 else
1513 {
1514 emitType(type, etConfig);
1515 }
1516 if (decl.symbolType() != SymbolType::Empty)
1517 {
1518 mOut << " ";
1519 emitNameOf(decl);
1520 }
1521 }
1522
emitVariableDeclaration(const VarDecl & decl,const EmitVariableDeclarationConfig & evdConfig)1523 void GenMetalTraverser::emitVariableDeclaration(const VarDecl &decl,
1524 const EmitVariableDeclarationConfig &evdConfig)
1525 {
1526 const SymbolType symbolType = decl.symbolType();
1527 const TType &type = decl.type();
1528 const TBasicType basicType = type.getBasicType();
1529
1530 switch (basicType)
1531 {
1532 case TBasicType::EbtStruct:
1533 {
1534 if (type.isStructSpecifier() && !evdConfig.disableStructSpecifier)
1535 {
1536 // It's invalid to declare a struct inside a function argument. When emitting a
1537 // function parameter, the callsite should set evdConfig.disableStructSpecifier.
1538 ASSERT(!evdConfig.isParameter);
1539 emitStructDeclaration(type);
1540 if (symbolType != SymbolType::Empty)
1541 {
1542 mOut << " ";
1543 emitNameOf(decl);
1544 }
1545 }
1546 else
1547 {
1548 emitOrdinaryVariableDeclaration(decl, evdConfig);
1549 }
1550 }
1551 break;
1552
1553 default:
1554 {
1555 ASSERT(symbolType != SymbolType::Empty || evdConfig.isParameter);
1556 emitOrdinaryVariableDeclaration(decl, evdConfig);
1557 }
1558 }
1559
1560 if (evdConfig.emitPostQualifier)
1561 {
1562 emitPostQualifier(evdConfig, decl, type.getQualifier());
1563 }
1564 }
1565
visitSymbol(TIntermSymbol * symbolNode)1566 void GenMetalTraverser::visitSymbol(TIntermSymbol *symbolNode)
1567 {
1568 const TVariable &var = symbolNode->variable();
1569 const TType &type = var.getType();
1570 ASSERT(var.symbolType() != SymbolType::Empty);
1571
1572 if (type.getBasicType() == TBasicType::EbtVoid)
1573 {
1574 mOut << "/*";
1575 emitNameOf(var);
1576 mOut << "*/";
1577 }
1578 else
1579 {
1580 emitNameOf(var);
1581 }
1582 }
1583
emitSingleConstant(const TConstantUnion * const constUnion)1584 void GenMetalTraverser::emitSingleConstant(const TConstantUnion *const constUnion)
1585 {
1586 switch (constUnion->getType())
1587 {
1588 case TBasicType::EbtBool:
1589 {
1590 mOut << (constUnion->getBConst() ? "true" : "false");
1591 }
1592 break;
1593
1594 case TBasicType::EbtFloat:
1595 {
1596 float value = constUnion->getFConst();
1597 if (std::isnan(value))
1598 {
1599 mOut << "NAN";
1600 }
1601 else if (std::isinf(value))
1602 {
1603 if (value < 0)
1604 {
1605 mOut << "-";
1606 }
1607 mOut << "INFINITY";
1608 }
1609 else
1610 {
1611 mOut << value << "f";
1612 }
1613 }
1614 break;
1615
1616 case TBasicType::EbtInt:
1617 {
1618 mOut << constUnion->getIConst();
1619 }
1620 break;
1621
1622 case TBasicType::EbtUInt:
1623 {
1624 mOut << constUnion->getUConst() << "u";
1625 }
1626 break;
1627
1628 default:
1629 {
1630 UNIMPLEMENTED();
1631 }
1632 }
1633 }
1634
emitConstantUnionArray(const TConstantUnion * const constUnion,const size_t size)1635 const TConstantUnion *GenMetalTraverser::emitConstantUnionArray(
1636 const TConstantUnion *const constUnion,
1637 const size_t size)
1638 {
1639 const TConstantUnion *constUnionIterated = constUnion;
1640 for (size_t i = 0; i < size; i++, constUnionIterated++)
1641 {
1642 emitSingleConstant(constUnionIterated);
1643
1644 if (i != size - 1)
1645 {
1646 mOut << ", ";
1647 }
1648 }
1649 return constUnionIterated;
1650 }
1651
emitConstantUnion(const TType & type,const TConstantUnion * constUnionBegin)1652 const TConstantUnion *GenMetalTraverser::emitConstantUnion(const TType &type,
1653 const TConstantUnion *constUnionBegin)
1654 {
1655 const TConstantUnion *constUnionCurr = constUnionBegin;
1656 const TStructure *structure = type.getStruct();
1657 if (structure)
1658 {
1659 EmitTypeConfig config = EmitTypeConfig{nullptr};
1660 emitType(type, config);
1661 mOut << "{";
1662 const TFieldList &fields = structure->fields();
1663 for (size_t i = 0; i < fields.size(); ++i)
1664 {
1665 const TType *fieldType = fields[i]->type();
1666 constUnionCurr = emitConstantUnion(*fieldType, constUnionCurr);
1667 if (i != fields.size() - 1)
1668 {
1669 mOut << ", ";
1670 }
1671 }
1672 mOut << "}";
1673 }
1674 else
1675 {
1676 size_t size = type.getObjectSize();
1677 bool writeType = size > 1;
1678 if (writeType)
1679 {
1680 EmitTypeConfig config = EmitTypeConfig{nullptr};
1681 emitType(type, config);
1682 mOut << "(";
1683 }
1684 constUnionCurr = emitConstantUnionArray(constUnionCurr, size);
1685 if (writeType)
1686 {
1687 mOut << ")";
1688 }
1689 }
1690 return constUnionCurr;
1691 }
1692
visitConstantUnion(TIntermConstantUnion * constValueNode)1693 void GenMetalTraverser::visitConstantUnion(TIntermConstantUnion *constValueNode)
1694 {
1695 emitConstantUnion(constValueNode->getType(), constValueNode->getConstantValue());
1696 }
1697
visitSwizzle(Visit,TIntermSwizzle * swizzleNode)1698 bool GenMetalTraverser::visitSwizzle(Visit, TIntermSwizzle *swizzleNode)
1699 {
1700 groupedTraverse(*swizzleNode->getOperand());
1701 mOut << ".";
1702
1703 {
1704 #if defined(ANGLE_ENABLE_ASSERTS)
1705 DebugSink::EscapedSink escapedOut(mOut.escape());
1706 TInfoSinkBase &out = escapedOut.get();
1707 #else
1708 TInfoSinkBase &out = mOut;
1709 #endif
1710 swizzleNode->writeOffsetsAsXYZW(&out);
1711 }
1712
1713 return false;
1714 }
1715
getDirectField(const TFieldListCollection & fieldListCollection,const TConstantUnion & index)1716 const TField &GenMetalTraverser::getDirectField(const TFieldListCollection &fieldListCollection,
1717 const TConstantUnion &index)
1718 {
1719 ASSERT(index.getType() == TBasicType::EbtInt);
1720
1721 const TFieldList &fieldList = fieldListCollection.fields();
1722 const int indexVal = index.getIConst();
1723 const TField &field = *fieldList[indexVal];
1724
1725 return field;
1726 }
1727
getDirectField(const TIntermTyped & fieldsNode,TIntermTyped & indexNode)1728 const TField &GenMetalTraverser::getDirectField(const TIntermTyped &fieldsNode,
1729 TIntermTyped &indexNode)
1730 {
1731 const TType &fieldsType = fieldsNode.getType();
1732
1733 const TFieldListCollection *fieldListCollection = fieldsType.getStruct();
1734 if (fieldListCollection == nullptr)
1735 {
1736 fieldListCollection = fieldsType.getInterfaceBlock();
1737 }
1738 ASSERT(fieldListCollection);
1739
1740 const TIntermConstantUnion *indexNode_ = indexNode.getAsConstantUnion();
1741 ASSERT(indexNode_);
1742 const TConstantUnion &index = *indexNode_->getConstantValue();
1743
1744 return getDirectField(*fieldListCollection, index);
1745 }
1746
visitBinary(Visit,TIntermBinary * binaryNode)1747 bool GenMetalTraverser::visitBinary(Visit, TIntermBinary *binaryNode)
1748 {
1749 const TOperator op = binaryNode->getOp();
1750 TIntermTyped &leftNode = *binaryNode->getLeft();
1751 TIntermTyped &rightNode = *binaryNode->getRight();
1752
1753 switch (op)
1754 {
1755 case TOperator::EOpIndexDirectStruct:
1756 case TOperator::EOpIndexDirectInterfaceBlock:
1757 {
1758 const TField &field = getDirectField(leftNode, rightNode);
1759 if (mSymbolEnv.isPointer(field) && mSymbolEnv.isUBO(field))
1760 {
1761 emitOpeningPointerParen();
1762 }
1763 groupedTraverse(leftNode);
1764 if (!mSymbolEnv.isPointer(field))
1765 {
1766 emitClosingPointerParen();
1767 }
1768 mOut << ".";
1769 emitNameOf(field);
1770 }
1771 break;
1772
1773 case TOperator::EOpIndexDirect:
1774 case TOperator::EOpIndexIndirect:
1775 {
1776 TType leftType = leftNode.getType();
1777 groupedTraverse(leftNode);
1778 mOut << "[";
1779 const TConstantUnion *constIndex = rightNode.getConstantValue();
1780 // TODO(anglebug.com/8491): Convert type and bound checks to
1781 // assertions after AST validation is enabled for MSL translation.
1782 if (!leftType.isUnsizedArray() && constIndex != nullptr &&
1783 constIndex->getType() == EbtInt && constIndex->getIConst() >= 0 &&
1784 constIndex->getIConst() < static_cast<int>(leftType.isArray()
1785 ? leftType.getOutermostArraySize()
1786 : leftType.getNominalSize()))
1787 {
1788 emitSingleConstant(constIndex);
1789 }
1790 else
1791 {
1792 mOut << "ANGLE_int_clamp(";
1793 groupedTraverse(rightNode);
1794 mOut << ", 0, ";
1795 if (leftType.isUnsizedArray())
1796 {
1797 groupedTraverse(leftNode);
1798 mOut << ".size()";
1799 }
1800 else
1801 {
1802 uint32_t maxSize;
1803 if (leftType.isArray())
1804 {
1805 maxSize = leftType.getOutermostArraySize() - 1;
1806 }
1807 else
1808 {
1809 maxSize = leftType.getNominalSize() - 1;
1810 }
1811 mOut << maxSize;
1812 }
1813 mOut << ")";
1814 }
1815 mOut << "]";
1816 }
1817 break;
1818
1819 default:
1820 {
1821 const TType &resultType = binaryNode->getType();
1822 const TType &leftType = leftNode.getType();
1823 const TType &rightType = rightNode.getType();
1824
1825 if (IsSymbolicOperator(op, resultType, &leftType, &rightType))
1826 {
1827 groupedTraverse(leftNode);
1828 if (op != TOperator::EOpComma)
1829 {
1830 mOut << " ";
1831 }
1832 else
1833 {
1834 emitClosingPointerParen();
1835 }
1836 mOut << GetOperatorString(op, resultType, &leftType, &rightType, nullptr) << " ";
1837 groupedTraverse(rightNode);
1838 }
1839 else
1840 {
1841 emitClosingPointerParen();
1842 mOut << GetOperatorString(op, resultType, &leftType, &rightType, nullptr) << "(";
1843 leftNode.traverse(this);
1844 mOut << ", ";
1845 rightNode.traverse(this);
1846 mOut << ")";
1847 }
1848 }
1849 }
1850
1851 return false;
1852 }
1853
IsPostfix(TOperator op)1854 static bool IsPostfix(TOperator op)
1855 {
1856 switch (op)
1857 {
1858 case TOperator::EOpPostIncrement:
1859 case TOperator::EOpPostDecrement:
1860 return true;
1861
1862 default:
1863 return false;
1864 }
1865 }
1866
visitUnary(Visit,TIntermUnary * unaryNode)1867 bool GenMetalTraverser::visitUnary(Visit, TIntermUnary *unaryNode)
1868 {
1869 const TOperator op = unaryNode->getOp();
1870 const TType &resultType = unaryNode->getType();
1871
1872 TIntermTyped &arg = *unaryNode->getOperand();
1873 const TType &argType = arg.getType();
1874
1875 const char *name = GetOperatorString(op, resultType, &argType, nullptr, nullptr);
1876
1877 if (IsSymbolicOperator(op, resultType, &argType, nullptr))
1878 {
1879 const bool postfix = IsPostfix(op);
1880 if (!postfix)
1881 {
1882 mOut << name;
1883 }
1884 groupedTraverse(arg);
1885 if (postfix)
1886 {
1887 mOut << name;
1888 }
1889 }
1890 else
1891 {
1892 mOut << name << "(";
1893 arg.traverse(this);
1894 mOut << ")";
1895 }
1896
1897 return false;
1898 }
1899
visitTernary(Visit,TIntermTernary * conditionalNode)1900 bool GenMetalTraverser::visitTernary(Visit, TIntermTernary *conditionalNode)
1901 {
1902 groupedTraverse(*conditionalNode->getCondition());
1903 mOut << " ? ";
1904 groupedTraverse(*conditionalNode->getTrueExpression());
1905 mOut << " : ";
1906 groupedTraverse(*conditionalNode->getFalseExpression());
1907
1908 return false;
1909 }
1910
visitIfElse(Visit,TIntermIfElse * ifThenElseNode)1911 bool GenMetalTraverser::visitIfElse(Visit, TIntermIfElse *ifThenElseNode)
1912 {
1913 TIntermTyped &condNode = *ifThenElseNode->getCondition();
1914 TIntermBlock *thenNode = ifThenElseNode->getTrueBlock();
1915 TIntermBlock *elseNode = ifThenElseNode->getFalseBlock();
1916
1917 emitIndentation();
1918 mOut << "if (";
1919 condNode.traverse(this);
1920 mOut << ")";
1921
1922 if (thenNode)
1923 {
1924 mOut << "\n";
1925 thenNode->traverse(this);
1926 }
1927 else
1928 {
1929 mOut << " {}";
1930 }
1931
1932 if (elseNode)
1933 {
1934 mOut << "\n";
1935 emitIndentation();
1936 mOut << "else\n";
1937 elseNode->traverse(this);
1938 }
1939 else
1940 {
1941 // Always emit "else" even when empty block to avoid nested if-stmt issues.
1942 mOut << " else {}";
1943 }
1944
1945 return false;
1946 }
1947
visitSwitch(Visit,TIntermSwitch * switchNode)1948 bool GenMetalTraverser::visitSwitch(Visit, TIntermSwitch *switchNode)
1949 {
1950 emitIndentation();
1951 mOut << "switch (";
1952 switchNode->getInit()->traverse(this);
1953 mOut << ")\n";
1954
1955 ASSERT(!mParentIsSwitch);
1956 mParentIsSwitch = true;
1957 switchNode->getStatementList()->traverse(this);
1958 mParentIsSwitch = false;
1959
1960 return false;
1961 }
1962
visitCase(Visit,TIntermCase * caseNode)1963 bool GenMetalTraverser::visitCase(Visit, TIntermCase *caseNode)
1964 {
1965 emitIndentation();
1966
1967 if (caseNode->hasCondition())
1968 {
1969 TIntermTyped *condExpr = caseNode->getCondition();
1970 mOut << "case ";
1971 condExpr->traverse(this);
1972 mOut << ":";
1973 }
1974 else
1975 {
1976 mOut << "default:\n";
1977 }
1978
1979 return false;
1980 }
1981
emitFunctionSignature(const TFunction & func)1982 void GenMetalTraverser::emitFunctionSignature(const TFunction &func)
1983 {
1984 const bool isMain = func.isMain();
1985
1986 emitFunctionReturn(func);
1987
1988 mOut << " ";
1989 emitNameOf(func);
1990 if (isMain)
1991 {
1992 mOut << "0";
1993 }
1994 mOut << "(";
1995
1996 bool emitComma = false;
1997 const size_t paramCount = func.getParamCount();
1998 for (size_t i = 0; i < paramCount; ++i)
1999 {
2000 if (emitComma)
2001 {
2002 mOut << ", ";
2003 }
2004 emitComma = true;
2005
2006 const TVariable ¶m = *func.getParam(i);
2007 emitFunctionParameter(func, param);
2008 }
2009
2010 if (isTraversingVertexMain)
2011 {
2012 mOut << " @@XFB-Bindings@@ ";
2013 }
2014
2015 mOut << ")";
2016 }
2017
emitFunctionReturn(const TFunction & func)2018 void GenMetalTraverser::emitFunctionReturn(const TFunction &func)
2019 {
2020 const bool isMain = func.isMain();
2021 bool isVertexMain = false;
2022 const TType &returnType = func.getReturnType();
2023 if (isMain)
2024 {
2025 const TStructure *structure = returnType.getStruct();
2026 if (mPipelineStructs.fragmentOut.matches(*structure))
2027 {
2028 if (mCompiler.isEarlyFragmentTestsSpecified())
2029 {
2030 mOut << "[[early_fragment_tests]]\n";
2031 }
2032 mOut << "fragment ";
2033 }
2034 else if (mPipelineStructs.vertexOut.matches(*structure))
2035 {
2036 ASSERT(structure != nullptr);
2037 mOut << "vertex __VERTEX_OUT(";
2038 isVertexMain = true;
2039 }
2040 else
2041 {
2042 UNIMPLEMENTED();
2043 }
2044 }
2045 emitType(returnType, EmitTypeConfig());
2046 if (isVertexMain)
2047 mOut << ") ";
2048 }
2049
emitFunctionParameter(const TFunction & func,const TVariable & param)2050 void GenMetalTraverser::emitFunctionParameter(const TFunction &func, const TVariable ¶m)
2051 {
2052 const bool isMain = func.isMain();
2053
2054 const TType &type = param.getType();
2055 const TStructure *structure = type.getStruct();
2056
2057 EmitVariableDeclarationConfig evdConfig;
2058 evdConfig.isParameter = true;
2059 evdConfig.disableStructSpecifier = true; // It's invalid to declare a struct in a function arg.
2060 evdConfig.isMainParameter = isMain;
2061 evdConfig.emitPostQualifier = isMain;
2062 evdConfig.isUBO = mSymbolEnv.isUBO(param);
2063 evdConfig.isPointer = mSymbolEnv.isPointer(param);
2064 evdConfig.isReference = mSymbolEnv.isReference(param);
2065 emitVariableDeclaration(VarDecl(param), evdConfig);
2066
2067 if (isMain)
2068 {
2069 TranslatorMetalReflection *reflection = mtl::getTranslatorMetalReflection(&mCompiler);
2070 if (structure)
2071 {
2072 if (mPipelineStructs.fragmentIn.matches(*structure) ||
2073 mPipelineStructs.vertexIn.matches(*structure))
2074 {
2075 mOut << " [[stage_in]]";
2076 }
2077 else if (mPipelineStructs.angleUniforms.matches(*structure))
2078 {
2079 mOut << " [[buffer(" << mDriverUniformsBindingIndex << ")]]";
2080 }
2081 else if (mPipelineStructs.uniformBuffers.matches(*structure))
2082 {
2083 mOut << " [[buffer(" << mUBOArgumentBufferBindingIndex << ")]]";
2084 reflection->hasUBOs = true;
2085 }
2086 else if (mPipelineStructs.userUniforms.matches(*structure))
2087 {
2088 mOut << " [[buffer(" << mMainUniformBufferIndex << ")]]";
2089 reflection->addUserUniformBufferBinding(param.name().data(),
2090 mMainUniformBufferIndex);
2091 mMainUniformBufferIndex += type.getArraySizeProduct();
2092 }
2093 else if (structure->name() == "metal::sampler")
2094 {
2095 mOut << " [[sampler(" << (mMainSamplerIndex) << ")]]";
2096 const std::string originalName =
2097 reflection->getOriginalName(param.uniqueId().get());
2098 reflection->addSamplerBinding(originalName, mMainSamplerIndex);
2099 mMainSamplerIndex += type.getArraySizeProduct();
2100 }
2101 }
2102 else if (IsSampler(type.getBasicType()))
2103 {
2104 mOut << " [[texture(" << (mMainTextureIndex) << ")]]";
2105 const std::string originalName = reflection->getOriginalName(param.uniqueId().get());
2106 reflection->addTextureBinding(originalName, mMainTextureIndex);
2107 mMainTextureIndex += type.getArraySizeProduct();
2108 }
2109 else if (Name(param) == Pipeline{Pipeline::Type::InstanceId, nullptr}.getStructInstanceName(
2110 Pipeline::Variant::Modified))
2111 {
2112 mOut << " [[instance_id]]";
2113 }
2114 else if (Name(param) == kBaseInstanceName)
2115 {
2116 mOut << " [[base_instance]]";
2117 }
2118 }
2119 }
2120
visitFunctionPrototype(TIntermFunctionPrototype * funcProtoNode)2121 void GenMetalTraverser::visitFunctionPrototype(TIntermFunctionPrototype *funcProtoNode)
2122 {
2123 const TFunction &func = *funcProtoNode->getFunction();
2124
2125 emitIndentation();
2126 emitFunctionSignature(func);
2127 }
2128
visitFunctionDefinition(Visit,TIntermFunctionDefinition * funcDefNode)2129 bool GenMetalTraverser::visitFunctionDefinition(Visit, TIntermFunctionDefinition *funcDefNode)
2130 {
2131 const TFunction &func = *funcDefNode->getFunction();
2132 TIntermBlock &body = *funcDefNode->getBody();
2133 if (func.isMain())
2134 {
2135 const TType &returnType = func.getReturnType();
2136 const TStructure *structure = returnType.getStruct();
2137 isTraversingVertexMain = (mPipelineStructs.vertexOut.matches(*structure)) &&
2138 mCompiler.getShaderType() == GL_VERTEX_SHADER;
2139 }
2140 emitIndentation();
2141 emitFunctionSignature(func);
2142 mOut << "\n";
2143 body.traverse(this);
2144 if (isTraversingVertexMain)
2145 {
2146 isTraversingVertexMain = false;
2147 }
2148 return false;
2149 }
2150
BuildFuncToName()2151 GenMetalTraverser::FuncToName GenMetalTraverser::BuildFuncToName()
2152 {
2153 FuncToName map;
2154
2155 auto putAngle = [&](const char *nameStr) {
2156 const ImmutableString name(nameStr);
2157 ASSERT(map.find(name) == map.end());
2158 map[name] = Name(nameStr, SymbolType::AngleInternal);
2159 };
2160
2161 putAngle("texelFetch");
2162 putAngle("texelFetchOffset");
2163 putAngle("texture");
2164 putAngle("texture1D");
2165 putAngle("texture1DLod");
2166 putAngle("texture1DProjLod");
2167 putAngle("texture2D");
2168 putAngle("texture2DGradEXT");
2169 putAngle("texture2DLod");
2170 putAngle("texture2DLodEXT");
2171 putAngle("texture2DProj");
2172 putAngle("texture2DProjGradEXT");
2173 putAngle("texture2DProjLod");
2174 putAngle("texture2DProjLodEXT");
2175 putAngle("texture2DRect");
2176 putAngle("texture2DRectProj");
2177 putAngle("texture3D");
2178 putAngle("texture3DLod");
2179 putAngle("texture3DProjLod");
2180 putAngle("textureCube");
2181 putAngle("textureCubeGradEXT");
2182 putAngle("textureCubeLod");
2183 putAngle("textureCubeLodEXT");
2184 putAngle("textureCubeProjLod");
2185 putAngle("textureGrad");
2186 putAngle("textureGradOffset");
2187 putAngle("textureLod");
2188 putAngle("textureLodOffset");
2189 putAngle("textureOffset");
2190 putAngle("textureProj");
2191 putAngle("textureProjGrad");
2192 putAngle("textureProjGradOffset");
2193 putAngle("textureProjLod");
2194 putAngle("textureProjLodOffset");
2195 putAngle("textureProjOffset");
2196 putAngle("textureSize");
2197 putAngle("imageLoad");
2198 putAngle("imageStore");
2199 putAngle("memoryBarrierImage");
2200
2201 return map;
2202 }
2203
visitAggregate(Visit,TIntermAggregate * aggregateNode)2204 bool GenMetalTraverser::visitAggregate(Visit, TIntermAggregate *aggregateNode)
2205 {
2206 const TIntermSequence &args = *aggregateNode->getSequence();
2207
2208 auto emitArgList = [&](const char *open, const char *close) {
2209 mOut << open;
2210
2211 bool emitComma = false;
2212 for (TIntermNode *arg : args)
2213 {
2214 if (emitComma)
2215 {
2216 emitClosingPointerParen();
2217 mOut << ", ";
2218 }
2219 emitComma = true;
2220 arg->traverse(this);
2221 }
2222
2223 mOut << close;
2224 };
2225
2226 const TType &retType = aggregateNode->getType();
2227
2228 if (aggregateNode->isConstructor())
2229 {
2230 const bool isStandalone = getParentNode()->getAsBlock();
2231 if (isStandalone)
2232 {
2233 // Prevent constructor from being interpreted as a declaration by wrapping in parens.
2234 // This can happen if given something like:
2235 // int(symbol); // <- This will be treated like `int symbol;`... don't want that.
2236 // So instead emit:
2237 // (int(symbol));
2238 mOut << "(";
2239 }
2240
2241 const EmitTypeConfig etConfig;
2242
2243 if (retType.isArray())
2244 {
2245 emitType(retType, etConfig);
2246 emitArgList("{", "}");
2247 }
2248 else if (retType.getStruct())
2249 {
2250 emitType(retType, etConfig);
2251 emitArgList("{", "}");
2252 }
2253 else
2254 {
2255 emitType(retType, etConfig);
2256 emitArgList("(", ")");
2257 }
2258
2259 if (isStandalone)
2260 {
2261 mOut << ")";
2262 }
2263
2264 return false;
2265 }
2266 else
2267 {
2268 const TOperator op = aggregateNode->getOp();
2269 switch (op)
2270 {
2271 case TOperator::EOpCallFunctionInAST:
2272 case TOperator::EOpCallInternalRawFunction:
2273 {
2274 const TFunction &func = *aggregateNode->getFunction();
2275 emitNameOf(func);
2276 //'@' symbol in name specifices a macro substitution marker.
2277 if (!func.name().contains("@"))
2278 {
2279 emitArgList("(", ")");
2280 }
2281 else
2282 {
2283 mTemporarilyDisableSemicolon =
2284 true; // Disable semicolon for macro substitution.
2285 }
2286 return false;
2287 }
2288
2289 default:
2290 {
2291 auto getArgType = [&](size_t index) -> const TType * {
2292 if (index < args.size())
2293 {
2294 TIntermTyped *arg = args[index]->getAsTyped();
2295 ASSERT(arg);
2296 return &arg->getType();
2297 }
2298 return nullptr;
2299 };
2300
2301 const TType *argType0 = getArgType(0);
2302 const TType *argType1 = getArgType(1);
2303 const TType *argType2 = getArgType(2);
2304
2305 const char *opName = GetOperatorString(op, retType, argType0, argType1, argType2);
2306
2307 if (IsSymbolicOperator(op, retType, argType0, argType1))
2308 {
2309 switch (args.size())
2310 {
2311 case 1:
2312 {
2313 TIntermNode &operandNode = *aggregateNode->getChildNode(0);
2314 if (IsPostfix(op))
2315 {
2316 mOut << opName;
2317 groupedTraverse(operandNode);
2318 }
2319 else
2320 {
2321 groupedTraverse(operandNode);
2322 mOut << opName;
2323 }
2324 return false;
2325 }
2326
2327 case 2:
2328 {
2329 TIntermNode &leftNode = *aggregateNode->getChildNode(0);
2330 TIntermNode &rightNode = *aggregateNode->getChildNode(1);
2331 groupedTraverse(leftNode);
2332 mOut << " " << opName << " ";
2333 groupedTraverse(rightNode);
2334 return false;
2335 }
2336
2337 default:
2338 UNREACHABLE();
2339 return false;
2340 }
2341 }
2342 else if (opName == nullptr)
2343 {
2344 const TFunction &func = *aggregateNode->getFunction();
2345 auto it = mFuncToName.find(func.name());
2346 ASSERT(it != mFuncToName.end());
2347 EmitName(mOut, it->second);
2348 emitArgList("(", ")");
2349 return false;
2350 }
2351 else
2352 {
2353 mOut << opName;
2354 emitArgList("(", ")");
2355 return false;
2356 }
2357 }
2358 }
2359 }
2360 }
2361
emitOpenBrace()2362 void GenMetalTraverser::emitOpenBrace()
2363 {
2364 ASSERT(mIndentLevel >= 0);
2365
2366 emitIndentation();
2367 mOut << "{\n";
2368 ++mIndentLevel;
2369 }
2370
emitCloseBrace()2371 void GenMetalTraverser::emitCloseBrace()
2372 {
2373 ASSERT(mIndentLevel >= 1);
2374
2375 --mIndentLevel;
2376 emitIndentation();
2377 mOut << "}";
2378 }
2379
RequiresSemicolonTerminator(TIntermNode & node)2380 static bool RequiresSemicolonTerminator(TIntermNode &node)
2381 {
2382 if (node.getAsBlock())
2383 {
2384 return false;
2385 }
2386 if (node.getAsLoopNode())
2387 {
2388 return false;
2389 }
2390 if (node.getAsSwitchNode())
2391 {
2392 return false;
2393 }
2394 if (node.getAsIfElseNode())
2395 {
2396 return false;
2397 }
2398 if (node.getAsFunctionDefinition())
2399 {
2400 return false;
2401 }
2402 if (node.getAsCaseNode())
2403 {
2404 return false;
2405 }
2406
2407 return true;
2408 }
2409
NewlinePad(TIntermNode & node)2410 static bool NewlinePad(TIntermNode &node)
2411 {
2412 if (node.getAsFunctionDefinition())
2413 {
2414 return true;
2415 }
2416 if (TIntermDeclaration *declNode = node.getAsDeclarationNode())
2417 {
2418 ASSERT(declNode->getChildCount() == 1);
2419 TIntermNode &childNode = *declNode->getChildNode(0);
2420 if (TIntermSymbol *symbolNode = childNode.getAsSymbolNode())
2421 {
2422 const TVariable &var = symbolNode->variable();
2423 return var.getType().isStructSpecifier();
2424 }
2425 return false;
2426 }
2427 return false;
2428 }
2429
visitBlock(Visit,TIntermBlock * blockNode)2430 bool GenMetalTraverser::visitBlock(Visit, TIntermBlock *blockNode)
2431 {
2432 ASSERT(mIndentLevel >= -1);
2433 const bool isGlobalScope = mIndentLevel == -1;
2434 const bool parentIsSwitch = mParentIsSwitch;
2435 mParentIsSwitch = false;
2436
2437 if (isGlobalScope)
2438 {
2439 ++mIndentLevel;
2440 }
2441 else
2442 {
2443 emitOpenBrace();
2444 if (parentIsSwitch)
2445 {
2446 ++mIndentLevel;
2447 }
2448 }
2449
2450 TIntermNode *prevStmtNode = nullptr;
2451
2452 const size_t stmtCount = blockNode->getChildCount();
2453 for (size_t i = 0; i < stmtCount; ++i)
2454 {
2455 TIntermNode &stmtNode = *blockNode->getChildNode(i);
2456
2457 if (isGlobalScope && prevStmtNode && (NewlinePad(*prevStmtNode) || NewlinePad(stmtNode)))
2458 {
2459 mOut << "\n";
2460 }
2461 const bool isCase = stmtNode.getAsCaseNode();
2462 mIndentLevel -= isCase;
2463 emitIndentation();
2464 mIndentLevel += isCase;
2465 stmtNode.traverse(this);
2466 if (RequiresSemicolonTerminator(stmtNode) && !mTemporarilyDisableSemicolon)
2467 {
2468 mOut << ";";
2469 }
2470 mTemporarilyDisableSemicolon = false;
2471 mOut << "\n";
2472
2473 prevStmtNode = &stmtNode;
2474 }
2475
2476 if (isGlobalScope)
2477 {
2478 ASSERT(mIndentLevel == 0);
2479 --mIndentLevel;
2480 }
2481 else
2482 {
2483 if (parentIsSwitch)
2484 {
2485 ASSERT(mIndentLevel >= 1);
2486 --mIndentLevel;
2487 }
2488 emitCloseBrace();
2489 mParentIsSwitch = parentIsSwitch;
2490 }
2491
2492 return false;
2493 }
2494
visitGlobalQualifierDeclaration(Visit,TIntermGlobalQualifierDeclaration *)2495 bool GenMetalTraverser::visitGlobalQualifierDeclaration(Visit, TIntermGlobalQualifierDeclaration *)
2496 {
2497 return false;
2498 }
2499
visitDeclaration(Visit,TIntermDeclaration * declNode)2500 bool GenMetalTraverser::visitDeclaration(Visit, TIntermDeclaration *declNode)
2501 {
2502 ASSERT(declNode->getChildCount() == 1);
2503 TIntermNode &node = *declNode->getChildNode(0);
2504
2505 EmitVariableDeclarationConfig evdConfig;
2506
2507 if (TIntermSymbol *symbolNode = node.getAsSymbolNode())
2508 {
2509 const TVariable &var = symbolNode->variable();
2510 emitVariableDeclaration(VarDecl(var), evdConfig);
2511 if (var.getType().isArray() && var.getType().getQualifier() == EvqTemporary)
2512 {
2513 // The translator frontend injects a loop-based init for user arrays when the source
2514 // shader is using ESSL 1.00. Some Metal drivers may fail to access elements of such
2515 // arrays at runtime depending on the array size. An empty literal initializer added
2516 // to the generated MSL bypasses the issue. The frontend may be further optimized to
2517 // skip the loop-based init when targeting MSL.
2518 mOut << "{}";
2519 }
2520 }
2521 else if (TIntermBinary *initNode = node.getAsBinaryNode())
2522 {
2523 ASSERT(initNode->getOp() == TOperator::EOpInitialize);
2524 TIntermSymbol *leftSymbolNode = initNode->getLeft()->getAsSymbolNode();
2525 TIntermTyped *valueNode = initNode->getRight()->getAsTyped();
2526 ASSERT(leftSymbolNode && valueNode);
2527
2528 if (getRootNode() == getParentBlock())
2529 {
2530 // DeferGlobalInitializers should have turned non-const global initializers into
2531 // deferred initializers. Note that variables marked as EvqGlobal can be treated as
2532 // EvqConst in some ANGLE code but not actually have their qualifier actually changed to
2533 // EvqConst. Thus just assume all EvqGlobal are actually EvqConst for all code run after
2534 // DeferGlobalInitializers.
2535 mOut << "constant ";
2536 }
2537
2538 const TVariable &var = leftSymbolNode->variable();
2539 const Name varName(var);
2540
2541 if (ExpressionContainsName(varName, *valueNode))
2542 {
2543 mRenamedSymbols[&var] = mIdGen.createNewName(varName);
2544 }
2545
2546 emitVariableDeclaration(VarDecl(var), evdConfig);
2547 mOut << " = ";
2548 groupedTraverse(*valueNode);
2549 }
2550 else
2551 {
2552 UNREACHABLE();
2553 }
2554
2555 return false;
2556 }
2557
visitLoop(Visit,TIntermLoop * loopNode)2558 bool GenMetalTraverser::visitLoop(Visit, TIntermLoop *loopNode)
2559 {
2560 const TLoopType loopType = loopNode->getType();
2561
2562 switch (loopType)
2563 {
2564 case TLoopType::ELoopFor:
2565 return visitForLoop(loopNode);
2566 case TLoopType::ELoopWhile:
2567 return visitWhileLoop(loopNode);
2568 case TLoopType::ELoopDoWhile:
2569 return visitDoWhileLoop(loopNode);
2570 }
2571 }
2572
visitForLoop(TIntermLoop * loopNode)2573 bool GenMetalTraverser::visitForLoop(TIntermLoop *loopNode)
2574 {
2575 ASSERT(loopNode->getType() == TLoopType::ELoopFor);
2576
2577 TIntermNode *initNode = loopNode->getInit();
2578 TIntermTyped *condNode = loopNode->getCondition();
2579 TIntermTyped *exprNode = loopNode->getExpression();
2580
2581 mOut << "for (";
2582
2583 if (initNode)
2584 {
2585 initNode->traverse(this);
2586 }
2587 else
2588 {
2589 mOut << " ";
2590 }
2591
2592 mOut << "; ";
2593
2594 if (condNode)
2595 {
2596 condNode->traverse(this);
2597 }
2598
2599 mOut << "; ";
2600
2601 if (exprNode)
2602 {
2603 exprNode->traverse(this);
2604 }
2605
2606 mOut << ")\n";
2607
2608 emitLoopBody(loopNode->getBody());
2609
2610 return false;
2611 }
2612
visitWhileLoop(TIntermLoop * loopNode)2613 bool GenMetalTraverser::visitWhileLoop(TIntermLoop *loopNode)
2614 {
2615 ASSERT(loopNode->getType() == TLoopType::ELoopWhile);
2616
2617 TIntermNode *initNode = loopNode->getInit();
2618 TIntermTyped *condNode = loopNode->getCondition();
2619 TIntermTyped *exprNode = loopNode->getExpression();
2620 ASSERT(condNode);
2621 ASSERT(!initNode && !exprNode);
2622
2623 emitIndentation();
2624 mOut << "while (";
2625 condNode->traverse(this);
2626 mOut << ")\n";
2627 emitLoopBody(loopNode->getBody());
2628
2629 return false;
2630 }
2631
visitDoWhileLoop(TIntermLoop * loopNode)2632 bool GenMetalTraverser::visitDoWhileLoop(TIntermLoop *loopNode)
2633 {
2634 ASSERT(loopNode->getType() == TLoopType::ELoopDoWhile);
2635
2636 TIntermNode *initNode = loopNode->getInit();
2637 TIntermTyped *condNode = loopNode->getCondition();
2638 TIntermTyped *exprNode = loopNode->getExpression();
2639 ASSERT(condNode);
2640 ASSERT(!initNode && !exprNode);
2641
2642 emitIndentation();
2643 mOut << "do\n";
2644 emitLoopBody(loopNode->getBody());
2645 mOut << "\n";
2646 emitIndentation();
2647 mOut << "while (";
2648 condNode->traverse(this);
2649 mOut << ");";
2650
2651 return false;
2652 }
2653
visitBranch(Visit,TIntermBranch * branchNode)2654 bool GenMetalTraverser::visitBranch(Visit, TIntermBranch *branchNode)
2655 {
2656 const TOperator flowOp = branchNode->getFlowOp();
2657 TIntermTyped *exprNode = branchNode->getExpression();
2658
2659 emitIndentation();
2660
2661 switch (flowOp)
2662 {
2663 case TOperator::EOpKill:
2664 {
2665 ASSERT(exprNode == nullptr);
2666 mOut << "metal::discard_fragment()";
2667 }
2668 break;
2669
2670 case TOperator::EOpReturn:
2671 {
2672 if (isTraversingVertexMain)
2673 {
2674 mOut << "#if TRANSFORM_FEEDBACK_ENABLED\n";
2675 emitIndentation();
2676 mOut << "return;\n";
2677 emitIndentation();
2678 mOut << "#else\n";
2679 emitIndentation();
2680 }
2681 mOut << "return";
2682 if (exprNode)
2683 {
2684 mOut << " ";
2685 exprNode->traverse(this);
2686 mOut << ";";
2687 }
2688 if (isTraversingVertexMain)
2689 {
2690 mOut << "\n";
2691 emitIndentation();
2692 mOut << "#endif\n";
2693 mTemporarilyDisableSemicolon = true;
2694 }
2695 }
2696 break;
2697
2698 case TOperator::EOpBreak:
2699 {
2700 ASSERT(exprNode == nullptr);
2701 mOut << "break";
2702 }
2703 break;
2704
2705 case TOperator::EOpContinue:
2706 {
2707 ASSERT(exprNode == nullptr);
2708 mOut << "continue";
2709 }
2710 break;
2711
2712 default:
2713 {
2714 UNREACHABLE();
2715 }
2716 }
2717
2718 return false;
2719 }
2720
2721 static size_t emitMetalCallCount = 0;
2722
EmitMetal(TCompiler & compiler,TIntermBlock & root,IdGen & idGen,const PipelineStructs & pipelineStructs,SymbolEnv & symbolEnv,const ProgramPreludeConfig & ppc,const ShCompileOptions & compileOptions)2723 bool sh::EmitMetal(TCompiler &compiler,
2724 TIntermBlock &root,
2725 IdGen &idGen,
2726 const PipelineStructs &pipelineStructs,
2727 SymbolEnv &symbolEnv,
2728 const ProgramPreludeConfig &ppc,
2729 const ShCompileOptions &compileOptions)
2730 {
2731 TInfoSinkBase &out = compiler.getInfoSink().obj;
2732
2733 {
2734 ++emitMetalCallCount;
2735 std::string filenameProto = angle::GetEnvironmentVar("GMD_FIXED_EMIT");
2736 if (!filenameProto.empty())
2737 {
2738 if (filenameProto != "/dev/null")
2739 {
2740 auto tryOpen = [&](char const *ext) {
2741 auto filename = filenameProto;
2742 filename += std::to_string(emitMetalCallCount);
2743 filename += ".";
2744 filename += ext;
2745 return fopen(filename.c_str(), "rb");
2746 };
2747 FILE *file = tryOpen("metal");
2748 if (!file)
2749 {
2750 file = tryOpen("cpp");
2751 }
2752 ASSERT(file);
2753
2754 fseek(file, 0, SEEK_END);
2755 size_t fileSize = ftell(file);
2756 fseek(file, 0, SEEK_SET);
2757
2758 std::vector<char> buff;
2759 buff.resize(fileSize + 1);
2760 fread(buff.data(), fileSize, 1, file);
2761 buff.back() = '\0';
2762
2763 fclose(file);
2764
2765 out << buff.data();
2766 }
2767
2768 return true;
2769 }
2770 }
2771
2772 out << "\n\n";
2773
2774 if (!EmitProgramPrelude(root, out, ppc))
2775 {
2776 return false;
2777 }
2778
2779 {
2780 #if defined(ANGLE_ENABLE_ASSERTS)
2781 DebugSink outWrapper(out, angle::GetBoolEnvironmentVar("GMD_STDOUT"));
2782 outWrapper.watch(angle::GetEnvironmentVar("GMD_WATCH_STRING"));
2783 #else
2784 TInfoSinkBase &outWrapper = out;
2785 #endif
2786 GenMetalTraverser gen(compiler, outWrapper, idGen, pipelineStructs, symbolEnv,
2787 compileOptions);
2788 root.traverse(&gen);
2789 }
2790
2791 out << "\n";
2792
2793 return true;
2794 }
2795