• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include "compiler/translator/TranslatorMetalDirect.h"
8 
9 #include "angle_gl.h"
10 #include "common/utilities.h"
11 #include "compiler/translator/BuiltinsWorkaroundGLSL.h"
12 #include "compiler/translator/DriverUniformMetal.h"
13 #include "compiler/translator/ImmutableStringBuilder.h"
14 #include "compiler/translator/OutputGLSLBase.h"
15 #include "compiler/translator/StaticType.h"
16 #include "compiler/translator/TranslatorMetalDirect/AddExplicitTypeCasts.h"
17 #include "compiler/translator/TranslatorMetalDirect/AstHelpers.h"
18 #include "compiler/translator/TranslatorMetalDirect/EmitMetal.h"
19 #include "compiler/translator/TranslatorMetalDirect/FixTypeConstructors.h"
20 #include "compiler/translator/TranslatorMetalDirect/HoistConstants.h"
21 #include "compiler/translator/TranslatorMetalDirect/IntroduceVertexIndexID.h"
22 #include "compiler/translator/TranslatorMetalDirect/Name.h"
23 #include "compiler/translator/TranslatorMetalDirect/NameEmbeddedUniformStructsMetal.h"
24 #include "compiler/translator/TranslatorMetalDirect/ReduceInterfaceBlocks.h"
25 #include "compiler/translator/TranslatorMetalDirect/RewriteCaseDeclarations.h"
26 #include "compiler/translator/TranslatorMetalDirect/RewriteOutArgs.h"
27 #include "compiler/translator/TranslatorMetalDirect/RewritePipelines.h"
28 #include "compiler/translator/TranslatorMetalDirect/RewriteUnaddressableReferences.h"
29 #include "compiler/translator/TranslatorMetalDirect/SeparateCompoundExpressions.h"
30 #include "compiler/translator/TranslatorMetalDirect/SeparateCompoundStructDeclarations.h"
31 #include "compiler/translator/TranslatorMetalDirect/SymbolEnv.h"
32 #include "compiler/translator/TranslatorMetalDirect/ToposortStructs.h"
33 #include "compiler/translator/TranslatorMetalDirect/TranslatorMetalUtils.h"
34 #include "compiler/translator/TranslatorMetalDirect/WrapMain.h"
35 #include "compiler/translator/tree_ops/ConvertUnsupportedConstructorsToFunctionCalls.h"
36 #include "compiler/translator/tree_ops/InitializeVariables.h"
37 #include "compiler/translator/tree_ops/MonomorphizeUnsupportedFunctions.h"
38 #include "compiler/translator/tree_ops/NameNamelessUniformBuffers.h"
39 #include "compiler/translator/tree_ops/RemoveAtomicCounterBuiltins.h"
40 #include "compiler/translator/tree_ops/RemoveInactiveInterfaceVariables.h"
41 #include "compiler/translator/tree_ops/RewriteArrayOfArrayOfOpaqueUniforms.h"
42 #include "compiler/translator/tree_ops/RewriteAtomicCounters.h"
43 #include "compiler/translator/tree_ops/RewriteCubeMapSamplersAs2DArray.h"
44 #include "compiler/translator/tree_ops/RewriteDfdy.h"
45 #include "compiler/translator/tree_ops/RewriteStructSamplers.h"
46 #include "compiler/translator/tree_ops/SeparateStructFromUniformDeclarations.h"
47 #include "compiler/translator/tree_ops/apple/RewriteRowMajorMatrices.h"
48 #include "compiler/translator/tree_util/BuiltIn.h"
49 #include "compiler/translator/tree_util/DriverUniform.h"
50 #include "compiler/translator/tree_util/FindFunction.h"
51 #include "compiler/translator/tree_util/FindMain.h"
52 #include "compiler/translator/tree_util/FindSymbolNode.h"
53 #include "compiler/translator/tree_util/IntermNode_util.h"
54 #include "compiler/translator/tree_util/ReplaceClipCullDistanceVariable.h"
55 #include "compiler/translator/tree_util/ReplaceVariable.h"
56 #include "compiler/translator/tree_util/RunAtTheEndOfShader.h"
57 #include "compiler/translator/tree_util/SpecializationConstant.h"
58 #include "compiler/translator/util.h"
59 
60 namespace sh
61 {
62 
63 namespace
64 {
65 
66 constexpr Name kSampleMaskWriteFuncName("writeSampleMask", SymbolType::AngleInternal);
67 constexpr Name kFlippedPointCoordName("flippedPointCoord", SymbolType::AngleInternal);
68 constexpr Name kFlippedFragCoordName("flippedFragCoord", SymbolType::AngleInternal);
69 
70 constexpr const TVariable kgl_VertexIDMetal(BuiltInId::gl_VertexID,
71                                             ImmutableString("gl_VertexID"),
72                                             SymbolType::BuiltIn,
73                                             TExtension::UNDEFINED,
74                                             StaticType::Get<EbtUInt, EbpHigh, EvqVertexID, 1, 1>());
75 
76 class DeclareStructTypesTraverser : public TIntermTraverser
77 {
78   public:
DeclareStructTypesTraverser(TOutputMSL * outputMSL)79     explicit DeclareStructTypesTraverser(TOutputMSL *outputMSL)
80         : TIntermTraverser(true, false, false), mOutputMSL(outputMSL)
81     {}
82 
visitDeclaration(Visit visit,TIntermDeclaration * node)83     bool visitDeclaration(Visit visit, TIntermDeclaration *node) override
84     {
85         ASSERT(visit == PreVisit);
86         if (!mInGlobalScope)
87         {
88             return false;
89         }
90 
91         const TIntermSequence &sequence = *(node->getSequence());
92         TIntermTyped *declarator        = sequence.front()->getAsTyped();
93         const TType &type               = declarator->getType();
94 
95         if (type.isStructSpecifier())
96         {
97             const TStructure *structure = type.getStruct();
98 
99             // Embedded structs should be parsed away by now.
100             ASSERT(structure->symbolType() != SymbolType::Empty);
101             // outputMSL->writeStructType(structure);
102 
103             TIntermSymbol *symbolNode = declarator->getAsSymbolNode();
104             if (symbolNode && symbolNode->variable().symbolType() == SymbolType::Empty)
105             {
106                 // Remove the struct specifier declaration from the tree so it isn't parsed again.
107                 TIntermSequence emptyReplacement;
108                 mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node,
109                                                 std::move(emptyReplacement));
110             }
111         }
112         // TODO: REMOVE, used to remove 'unsued' warning
113         mOutputMSL = nullptr;
114 
115         return false;
116     }
117 
118   private:
119     TOutputMSL *mOutputMSL;
120 };
121 
122 class DeclareDefaultUniformsTraverser : public TIntermTraverser
123 {
124   public:
DeclareDefaultUniformsTraverser(TInfoSinkBase * sink,ShHashFunction64 hashFunction,NameMap * nameMap)125     DeclareDefaultUniformsTraverser(TInfoSinkBase *sink,
126                                     ShHashFunction64 hashFunction,
127                                     NameMap *nameMap)
128         : TIntermTraverser(true, true, true),
129           mSink(sink),
130           mHashFunction(hashFunction),
131           mNameMap(nameMap),
132           mInDefaultUniform(false)
133     {}
134 
visitDeclaration(Visit visit,TIntermDeclaration * node)135     bool visitDeclaration(Visit visit, TIntermDeclaration *node) override
136     {
137         const TIntermSequence &sequence = *(node->getSequence());
138 
139         // TODO(jmadill): Compound declarations.
140         ASSERT(sequence.size() == 1);
141 
142         TIntermTyped *variable = sequence.front()->getAsTyped();
143         const TType &type      = variable->getType();
144         bool isUniform         = type.getQualifier() == EvqUniform && !type.isInterfaceBlock() &&
145                          !IsOpaqueType(type.getBasicType());
146 
147         if (visit == PreVisit)
148         {
149             if (isUniform)
150             {
151                 (*mSink) << "    " << GetTypeName(type, mHashFunction, mNameMap) << " ";
152                 mInDefaultUniform = true;
153             }
154         }
155         else if (visit == InVisit)
156         {
157             mInDefaultUniform = isUniform;
158         }
159         else if (visit == PostVisit)
160         {
161             if (isUniform)
162             {
163                 (*mSink) << ";\n";
164 
165                 // Remove the uniform declaration from the tree so it isn't parsed again.
166                 TIntermSequence emptyReplacement;
167                 mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node,
168                                                 std::move(emptyReplacement));
169             }
170 
171             mInDefaultUniform = false;
172         }
173         return true;
174     }
175 
visitSymbol(TIntermSymbol * symbol)176     void visitSymbol(TIntermSymbol *symbol) override
177     {
178         if (mInDefaultUniform)
179         {
180             const ImmutableString &name = symbol->variable().name();
181             ASSERT(!gl::IsBuiltInName(name.data()));
182             (*mSink) << HashName(&symbol->variable(), mHashFunction, mNameMap)
183                      << ArrayString(symbol->getType());
184         }
185     }
186 
187   private:
188     TInfoSinkBase *mSink;
189     ShHashFunction64 mHashFunction;
190     NameMap *mNameMap;
191     bool mInDefaultUniform;
192 };
193 
194 // Declares a new variable to replace gl_DepthRange, its values are fed from a driver uniform.
ReplaceGLDepthRangeWithDriverUniform(TCompiler * compiler,TIntermBlock * root,const DriverUniformMetal * driverUniforms,TSymbolTable * symbolTable)195 ANGLE_NO_DISCARD bool ReplaceGLDepthRangeWithDriverUniform(TCompiler *compiler,
196                                                            TIntermBlock *root,
197                                                            const DriverUniformMetal *driverUniforms,
198                                                            TSymbolTable *symbolTable)
199 {
200     // Create a symbol reference to "gl_DepthRange"
201     const TVariable *depthRangeVar = static_cast<const TVariable *>(
202         symbolTable->findBuiltIn(ImmutableString("gl_DepthRange"), 0));
203 
204     // ANGLEUniforms.depthRange
205     TIntermTyped *angleEmulatedDepthRangeRef = driverUniforms->getDepthRangeRef();
206 
207     // Use this variable instead of gl_DepthRange everywhere.
208     return ReplaceVariableWithTyped(compiler, root, depthRangeVar, angleEmulatedDepthRangeRef);
209 }
210 
GetMainSequence(TIntermBlock * root)211 TIntermSequence *GetMainSequence(TIntermBlock *root)
212 {
213     TIntermFunctionDefinition *main = FindMain(root);
214     return main->getBody()->getSequence();
215 }
216 
217 // Replaces a builtin variable with a version that is rotated and corrects the X and Y coordinates.
RotateAndFlipBuiltinVariable(TCompiler * compiler,TIntermBlock * root,TIntermSequence * insertSequence,TIntermTyped * flipXY,TSymbolTable * symbolTable,const TVariable * builtin,const Name & flippedVariableName,TIntermTyped * pivot,TIntermTyped * fragRotation)218 ANGLE_NO_DISCARD bool RotateAndFlipBuiltinVariable(TCompiler *compiler,
219                                                    TIntermBlock *root,
220                                                    TIntermSequence *insertSequence,
221                                                    TIntermTyped *flipXY,
222                                                    TSymbolTable *symbolTable,
223                                                    const TVariable *builtin,
224                                                    const Name &flippedVariableName,
225                                                    TIntermTyped *pivot,
226                                                    TIntermTyped *fragRotation)
227 {
228     // Create a symbol reference to 'builtin'.
229     TIntermSymbol *builtinRef = new TIntermSymbol(builtin);
230 
231     // Create a swizzle to "builtin.xy"
232     TVector<int> swizzleOffsetXY = {0, 1};
233     TIntermSwizzle *builtinXY    = new TIntermSwizzle(builtinRef, swizzleOffsetXY);
234 
235     // Create a symbol reference to our new variable that will hold the modified builtin.
236     const TType *type =
237         StaticType::GetForVec<EbtFloat, EbpHigh>(EvqGlobal, builtin->getType().getNominalSize());
238     TVariable *replacementVar =
239         new TVariable(symbolTable, flippedVariableName.rawName(), type, SymbolType::AngleInternal);
240     DeclareGlobalVariable(root, replacementVar);
241     TIntermSymbol *flippedBuiltinRef = new TIntermSymbol(replacementVar);
242 
243     // Use this new variable instead of 'builtin' everywhere.
244     if (!ReplaceVariable(compiler, root, builtin, replacementVar))
245     {
246         return false;
247     }
248 
249     // Create the expression "(builtin.xy * fragRotation)"
250     TIntermTyped *rotatedXY;
251     if (fragRotation)
252     {
253         rotatedXY = new TIntermBinary(EOpMatrixTimesVector, fragRotation, builtinXY);
254     }
255     else
256     {
257         // No rotation applied, use original variable.
258         rotatedXY = builtinXY;
259     }
260 
261     // Create the expression "(builtin.xy - pivot) * flipXY + pivot
262     TIntermBinary *removePivot = new TIntermBinary(EOpSub, rotatedXY, pivot);
263     TIntermBinary *inverseXY   = new TIntermBinary(EOpMul, removePivot, flipXY);
264     TIntermBinary *plusPivot   = new TIntermBinary(EOpAdd, inverseXY, pivot->deepCopy());
265 
266     // Create the corrected variable and copy the value of the original builtin.
267     TIntermSequence sequence;
268     sequence.push_back(builtinRef->deepCopy());
269     TIntermAggregate *aggregate =
270         TIntermAggregate::CreateConstructor(builtin->getType(), &sequence);
271     TIntermBinary *assignment = new TIntermBinary(EOpInitialize, flippedBuiltinRef, aggregate);
272 
273     // Create an assignment to the replaced variable's .xy.
274     TIntermSwizzle *correctedXY =
275         new TIntermSwizzle(flippedBuiltinRef->deepCopy(), swizzleOffsetXY);
276     TIntermBinary *assignToY = new TIntermBinary(EOpAssign, correctedXY, plusPivot);
277 
278     // Add this assigment at the beginning of the main function
279     insertSequence->insert(insertSequence->begin(), assignToY);
280     insertSequence->insert(insertSequence->begin(), assignment);
281 
282     return compiler->validateAST(root);
283 }
284 
InsertFragCoordCorrection(TCompiler * compiler,ShCompileOptions compileOptions,TIntermBlock * root,TIntermSequence * insertSequence,TSymbolTable * symbolTable,SpecConst * specConst,const DriverUniformMetal * driverUniforms)285 ANGLE_NO_DISCARD bool InsertFragCoordCorrection(TCompiler *compiler,
286                                                 ShCompileOptions compileOptions,
287                                                 TIntermBlock *root,
288                                                 TIntermSequence *insertSequence,
289                                                 TSymbolTable *symbolTable,
290                                                 SpecConst *specConst,
291                                                 const DriverUniformMetal *driverUniforms)
292 {
293     TIntermTyped *flipXY = specConst->getFlipXY();
294     if (!flipXY)
295     {
296         flipXY = driverUniforms->getFlipXYRef();
297     }
298 
299     TIntermTyped *pivot = specConst->getHalfRenderArea();
300     if (!pivot)
301     {
302         pivot = driverUniforms->getHalfRenderAreaRef();
303     }
304 
305     TIntermTyped *fragRotation = nullptr;
306     if ((compileOptions & SH_ADD_PRE_ROTATION) != 0)
307     {
308         fragRotation = specConst->getFragRotationMatrix();
309         if (!fragRotation)
310         {
311             fragRotation = driverUniforms->getFragRotationMatrixRef();
312         }
313     }
314 
315     const TVariable *fragCoord = static_cast<const TVariable *>(
316         symbolTable->findBuiltIn(ImmutableString("gl_FragCoord"), compiler->getShaderVersion()));
317     return RotateAndFlipBuiltinVariable(compiler, root, insertSequence, flipXY, symbolTable,
318                                         fragCoord, kFlippedFragCoordName, pivot, fragRotation);
319 }
320 
DeclareRightBeforeMain(TIntermBlock & root,const TVariable & var)321 void DeclareRightBeforeMain(TIntermBlock &root, const TVariable &var)
322 {
323     root.insertChildNodes(FindMainIndex(&root), {new TIntermDeclaration{&var}});
324 }
325 
AddFragColorDeclaration(TIntermBlock & root,TSymbolTable & symbolTable)326 void AddFragColorDeclaration(TIntermBlock &root, TSymbolTable &symbolTable)
327 {
328     root.insertChildNodes(FindMainIndex(&root),
329                           TIntermSequence{new TIntermDeclaration{BuiltInVariable::gl_FragColor()}});
330 }
331 
AddFragDepthDeclaration(TIntermBlock & root,TSymbolTable & symbolTable)332 void AddFragDepthDeclaration(TIntermBlock &root, TSymbolTable &symbolTable)
333 {
334     root.insertChildNodes(FindMainIndex(&root),
335                           TIntermSequence{new TIntermDeclaration{BuiltInVariable::gl_FragDepth()}});
336 }
337 
AddFragDepthEXTDeclaration(TCompiler & compiler,TIntermBlock & root,TSymbolTable & symbolTable)338 void AddFragDepthEXTDeclaration(TCompiler &compiler, TIntermBlock &root, TSymbolTable &symbolTable)
339 {
340     const TIntermSymbol *glFragDepthExt = FindSymbolNode(&root, ImmutableString("gl_FragDepthEXT"));
341     ASSERT(glFragDepthExt);
342     // Replace gl_FragData with our globally defined fragdata.
343     if (!ReplaceVariable(&compiler, &root, &(glFragDepthExt->variable()),
344                          BuiltInVariable::gl_FragDepth()))
345     {
346         return;
347     }
348     AddFragDepthDeclaration(root, symbolTable);
349 }
AddSampleMaskDeclaration(TIntermBlock & root,TSymbolTable & symbolTable)350 void AddSampleMaskDeclaration(TIntermBlock &root, TSymbolTable &symbolTable)
351 {
352     TType *gl_SampleMaskType = new TType(EbtUInt, EbpHigh, EvqSampleMask, 1, 1);
353     const TVariable *gl_SampleMask =
354         new TVariable(&symbolTable, ImmutableString("gl_SampleMask"), gl_SampleMaskType,
355                       SymbolType::BuiltIn, TExtension::UNDEFINED);
356     root.insertChildNodes(FindMainIndex(&root),
357                           TIntermSequence{new TIntermDeclaration{gl_SampleMask}});
358 }
359 
AddFragDataDeclaration(TCompiler & compiler,TIntermBlock & root)360 ANGLE_NO_DISCARD bool AddFragDataDeclaration(TCompiler &compiler, TIntermBlock &root)
361 {
362     TSymbolTable &symbolTable = compiler.getSymbolTable();
363     const int maxDrawBuffers  = compiler.getResources().MaxDrawBuffers;
364     TType *gl_FragDataType    = new TType(EbtFloat, EbpMedium, EvqFragData, 4, 1);
365     std::vector<const TVariable *> glFragDataSlots;
366     TIntermSequence declareGLFragdataSequence;
367 
368     // Create gl_FragData_0,1,2,3
369     for (int i = 0; i < maxDrawBuffers; i++)
370     {
371         ImmutableStringBuilder builder(strlen("gl_FragData_") + 2);
372         builder << "gl_FragData_";
373         builder.appendDecimal(i);
374         const TVariable *glFragData =
375             new TVariable(&symbolTable, builder, gl_FragDataType, SymbolType::AngleInternal,
376                           TExtension::UNDEFINED);
377         glFragDataSlots.push_back(glFragData);
378         declareGLFragdataSequence.push_back(new TIntermDeclaration{glFragData});
379     }
380     root.insertChildNodes(FindMainIndex(&root), declareGLFragdataSequence);
381 
382     // Create an internal gl_FragData array type, compatible with indexing syntax.
383     TType *gl_FragDataTypeArray = new TType(EbtFloat, EbpMedium, EvqGlobal, 4, 1);
384     gl_FragDataTypeArray->makeArray(maxDrawBuffers);
385     const TVariable *glFragDataGlobal = new TVariable(&symbolTable, ImmutableString("gl_FragData"),
386                                                       gl_FragDataTypeArray, SymbolType::BuiltIn);
387 
388     DeclareGlobalVariable(&root, glFragDataGlobal);
389     const TIntermSymbol *originalGLFragData = FindSymbolNode(&root, ImmutableString("gl_FragData"));
390     ASSERT(originalGLFragData);
391 
392     // Replace gl_FragData() with our globally defined fragdata
393     if (!ReplaceVariable(&compiler, &root, &(originalGLFragData->variable()), glFragDataGlobal))
394     {
395         return false;
396     }
397 
398     // Assign each array attribute to an output
399     TIntermBlock *insertSequence = new TIntermBlock();
400     for (int i = 0; i < maxDrawBuffers; i++)
401     {
402         TIntermTyped *glFragDataSlot         = new TIntermSymbol(glFragDataSlots[i]);
403         TIntermTyped *glFragDataGlobalSymbol = new TIntermSymbol(glFragDataGlobal);
404         auto &access                         = AccessIndex(*glFragDataGlobalSymbol, i);
405         TIntermBinary *assignment =
406             new TIntermBinary(TOperator::EOpAssign, glFragDataSlot, &access);
407         insertSequence->appendStatement(assignment);
408     }
409     return RunAtTheEndOfShader(&compiler, &root, insertSequence, &symbolTable);
410 }
411 
AppendVertexShaderTransformFeedbackOutputToMain(TCompiler & compiler,SymbolEnv & mSymbolEnv,TIntermBlock & root)412 ANGLE_NO_DISCARD bool AppendVertexShaderTransformFeedbackOutputToMain(TCompiler &compiler,
413                                                                       SymbolEnv &mSymbolEnv,
414                                                                       TIntermBlock &root)
415 {
416     TSymbolTable &symbolTable = compiler.getSymbolTable();
417 
418     // Append the assignment as a statement at the end of the shader.
419     return RunAtTheEndOfShader(&compiler, &root,
420                                &(mSymbolEnv.callFunctionOverload(Name("@@XFB-OUT@@"), *new TType(),
421                                                                  *new TIntermSequence())),
422                                &symbolTable);
423 }
424 
425 // Unlike Vulkan having auto viewport flipping extension, in Metal we have to flip gl_Position.y
426 // manually.
427 // This operation performs flipping the gl_Position.y using this expression:
428 // gl_Position.y = gl_Position.y * negViewportScaleY
AppendVertexShaderPositionYCorrectionToMain(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,TIntermTyped * negFlipY)429 ANGLE_NO_DISCARD bool AppendVertexShaderPositionYCorrectionToMain(TCompiler *compiler,
430                                                                   TIntermBlock *root,
431                                                                   TSymbolTable *symbolTable,
432                                                                   TIntermTyped *negFlipY)
433 {
434     // Create a symbol reference to "gl_Position"
435     const TVariable *position  = BuiltInVariable::gl_Position();
436     TIntermSymbol *positionRef = new TIntermSymbol(position);
437 
438     // Create a swizzle to "gl_Position.y"
439     TVector<int> swizzleOffsetY;
440     swizzleOffsetY.push_back(1);
441     TIntermSwizzle *positionY = new TIntermSwizzle(positionRef, swizzleOffsetY);
442 
443     // Create the expression "gl_Position.y * negFlipY"
444     TIntermBinary *inverseY = new TIntermBinary(EOpMul, positionY->deepCopy(), negFlipY);
445 
446     // Create the assignment "gl_Position.y = gl_Position.y * negViewportScaleY
447     TIntermTyped *positionYLHS = positionY->deepCopy();
448     TIntermBinary *assignment  = new TIntermBinary(TOperator::EOpAssign, positionYLHS, inverseY);
449 
450     // Append the assignment as a statement at the end of the shader.
451     return RunAtTheEndOfShader(compiler, root, assignment, symbolTable);
452 }
453 }  // namespace
454 
455 namespace mtl
456 {
getTranslatorMetalReflection(const TCompiler * compiler)457 TranslatorMetalReflection *getTranslatorMetalReflection(const TCompiler *compiler)
458 {
459     return ((TranslatorMetalDirect *)compiler)->getTranslatorMetalReflection();
460 }
461 }  // namespace mtl
TranslatorMetalDirect(sh::GLenum type,ShShaderSpec spec,ShShaderOutput output)462 TranslatorMetalDirect::TranslatorMetalDirect(sh::GLenum type,
463                                              ShShaderSpec spec,
464                                              ShShaderOutput output)
465     : TCompiler(type, spec, output)
466 {}
467 
468 // Add sample_mask writing to main, guarded by the function constant
469 // kCoverageMaskEnabledName
insertSampleMaskWritingLogic(TIntermBlock & root,DriverUniformMetal & driverUniforms)470 ANGLE_NO_DISCARD bool TranslatorMetalDirect::insertSampleMaskWritingLogic(
471     TIntermBlock &root,
472     DriverUniformMetal &driverUniforms)
473 {
474     // This transformation leaves the tree in an inconsistent state by using a variable that's
475     // defined in text, outside of the knowledge of the AST.
476     mValidateASTOptions.validateVariableReferences = false;
477     // It also uses a function call (ANGLE_writeSampleMask) that's unknown to the AST.
478     mValidateASTOptions.validateFunctionCall = false;
479 
480     TSymbolTable *symbolTable = &getSymbolTable();
481 
482     // Create kCoverageMaskEnabled and kSampleMaskWriteFuncName variable references.
483     TType *boolType = new TType(EbtBool);
484     boolType->setQualifier(EvqConst);
485     TVariable *coverageMaskEnabledVar =
486         new TVariable(symbolTable, sh::ImmutableString(sh::mtl::kCoverageMaskEnabledConstName),
487                       boolType, SymbolType::AngleInternal);
488 
489     TFunction *sampleMaskWriteFunc = new TFunction(
490         symbolTable, kSampleMaskWriteFuncName.rawName(), kSampleMaskWriteFuncName.symbolType(),
491         StaticType::GetBasic<EbtVoid, EbpUndefined>(), false);
492 
493     TType *uintType = new TType(EbtUInt);
494     TVariable *maskArg =
495         new TVariable(symbolTable, ImmutableString("mask"), uintType, SymbolType::AngleInternal);
496     sampleMaskWriteFunc->addParameter(maskArg);
497 
498     TVariable *gl_SampleMaskArg = new TVariable(symbolTable, ImmutableString("gl_SampleMask"),
499                                                 uintType, SymbolType::AngleInternal);
500     sampleMaskWriteFunc->addParameter(gl_SampleMaskArg);
501 
502     // Insert this MSL code to the end of main() in the shader
503     // if (ANGLECoverageMaskEnabled)
504     // {
505     //      ANGLE_writeSampleMask(ANGLE_angleUniforms.coverageMask,
506     //      ANGLE_fragmentOut.gl_SampleMask);
507     // }
508     TIntermSequence *args      = new TIntermSequence;
509     TIntermTyped *coverageMask = driverUniforms.getCoverageMaskFieldRef();
510     args->push_back(coverageMask);
511     const TIntermSymbol *gl_SampleMask = FindSymbolNode(&root, ImmutableString("gl_SampleMask"));
512     args->push_back(gl_SampleMask->deepCopy());
513 
514     TIntermAggregate *callSampleMaskWriteFunc =
515         TIntermAggregate::CreateFunctionCall(*sampleMaskWriteFunc, args);
516     TIntermBlock *callBlock = new TIntermBlock;
517     callBlock->appendStatement(callSampleMaskWriteFunc);
518 
519     TIntermSymbol *coverageMaskEnabled = new TIntermSymbol(coverageMaskEnabledVar);
520     TIntermIfElse *ifCall              = new TIntermIfElse(coverageMaskEnabled, callBlock, nullptr);
521     return RunAtTheEndOfShader(this, &root, ifCall, symbolTable);
522 }
523 
insertRasterizationDiscardLogic(TIntermBlock & root)524 ANGLE_NO_DISCARD bool TranslatorMetalDirect::insertRasterizationDiscardLogic(TIntermBlock &root)
525 {
526     // This transformation leaves the tree in an inconsistent state by using a variable that's
527     // defined in text, outside of the knowledge of the AST.
528     mValidateASTOptions.validateVariableReferences = false;
529 
530     TSymbolTable *symbolTable = &getSymbolTable();
531 
532     TType *boolType = new TType(EbtBool);
533     boolType->setQualifier(EvqConst);
534     TVariable *discardEnabledVar =
535         new TVariable(symbolTable, sh::ImmutableString(sh::mtl::kRasterizerDiscardEnabledConstName),
536                       boolType, SymbolType::AngleInternal);
537 
538     const TVariable *position  = BuiltInVariable::gl_Position();
539     TIntermSymbol *positionRef = new TIntermSymbol(position);
540 
541     // Create vec4(-3, -3, -3, 1):
542     auto vec4Type             = new TType(EbtFloat, 4);
543     TIntermSequence *vec4Args = new TIntermSequence();
544     vec4Args->push_back(CreateFloatNode(-3.0f, EbpMedium));
545     vec4Args->push_back(CreateFloatNode(-3.0f, EbpMedium));
546     vec4Args->push_back(CreateFloatNode(-3.0f, EbpMedium));
547     vec4Args->push_back(CreateFloatNode(1.0f, EbpMedium));
548     TIntermAggregate *constVarConstructor =
549         TIntermAggregate::CreateConstructor(*vec4Type, vec4Args);
550 
551     // Create the assignment "gl_Position = vec4(-3, -3, -3, 1)"
552     TIntermBinary *assignment =
553         new TIntermBinary(TOperator::EOpAssign, positionRef->deepCopy(), constVarConstructor);
554 
555     TIntermBlock *discardBlock = new TIntermBlock;
556     discardBlock->appendStatement(assignment);
557 
558     TIntermSymbol *discardEnabled = new TIntermSymbol(discardEnabledVar);
559     TIntermIfElse *ifCall         = new TIntermIfElse(discardEnabled, discardBlock, nullptr);
560 
561     return RunAtTheEndOfShader(this, &root, ifCall, symbolTable);
562 }
563 
564 // Metal needs to inverse the depth if depthRange is is reverse order, i.e. depth near > depth far
565 // This is achieved by multiply the depth value with scale value stored in
566 // driver uniform's depthRange.reserved
transformDepthBeforeCorrection(TIntermBlock * root,const DriverUniformMetal * driverUniforms)567 bool TranslatorMetalDirect::transformDepthBeforeCorrection(TIntermBlock *root,
568                                                            const DriverUniformMetal *driverUniforms)
569 {
570     // Create a symbol reference to "gl_Position"
571     const TVariable *position  = BuiltInVariable::gl_Position();
572     TIntermSymbol *positionRef = new TIntermSymbol(position);
573 
574     // Create a swizzle to "gl_Position.z"
575     TVector<int> swizzleOffsetZ = {2};
576     TIntermSwizzle *positionZ   = new TIntermSwizzle(positionRef, swizzleOffsetZ);
577 
578     // Create a ref to "depthRange.reserved"
579     TIntermTyped *viewportZScale = driverUniforms->getDepthRangeReservedFieldRef();
580 
581     // Create the expression "gl_Position.z * depthRange.reserved".
582     TIntermBinary *zScale = new TIntermBinary(EOpMul, positionZ->deepCopy(), viewportZScale);
583 
584     // Create the assignment "gl_Position.z = gl_Position.z * depthRange.reserved"
585     TIntermTyped *positionZLHS = positionZ->deepCopy();
586     TIntermBinary *assignment  = new TIntermBinary(TOperator::EOpAssign, positionZLHS, zScale);
587 
588     // Append the assignment as a statement at the end of the shader.
589     return RunAtTheEndOfShader(this, root, assignment, &getSymbolTable());
590 }
591 
592 // This operation performs the viewport depth translation needed by Metal. GL uses a
593 // clip space z range of -1 to +1 where as Metal uses 0 to 1. The translation becomes
594 // this expression
595 //
596 //     z_metal = 0.5 * (w_gl + z_gl)
597 //
598 // where z_metal is the depth output of a Metal vertex shader and z_gl is the same for GL.
appendVertexShaderDepthCorrectionToMain(TIntermBlock * root)599 bool TranslatorMetalDirect::appendVertexShaderDepthCorrectionToMain(TIntermBlock *root)
600 {
601     const TVariable *position  = BuiltInVariable::gl_Position();
602     TIntermSymbol *positionRef = new TIntermSymbol(position);
603 
604     TVector<int> swizzleOffsetZ = {2};
605     TIntermSwizzle *positionZ   = new TIntermSwizzle(positionRef, swizzleOffsetZ);
606 
607     TIntermConstantUnion *oneHalf = CreateFloatNode(0.5f, EbpMedium);
608 
609     TVector<int> swizzleOffsetW = {3};
610     TIntermSwizzle *positionW   = new TIntermSwizzle(positionRef->deepCopy(), swizzleOffsetW);
611 
612     // Create the expression "(gl_Position.z + gl_Position.w) * 0.5".
613     TIntermBinary *zPlusW = new TIntermBinary(EOpAdd, positionZ->deepCopy(), positionW->deepCopy());
614     TIntermBinary *halfZPlusW = new TIntermBinary(EOpMul, zPlusW, oneHalf->deepCopy());
615 
616     // Create the assignment "gl_Position.z = (gl_Position.z + gl_Position.w) * 0.5"
617     TIntermTyped *positionZLHS = positionZ->deepCopy();
618     TIntermBinary *assignment  = new TIntermBinary(TOperator::EOpAssign, positionZLHS, halfZPlusW);
619 
620     // Append the assignment as a statement at the end of the shader.
621     return RunAtTheEndOfShader(this, root, assignment, &getSymbolTable());
622 }
623 
metalShaderTypeFromGLSL(sh::GLenum shaderType)624 static inline MetalShaderType metalShaderTypeFromGLSL(sh::GLenum shaderType)
625 {
626     switch (shaderType)
627     {
628         case GL_VERTEX_SHADER:
629             return MetalShaderType::Vertex;
630         case GL_FRAGMENT_SHADER:
631             return MetalShaderType::Fragment;
632         case GL_COMPUTE_SHADER:
633             ASSERT(0 && "compute shaders not currently supported");
634             return MetalShaderType::Compute;
635         default:
636             ASSERT(0 && "Invalid shader type.");
637             return MetalShaderType::None;
638     }
639 }
640 
translateImpl(TInfoSinkBase & sink,TIntermBlock * root,ShCompileOptions compileOptions,PerformanceDiagnostics *,SpecConst * specConst,DriverUniformMetal * driverUniforms)641 bool TranslatorMetalDirect::translateImpl(TInfoSinkBase &sink,
642                                           TIntermBlock *root,
643                                           ShCompileOptions compileOptions,
644                                           PerformanceDiagnostics * /*perfDiagnostics*/,
645                                           SpecConst *specConst,
646                                           DriverUniformMetal *driverUniforms)
647 {
648     TSymbolTable &symbolTable = getSymbolTable();
649     IdGen idGen;
650     ProgramPreludeConfig ppc(metalShaderTypeFromGLSL(getShaderType()));
651 
652     if (!WrapMain(*this, idGen, *root))
653     {
654         return false;
655     }
656 
657     // Remove declarations of inactive shader interface variables so glslang wrapper doesn't need to
658     // replace them.  Note: this is done before extracting samplers from structs, as removing such
659     // inactive samplers is not yet supported.  Note also that currently, CollectVariables marks
660     // every field of an active uniform that's of struct type as active, i.e. no extracted sampler
661     // is inactive.
662     if (!RemoveInactiveInterfaceVariables(this, root, &getSymbolTable(), getAttributes(),
663                                           getInputVaryings(), getOutputVariables(), getUniforms(),
664                                           getInterfaceBlocks(), false))
665     {
666         return false;
667     }
668 
669     // Write out default uniforms into a uniform block assigned to a specific set/binding.
670     int aggregateTypesUsedForUniforms = 0;
671     int atomicCounterCount            = 0;
672     for (const auto &uniform : getUniforms())
673     {
674         if (uniform.isStruct() || uniform.isArrayOfArrays())
675         {
676             ++aggregateTypesUsedForUniforms;
677         }
678 
679         if (uniform.active && gl::IsAtomicCounterType(uniform.type))
680         {
681             ++atomicCounterCount;
682         }
683     }
684 
685     // If there are any function calls that take array-of-array of opaque uniform parameters, or
686     // other opaque uniforms that need special handling in Vulkan, such as atomic counters,
687     // monomorphize the functions by removing said parameters and replacing them in the function
688     // body with the call arguments.
689     //
690     // This has a few benefits:
691     //
692     // - It dramatically simplifies future transformations w.r.t to samplers in structs, array of
693     //   arrays of opaque types, atomic counters etc.
694     // - Avoids the need for shader*ArrayDynamicIndexing Vulkan features.
695     if (!MonomorphizeUnsupportedFunctions(this, root, &getSymbolTable(), compileOptions))
696     {
697         return false;
698     }
699 
700     if (aggregateTypesUsedForUniforms > 0)
701     {
702         if (!NameEmbeddedStructUniformsMetal(this, root, &symbolTable))
703         {
704             return false;
705         }
706 
707         if (!SeparateStructFromUniformDeclarations(this, root, &getSymbolTable()))
708         {
709             return false;
710         }
711 
712         int removedUniformsCount;
713 
714         if (!RewriteStructSamplers(this, root, &getSymbolTable(), &removedUniformsCount))
715         {
716             return false;
717         }
718     }
719 
720     // Replace array of array of opaque uniforms with a flattened array.  This is run after
721     // MonomorphizeUnsupportedFunctions and RewriteStructSamplers so that it's not possible for an
722     // array of array of opaque type to be partially subscripted and passed to a function.
723     if (!RewriteArrayOfArrayOfOpaqueUniforms(this, root, &getSymbolTable()))
724     {
725         return false;
726     }
727 
728     if (compileOptions & SH_EMULATE_SEAMFUL_CUBE_MAP_SAMPLING)
729     {
730         if (!RewriteCubeMapSamplersAs2DArray(this, root, &symbolTable,
731                                              getShaderType() == GL_FRAGMENT_SHADER))
732         {
733             return false;
734         }
735     }
736 
737     if (getShaderType() == GL_COMPUTE_SHADER)
738     {
739         driverUniforms->addComputeDriverUniformsToShader(root, &getSymbolTable());
740     }
741     else
742     {
743         driverUniforms->addGraphicsDriverUniformsToShader(root, &getSymbolTable());
744     }
745 
746     if (atomicCounterCount > 0)
747     {
748         const TIntermTyped *acbBufferOffsets = driverUniforms->getAbcBufferOffsets();
749         if (!RewriteAtomicCounters(this, root, &symbolTable, acbBufferOffsets))
750         {
751             return false;
752         }
753     }
754     else if (getShaderVersion() >= 310)
755     {
756         // Vulkan doesn't support Atomic Storage as a Storage Class, but we've seen
757         // cases where builtins are using it even with no active atomic counters.
758         // This pass simply removes those builtins in that scenario.
759         if (!RemoveAtomicCounterBuiltins(this, root))
760         {
761             return false;
762         }
763     }
764 
765     if (getShaderType() != GL_COMPUTE_SHADER)
766     {
767         if (!ReplaceGLDepthRangeWithDriverUniform(this, root, driverUniforms, &getSymbolTable()))
768         {
769             return false;
770         }
771     }
772 
773     {
774         bool usesInstanceId = false;
775         bool usesVertexId   = false;
776         for (const ShaderVariable &var : mAttributes)
777         {
778             if (var.isBuiltIn())
779             {
780                 if (var.name == "gl_InstanceID")
781                 {
782                     usesInstanceId = true;
783                 }
784                 if (var.name == "gl_VertexID")
785                 {
786                     usesVertexId = true;
787                 }
788             }
789         }
790 
791         if (usesInstanceId)
792         {
793             root->insertChildNodes(
794                 FindMainIndex(root),
795                 TIntermSequence{new TIntermDeclaration{BuiltInVariable::gl_InstanceID()}});
796         }
797         if (usesVertexId)
798         {
799             if (!ReplaceVariable(this, root, BuiltInVariable::gl_VertexID(), &kgl_VertexIDMetal))
800             {
801                 return false;
802             }
803             DeclareRightBeforeMain(*root, kgl_VertexIDMetal);
804         }
805     }
806     SymbolEnv symbolEnv(*this, *root);
807     // Declare gl_FragColor and gl_FragData as webgl_FragColor and webgl_FragData
808     // if it's core profile shaders and they are used.
809     if (getShaderType() == GL_FRAGMENT_SHADER)
810     {
811         bool usesPointCoord  = false;
812         bool usesFragCoord   = false;
813         bool usesFrontFacing = false;
814         for (const ShaderVariable &inputVarying : mInputVaryings)
815         {
816             if (inputVarying.isBuiltIn())
817             {
818                 if (inputVarying.name == "gl_PointCoord")
819                 {
820                     usesPointCoord = true;
821                 }
822                 else if (inputVarying.name == "gl_FragCoord")
823                 {
824                     usesFragCoord = true;
825                 }
826                 else if (inputVarying.name == "gl_FrontFacing")
827                 {
828                     usesFrontFacing = true;
829                 }
830             }
831         }
832 
833         bool usesFragColor    = false;
834         bool usesFragData     = false;
835         bool usesFragDepth    = false;
836         bool usesFragDepthEXT = false;
837         for (const ShaderVariable &outputVarying : mOutputVariables)
838         {
839             if (outputVarying.isBuiltIn())
840             {
841                 if (outputVarying.name == "gl_FragColor")
842                 {
843                     usesFragColor = true;
844                 }
845                 else if (outputVarying.name == "gl_FragData")
846                 {
847                     usesFragData = true;
848                 }
849                 else if (outputVarying.name == "gl_FragDepth")
850                 {
851                     usesFragDepth = true;
852                 }
853                 else if (outputVarying.name == "gl_FragDepthEXT")
854                 {
855                     usesFragDepthEXT = true;
856                 }
857             }
858         }
859 
860         if (usesFragColor && usesFragData)
861         {
862             return false;
863         }
864 
865         if (usesFragColor)
866         {
867             AddFragColorDeclaration(*root, symbolTable);
868         }
869 
870         if (usesFragData)
871         {
872             if (!AddFragDataDeclaration(*this, *root))
873             {
874                 return false;
875             }
876         }
877         if (usesFragDepth)
878         {
879             AddFragDepthDeclaration(*root, symbolTable);
880         }
881         else if (usesFragDepthEXT)
882         {
883             AddFragDepthEXTDeclaration(*this, *root, symbolTable);
884         }
885 
886         // Always add sample_mask. It will be guarded by a function constant decided at runtime.
887         bool usesSampleMask = true;
888         if (usesSampleMask)
889         {
890             AddSampleMaskDeclaration(*root, symbolTable);
891         }
892 
893         if (usesPointCoord)
894         {
895             TIntermTyped *flipNegXY = specConst->getNegFlipXY();
896             if (!flipNegXY)
897             {
898                 flipNegXY = driverUniforms->getNegFlipXYRef();
899             }
900             TIntermConstantUnion *pivot = CreateFloatNode(0.5f, EbpMedium);
901             TIntermTyped *fragRotation  = nullptr;
902             if (!RotateAndFlipBuiltinVariable(this, root, GetMainSequence(root), flipNegXY,
903                                               &getSymbolTable(), BuiltInVariable::gl_PointCoord(),
904                                               kFlippedPointCoordName, pivot, fragRotation))
905             {
906                 return false;
907             }
908             DeclareRightBeforeMain(*root, *BuiltInVariable::gl_PointCoord());
909         }
910 
911         if (usesFragCoord)
912         {
913             if (!InsertFragCoordCorrection(this, compileOptions, root, GetMainSequence(root),
914                                            &getSymbolTable(), specConst, driverUniforms))
915             {
916                 return false;
917             }
918             const TVariable *fragCoord = static_cast<const TVariable *>(
919                 getSymbolTable().findBuiltIn(ImmutableString("gl_FragCoord"), getShaderVersion()));
920             DeclareRightBeforeMain(*root, *fragCoord);
921         }
922 
923         if (!RewriteDfdy(this, compileOptions, root, getSymbolTable(), getShaderVersion(),
924                          specConst, driverUniforms))
925         {
926             return false;
927         }
928 
929         if (usesFrontFacing)
930         {
931             DeclareRightBeforeMain(*root, *BuiltInVariable::gl_FrontFacing());
932         }
933 
934         EmitEarlyFragmentTestsGLSL(*this, sink);
935     }
936     else if (getShaderType() == GL_VERTEX_SHADER)
937     {
938         DeclareRightBeforeMain(*root, *BuiltInVariable::gl_Position());
939 
940         if (FindSymbolNode(root, BuiltInVariable::gl_PointSize()->name()))
941         {
942             const TVariable *pointSize = static_cast<const TVariable *>(
943                 getSymbolTable().findBuiltIn(ImmutableString("gl_PointSize"), getShaderVersion()));
944             DeclareRightBeforeMain(*root, *pointSize);
945         }
946 
947         if (FindSymbolNode(root, BuiltInVariable::gl_VertexIndex()->name()))
948         {
949             if (!ReplaceVariable(this, root, BuiltInVariable::gl_VertexIndex(), &kgl_VertexIDMetal))
950             {
951                 return false;
952             }
953             DeclareRightBeforeMain(*root, kgl_VertexIDMetal);
954         }
955 
956         // Append a macro for transform feedback substitution prior to modifying depth.
957         if (!AppendVertexShaderTransformFeedbackOutputToMain(*this, symbolEnv, *root))
958         {
959             return false;
960         }
961 
962         // Search for the gl_ClipDistance usage, if its used, we need to do some replacements.
963         bool useClipDistance = false;
964         for (const ShaderVariable &outputVarying : mOutputVaryings)
965         {
966             if (outputVarying.name == "gl_ClipDistance")
967             {
968                 useClipDistance = true;
969                 break;
970             }
971         }
972 
973         if (useClipDistance &&
974             !ReplaceClipDistanceAssignments(this, root, &getSymbolTable(), getShaderType(),
975                                             driverUniforms->getClipDistancesEnabled()))
976         {
977             return false;
978         }
979 
980         if (!transformDepthBeforeCorrection(root, driverUniforms))
981         {
982             return false;
983         }
984 
985         if (!appendVertexShaderDepthCorrectionToMain(root))
986         {
987             return false;
988         }
989     }
990     else if (getShaderType() == GL_GEOMETRY_SHADER)
991     {
992         WriteGeometryShaderLayoutQualifiers(
993             sink, getGeometryShaderInputPrimitiveType(), getGeometryShaderInvocations(),
994             getGeometryShaderOutputPrimitiveType(), getGeometryShaderMaxVertices());
995     }
996     else
997     {
998         ASSERT(getShaderType() == GL_COMPUTE_SHADER);
999         EmitWorkGroupSizeGLSL(*this, sink);
1000     }
1001 
1002     if (getShaderType() == GL_VERTEX_SHADER)
1003     {
1004         TIntermTyped *negFlipY = driverUniforms->getNegFlipYRef();
1005 
1006         if (!AppendVertexShaderPositionYCorrectionToMain(this, root, &getSymbolTable(), negFlipY))
1007         {
1008             return false;
1009         }
1010         if (!insertRasterizationDiscardLogic(*root))
1011         {
1012             return false;
1013         }
1014     }
1015     else if (getShaderType() == GL_FRAGMENT_SHADER)
1016     {
1017         if (!insertSampleMaskWritingLogic(*root, *driverUniforms))
1018         {
1019             return false;
1020         }
1021     }
1022 
1023     if (!validateAST(root))
1024     {
1025         return false;
1026     }
1027 
1028     // This is the largest size required to pass all the tests in
1029     // (dEQP-GLES3.functional.shaders.large_constant_arrays)
1030     // This value could in principle be smaller.
1031     const size_t hoistThresholdSize = 256;
1032     if (!HoistConstants(*this, *root, idGen, hoistThresholdSize))
1033     {
1034         return false;
1035     }
1036 
1037     if (!ConvertUnsupportedConstructorsToFunctionCalls(*this, *root))
1038     {
1039         return false;
1040     }
1041 
1042     const bool needsExplicitBoolCasts = (compileOptions & SH_ADD_EXPLICIT_BOOL_CASTS) != 0;
1043     if (!AddExplicitTypeCasts(*this, *root, symbolEnv, needsExplicitBoolCasts))
1044     {
1045         return false;
1046     }
1047 
1048     if (!SeparateCompoundExpressions(*this, symbolEnv, idGen, *root))
1049     {
1050         return false;
1051     }
1052 
1053     if ((compileOptions & SH_REWRITE_ROW_MAJOR_MATRICES) != 0 && getShaderVersion() >= 300)
1054     {
1055         // "Make sure every uniform buffer variable has a name.  The following transformation
1056         // relies on this." This pass was removed in e196bc85ac2dda0e9f6664cfc2eca0029e33d2d1,
1057         // but currently finding it still necessary for MSL.
1058         if (!NameNamelessUniformBuffers(this, root, &getSymbolTable()))
1059         {
1060             return false;
1061         }
1062         // Note: RewriteRowMajorMatrices can create temporaries moved above
1063         // the statement they are used in. As such it must come after
1064         // SeparateCompoundExpressions since it is not aware of short circuits
1065         // and side effects.
1066         if (!RewriteRowMajorMatrices(this, root, &getSymbolTable()))
1067         {
1068             return false;
1069         }
1070     }
1071 
1072     // Note: ReduceInterfaceBlocks removes row_major matrix layout specifiers
1073     // so it must come after RewriteRowMajorMatrices.
1074     if (!ReduceInterfaceBlocks(*this, *root, idGen, &getSymbolTable()))
1075     {
1076         return false;
1077     }
1078 
1079     if (!SeparateCompoundStructDeclarations(*this, idGen, *root, &getSymbolTable()))
1080     {
1081         return false;
1082     }
1083 
1084     // The RewritePipelines phase leaves the tree in an inconsistent state by inserting
1085     // references to structures like "ANGLE_TextureEnv<metal::texture2d<float>>" which are
1086     // defined in text (in ProgramPrelude), outside of the knowledge of the AST.
1087     mValidateASTOptions.validateStructUsage = false;
1088     // The RewritePipelines phase also generates incoming arguments to synthesized
1089     // functions that use are missing qualifiers - for example, angleUniforms isn't marked
1090     // as an incoming argument.
1091     mValidateASTOptions.validateQualifiers = false;
1092 
1093     PipelineStructs pipelineStructs;
1094     if (!RewritePipelines(*this, *root, getInputVaryings(), getOutputVaryings(), idGen,
1095                           *driverUniforms, symbolEnv, pipelineStructs))
1096     {
1097         return false;
1098     }
1099     if (getShaderType() == GL_VERTEX_SHADER)
1100     {
1101         // This has to happen after RewritePipelines.
1102         if (!IntroduceVertexAndInstanceIndex(*this, *root))
1103         {
1104             return false;
1105         }
1106     }
1107 
1108     if (!RewriteCaseDeclarations(*this, *root))
1109     {
1110         return false;
1111     }
1112 
1113     if (!RewriteUnaddressableReferences(*this, *root, symbolEnv))
1114     {
1115         return false;
1116     }
1117 
1118     if (!RewriteOutArgs(*this, *root, symbolEnv))
1119     {
1120         return false;
1121     }
1122     if (!FixTypeConstructors(*this, symbolEnv, *root))
1123     {
1124         return false;
1125     }
1126     if (!ToposortStructs(*this, symbolEnv, *root, ppc))
1127     {
1128         return false;
1129     }
1130     if (!EmitMetal(*this, *root, idGen, pipelineStructs, symbolEnv, ppc, &getSymbolTable()))
1131     {
1132         return false;
1133     }
1134 
1135     ASSERT(validateAST(root));
1136 
1137     return true;
1138 }
1139 
translate(TIntermBlock * root,ShCompileOptions compileOptions,PerformanceDiagnostics * perfDiagnostics)1140 bool TranslatorMetalDirect::translate(TIntermBlock *root,
1141                                       ShCompileOptions compileOptions,
1142                                       PerformanceDiagnostics *perfDiagnostics)
1143 {
1144     if (!root)
1145     {
1146         return false;
1147     }
1148 
1149     // TODO: refactor the code in TranslatorMetalDirect to not issue raw function calls.
1150     // http://anglebug.com/6059#c2
1151     mValidateASTOptions.validateNoRawFunctionCalls = false;
1152     // A validation error is generated in this backend due to bool uniforms.
1153     mValidateASTOptions.validatePrecision = false;
1154 
1155     TInfoSinkBase &sink = getInfoSink().obj;
1156     SpecConst specConst(&getSymbolTable(), compileOptions, getShaderType());
1157     DriverUniformMetal driverUniforms(DriverUniformMode::Structure);
1158     if (!translateImpl(sink, root, compileOptions, perfDiagnostics, &specConst, &driverUniforms))
1159     {
1160         return false;
1161     }
1162 
1163     return true;
1164 }
shouldFlattenPragmaStdglInvariantAll()1165 bool TranslatorMetalDirect::shouldFlattenPragmaStdglInvariantAll()
1166 {
1167     // Not neccesary for MSL transformation.
1168     return false;
1169 }
1170 
1171 }  // namespace sh
1172