• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include "compiler/translator/TranslatorMetalDirect.h"
8 
9 #include "angle_gl.h"
10 #include "common/utilities.h"
11 #include "compiler/translator/BuiltinsWorkaroundGLSL.h"
12 #include "compiler/translator/ImmutableStringBuilder.h"
13 #include "compiler/translator/OutputVulkanGLSL.h"
14 #include "compiler/translator/StaticType.h"
15 #include "compiler/translator/TranslatorMetalDirect/AddExplicitTypeCasts.h"
16 #include "compiler/translator/TranslatorMetalDirect/AstHelpers.h"
17 #include "compiler/translator/TranslatorMetalDirect/EmitMetal.h"
18 #include "compiler/translator/TranslatorMetalDirect/FixTypeConstructors.h"
19 #include "compiler/translator/TranslatorMetalDirect/HoistConstants.h"
20 #include "compiler/translator/TranslatorMetalDirect/IntroduceVertexIndexID.h"
21 #include "compiler/translator/TranslatorMetalDirect/Name.h"
22 #include "compiler/translator/TranslatorMetalDirect/NameEmbeddedUniformStructsMetal.h"
23 #include "compiler/translator/TranslatorMetalDirect/ReduceInterfaceBlocks.h"
24 #include "compiler/translator/TranslatorMetalDirect/RewriteCaseDeclarations.h"
25 #include "compiler/translator/TranslatorMetalDirect/RewriteGlobalQualifierDecls.h"
26 #include "compiler/translator/TranslatorMetalDirect/RewriteKeywords.h"
27 #include "compiler/translator/TranslatorMetalDirect/RewriteOutArgs.h"
28 #include "compiler/translator/TranslatorMetalDirect/RewritePipelines.h"
29 #include "compiler/translator/TranslatorMetalDirect/RewriteUnaddressableReferences.h"
30 #include "compiler/translator/TranslatorMetalDirect/SeparateCompoundExpressions.h"
31 #include "compiler/translator/TranslatorMetalDirect/SeparateCompoundStructDeclarations.h"
32 #include "compiler/translator/TranslatorMetalDirect/SymbolEnv.h"
33 #include "compiler/translator/TranslatorMetalDirect/ToposortStructs.h"
34 #include "compiler/translator/TranslatorMetalDirect/TranslatorMetalUtils.h"
35 #include "compiler/translator/TranslatorMetalDirect/WrapMain.h"
36 #include "compiler/translator/tree_ops/InitializeVariables.h"
37 #include "compiler/translator/tree_ops/NameNamelessUniformBuffers.h"
38 #include "compiler/translator/tree_ops/RemoveAtomicCounterBuiltins.h"
39 #include "compiler/translator/tree_ops/RemoveInactiveInterfaceVariables.h"
40 #include "compiler/translator/tree_ops/RewriteAtomicCounters.h"
41 #include "compiler/translator/tree_ops/RewriteCubeMapSamplersAs2DArray.h"
42 #include "compiler/translator/tree_ops/RewriteDfdy.h"
43 #include "compiler/translator/tree_ops/RewriteRowMajorMatrices.h"
44 #include "compiler/translator/tree_ops/RewriteStructSamplers.h"
45 #include "compiler/translator/tree_util/BuiltIn.h"
46 #include "compiler/translator/tree_util/DriverUniform.h"
47 #include "compiler/translator/tree_util/FindFunction.h"
48 #include "compiler/translator/tree_util/FindMain.h"
49 #include "compiler/translator/tree_util/FindSymbolNode.h"
50 #include "compiler/translator/tree_util/IntermNode_util.h"
51 #include "compiler/translator/tree_util/ReplaceClipCullDistanceVariable.h"
52 #include "compiler/translator/tree_util/ReplaceVariable.h"
53 #include "compiler/translator/tree_util/RunAtTheEndOfShader.h"
54 #include "compiler/translator/tree_util/SpecializationConstant.h"
55 #include "compiler/translator/util.h"
56 
57 namespace sh
58 {
59 
60 namespace
61 {
62 
63 constexpr Name kSampleMaskWriteFuncName("writeSampleMask", SymbolType::AngleInternal);
64 constexpr Name kFlippedPointCoordName("flippedPointCoord", SymbolType::UserDefined);
65 constexpr Name kFlippedFragCoordName("flippedFragCoord", SymbolType::UserDefined);
66 
67 constexpr const TVariable kgl_VertexIDMetal(BuiltInId::gl_VertexID,
68                                             ImmutableString("gl_VertexID"),
69                                             SymbolType::BuiltIn,
70                                             TExtension::UNDEFINED,
71                                             StaticType::Get<EbtUInt, EbpHigh, EvqVertexID, 1, 1>());
72 
73 class DeclareStructTypesTraverser : public TIntermTraverser
74 {
75   public:
DeclareStructTypesTraverser(TOutputMSL * outputMSL)76     explicit DeclareStructTypesTraverser(TOutputMSL *outputMSL)
77         : TIntermTraverser(true, false, false), mOutputMSL(outputMSL)
78     {}
79 
visitDeclaration(Visit visit,TIntermDeclaration * node)80     bool visitDeclaration(Visit visit, TIntermDeclaration *node) override
81     {
82         ASSERT(visit == PreVisit);
83         if (!mInGlobalScope)
84         {
85             return false;
86         }
87 
88         const TIntermSequence &sequence = *(node->getSequence());
89         TIntermTyped *declarator        = sequence.front()->getAsTyped();
90         const TType &type               = declarator->getType();
91 
92         if (type.isStructSpecifier())
93         {
94             const TStructure *structure = type.getStruct();
95 
96             // Embedded structs should be parsed away by now.
97             ASSERT(structure->symbolType() != SymbolType::Empty);
98             // outputMSL->writeStructType(structure);
99 
100             TIntermSymbol *symbolNode = declarator->getAsSymbolNode();
101             if (symbolNode && symbolNode->variable().symbolType() == SymbolType::Empty)
102             {
103                 // Remove the struct specifier declaration from the tree so it isn't parsed again.
104                 TIntermSequence emptyReplacement;
105                 mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node,
106                                                 std::move(emptyReplacement));
107             }
108         }
109         // TODO: REMOVE, used to remove 'unsued' warning
110         mOutputMSL = nullptr;
111 
112         return false;
113     }
114 
115   private:
116     TOutputMSL *mOutputMSL;
117 };
118 
119 class DeclareDefaultUniformsTraverser : public TIntermTraverser
120 {
121   public:
DeclareDefaultUniformsTraverser(TInfoSinkBase * sink,ShHashFunction64 hashFunction,NameMap * nameMap)122     DeclareDefaultUniformsTraverser(TInfoSinkBase *sink,
123                                     ShHashFunction64 hashFunction,
124                                     NameMap *nameMap)
125         : TIntermTraverser(true, true, true),
126           mSink(sink),
127           mHashFunction(hashFunction),
128           mNameMap(nameMap),
129           mInDefaultUniform(false)
130     {}
131 
visitDeclaration(Visit visit,TIntermDeclaration * node)132     bool visitDeclaration(Visit visit, TIntermDeclaration *node) override
133     {
134         const TIntermSequence &sequence = *(node->getSequence());
135 
136         // TODO(jmadill): Compound declarations.
137         ASSERT(sequence.size() == 1);
138 
139         TIntermTyped *variable = sequence.front()->getAsTyped();
140         const TType &type      = variable->getType();
141         bool isUniform         = type.getQualifier() == EvqUniform && !type.isInterfaceBlock() &&
142                          !IsOpaqueType(type.getBasicType());
143 
144         if (visit == PreVisit)
145         {
146             if (isUniform)
147             {
148                 (*mSink) << "    " << GetTypeName(type, mHashFunction, mNameMap) << " ";
149                 mInDefaultUniform = true;
150             }
151         }
152         else if (visit == InVisit)
153         {
154             mInDefaultUniform = isUniform;
155         }
156         else if (visit == PostVisit)
157         {
158             if (isUniform)
159             {
160                 (*mSink) << ";\n";
161 
162                 // Remove the uniform declaration from the tree so it isn't parsed again.
163                 TIntermSequence emptyReplacement;
164                 mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node,
165                                                 std::move(emptyReplacement));
166             }
167 
168             mInDefaultUniform = false;
169         }
170         return true;
171     }
172 
visitSymbol(TIntermSymbol * symbol)173     void visitSymbol(TIntermSymbol *symbol) override
174     {
175         if (mInDefaultUniform)
176         {
177             const ImmutableString &name = symbol->variable().name();
178             ASSERT(!name.beginsWith("gl_"));
179             (*mSink) << HashName(&symbol->variable(), mHashFunction, mNameMap)
180                      << ArrayString(symbol->getType());
181         }
182     }
183 
184   private:
185     TInfoSinkBase *mSink;
186     ShHashFunction64 mHashFunction;
187     NameMap *mNameMap;
188     bool mInDefaultUniform;
189 };
190 
191 // Declares a new variable to replace gl_DepthRange, its values are fed from a driver uniform.
ReplaceGLDepthRangeWithDriverUniform(TCompiler * compiler,TIntermBlock * root,const DriverUniform * driverUniforms,TSymbolTable * symbolTable)192 ANGLE_NO_DISCARD bool ReplaceGLDepthRangeWithDriverUniform(TCompiler *compiler,
193                                                            TIntermBlock *root,
194                                                            const DriverUniform *driverUniforms,
195                                                            TSymbolTable *symbolTable)
196 {
197     // Create a symbol reference to "gl_DepthRange"
198     const TVariable *depthRangeVar = static_cast<const TVariable *>(
199         symbolTable->findBuiltIn(ImmutableString("gl_DepthRange"), 0));
200 
201     // ANGLEUniforms.depthRange
202     TIntermBinary *angleEmulatedDepthRangeRef = driverUniforms->getDepthRangeRef();
203 
204     // Use this variable instead of gl_DepthRange everywhere.
205     return ReplaceVariableWithTyped(compiler, root, depthRangeVar, angleEmulatedDepthRangeRef);
206 }
207 
GetMainSequence(TIntermBlock * root)208 TIntermSequence *GetMainSequence(TIntermBlock *root)
209 {
210     TIntermFunctionDefinition *main = FindMain(root);
211     return main->getBody()->getSequence();
212 }
213 
214 // This operation performs Android pre-rotation and y-flip.  For Android (and potentially other
215 // platforms), the device may rotate, such that the orientation of the application is rotated
216 // relative to the native orientation of the device.  This is corrected in part by multiplying
217 // gl_Position by a mat2.
218 // The equations reduce to an expression:
219 //
220 //     gl_Position.xy = gl_Position.xy * preRotation
AppendPreRotation(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,SpecConst * specConst,const DriverUniform * driverUniforms)221 ANGLE_NO_DISCARD bool AppendPreRotation(TCompiler *compiler,
222                                         TIntermBlock *root,
223                                         TSymbolTable *symbolTable,
224                                         SpecConst *specConst,
225                                         const DriverUniform *driverUniforms)
226 {
227     TIntermTyped *preRotationRef = specConst->getPreRotationMatrix();
228     if (!preRotationRef)
229     {
230         preRotationRef = driverUniforms->getPreRotationMatrixRef();
231     }
232     TIntermSymbol *glPos         = new TIntermSymbol(BuiltInVariable::gl_Position());
233     TVector<int> swizzleOffsetXY = {0, 1};
234     TIntermSwizzle *glPosXY      = new TIntermSwizzle(glPos, swizzleOffsetXY);
235 
236     // Create the expression "(gl_Position.xy * preRotation)"
237     TIntermBinary *zRotated = new TIntermBinary(EOpMatrixTimesVector, preRotationRef, glPosXY);
238 
239     // Create the assignment "gl_Position.xy = (gl_Position.xy * preRotation)"
240     TIntermBinary *assignment =
241         new TIntermBinary(TOperator::EOpAssign, glPosXY->deepCopy(), zRotated);
242 
243     // Append the assignment as a statement at the end of the shader.
244     return RunAtTheEndOfShader(compiler, root, assignment, symbolTable);
245 }
246 
247 // Replaces a builtin variable with a version that is rotated and corrects the X and Y coordinates.
RotateAndFlipBuiltinVariable(TCompiler * compiler,TIntermBlock * root,TIntermSequence * insertSequence,TIntermTyped * flipXY,TSymbolTable * symbolTable,const TVariable * builtin,const Name & flippedVariableName,TIntermTyped * pivot,TIntermTyped * fragRotation)248 ANGLE_NO_DISCARD bool RotateAndFlipBuiltinVariable(TCompiler *compiler,
249                                                    TIntermBlock *root,
250                                                    TIntermSequence *insertSequence,
251                                                    TIntermTyped *flipXY,
252                                                    TSymbolTable *symbolTable,
253                                                    const TVariable *builtin,
254                                                    const Name &flippedVariableName,
255                                                    TIntermTyped *pivot,
256                                                    TIntermTyped *fragRotation)
257 {
258     // Create a symbol reference to 'builtin'.
259     TIntermSymbol *builtinRef = new TIntermSymbol(builtin);
260 
261     // Create a swizzle to "builtin.xy"
262     TVector<int> swizzleOffsetXY = {0, 1};
263     TIntermSwizzle *builtinXY    = new TIntermSwizzle(builtinRef, swizzleOffsetXY);
264 
265     // Create a symbol reference to our new variable that will hold the modified builtin.
266     const TType *type = StaticType::GetForVec<EbtFloat>(
267         EvqGlobal, static_cast<unsigned char>(builtin->getType().getNominalSize()));
268     TVariable *replacementVar =
269         new TVariable(symbolTable, flippedVariableName.rawName(), type, SymbolType::AngleInternal);
270     DeclareGlobalVariable(root, replacementVar);
271     TIntermSymbol *flippedBuiltinRef = new TIntermSymbol(replacementVar);
272 
273     // Use this new variable instead of 'builtin' everywhere.
274     if (!ReplaceVariable(compiler, root, builtin, replacementVar))
275     {
276         return false;
277     }
278 
279     // Create the expression "(builtin.xy * fragRotation)"
280     TIntermTyped *rotatedXY;
281     if (fragRotation)
282     {
283         rotatedXY = new TIntermBinary(EOpMatrixTimesVector, fragRotation, builtinXY);
284     }
285     else
286     {
287         // No rotation applied, use original variable.
288         rotatedXY = builtinXY;
289     }
290 
291     // Create the expression "(builtin.xy - pivot) * flipXY + pivot
292     TIntermBinary *removePivot = new TIntermBinary(EOpSub, rotatedXY, pivot);
293     TIntermBinary *inverseXY   = new TIntermBinary(EOpMul, removePivot, flipXY);
294     TIntermBinary *plusPivot   = new TIntermBinary(EOpAdd, inverseXY, pivot->deepCopy());
295 
296     // Create the corrected variable and copy the value of the original builtin.
297     TIntermSequence sequence;
298     sequence.push_back(builtinRef->deepCopy());
299     TIntermAggregate *aggregate =
300         TIntermAggregate::CreateConstructor(builtin->getType(), &sequence);
301     TIntermBinary *assignment = new TIntermBinary(EOpInitialize, flippedBuiltinRef, aggregate);
302 
303     // Create an assignment to the replaced variable's .xy.
304     TIntermSwizzle *correctedXY =
305         new TIntermSwizzle(flippedBuiltinRef->deepCopy(), swizzleOffsetXY);
306     TIntermBinary *assignToY = new TIntermBinary(EOpAssign, correctedXY, plusPivot);
307 
308     // Add this assigment at the beginning of the main function
309     insertSequence->insert(insertSequence->begin(), assignToY);
310     insertSequence->insert(insertSequence->begin(), assignment);
311 
312     return compiler->validateAST(root);
313 }
314 
InsertFragCoordCorrection(TCompiler * compiler,ShCompileOptions compileOptions,TIntermBlock * root,TIntermSequence * insertSequence,TSymbolTable * symbolTable,SpecConst * specConst,const DriverUniform * driverUniforms)315 ANGLE_NO_DISCARD bool InsertFragCoordCorrection(TCompiler *compiler,
316                                                 ShCompileOptions compileOptions,
317                                                 TIntermBlock *root,
318                                                 TIntermSequence *insertSequence,
319                                                 TSymbolTable *symbolTable,
320                                                 SpecConst *specConst,
321                                                 const DriverUniform *driverUniforms)
322 {
323     TIntermTyped *flipXY = specConst->getFlipXY();
324     if (!flipXY)
325     {
326         flipXY = driverUniforms->getFlipXYRef();
327     }
328 
329     TIntermBinary *pivot = specConst->getHalfRenderArea();
330     if (!pivot)
331     {
332         pivot = driverUniforms->getHalfRenderAreaRef();
333     }
334 
335     TIntermTyped *fragRotation = nullptr;
336     if ((compileOptions & SH_ADD_PRE_ROTATION) != 0)
337     {
338         fragRotation = specConst->getFragRotationMatrix();
339         if (!fragRotation)
340         {
341             fragRotation = driverUniforms->getFragRotationMatrixRef();
342         }
343     }
344     return RotateAndFlipBuiltinVariable(compiler, root, insertSequence, flipXY, symbolTable,
345                                         BuiltInVariable::gl_FragCoord(), kFlippedFragCoordName,
346                                         pivot, fragRotation);
347 }
348 
DeclareRightBeforeMain(TIntermBlock & root,const TVariable & var)349 void DeclareRightBeforeMain(TIntermBlock &root, const TVariable &var)
350 {
351     root.insertChildNodes(FindMainIndex(&root), {new TIntermDeclaration{&var}});
352 }
353 
AddFragColorDeclaration(TIntermBlock & root,TSymbolTable & symbolTable)354 void AddFragColorDeclaration(TIntermBlock &root, TSymbolTable &symbolTable)
355 {
356     root.insertChildNodes(FindMainIndex(&root),
357                           TIntermSequence{new TIntermDeclaration{BuiltInVariable::gl_FragColor()}});
358 }
359 
AddFragDepthDeclaration(TIntermBlock & root,TSymbolTable & symbolTable)360 void AddFragDepthDeclaration(TIntermBlock &root, TSymbolTable &symbolTable)
361 {
362     root.insertChildNodes(FindMainIndex(&root),
363                           TIntermSequence{new TIntermDeclaration{BuiltInVariable::gl_FragDepth()}});
364 }
365 
AddFragDepthEXTDeclaration(TCompiler & compiler,TIntermBlock & root,TSymbolTable & symbolTable)366 void AddFragDepthEXTDeclaration(TCompiler &compiler, TIntermBlock &root, TSymbolTable &symbolTable)
367 {
368     const TIntermSymbol *glFragDepthExt = FindSymbolNode(&root, ImmutableString("gl_FragDepthEXT"));
369     ASSERT(glFragDepthExt);
370     // Replace gl_FragData with our globally defined fragdata.
371     if (!ReplaceVariable(&compiler, &root, &(glFragDepthExt->variable()),
372                          BuiltInVariable::gl_FragDepth()))
373     {
374         return;
375     }
376     AddFragDepthDeclaration(root, symbolTable);
377 }
AddSampleMaskDeclaration(TIntermBlock & root,TSymbolTable & symbolTable)378 void AddSampleMaskDeclaration(TIntermBlock &root, TSymbolTable &symbolTable)
379 {
380     TType *gl_SampleMaskType = new TType(EbtUInt, EbpHigh, EvqSampleMask, 1, 1);
381     const TVariable *gl_SampleMask =
382         new TVariable(&symbolTable, ImmutableString("gl_SampleMask"), gl_SampleMaskType,
383                       SymbolType::BuiltIn, TExtension::UNDEFINED);
384     root.insertChildNodes(FindMainIndex(&root),
385                           TIntermSequence{new TIntermDeclaration{gl_SampleMask}});
386 }
387 
AddFragDataDeclaration(TCompiler & compiler,TIntermBlock & root)388 ANGLE_NO_DISCARD bool AddFragDataDeclaration(TCompiler &compiler, TIntermBlock &root)
389 {
390     TSymbolTable &symbolTable = compiler.getSymbolTable();
391     const int maxDrawBuffers  = compiler.getResources().MaxDrawBuffers;
392     TType *gl_FragDataType    = new TType(EbtFloat, EbpMedium, EvqFragData, 4, 1);
393     std::vector<const TVariable *> glFragDataSlots;
394     TIntermSequence declareGLFragdataSequence;
395 
396     // Create gl_FragData_0,1,2,3
397     for (int i = 0; i < maxDrawBuffers; i++)
398     {
399         ImmutableStringBuilder builder(strlen("gl_FragData_") + 2);
400         builder << "gl_FragData_";
401         builder.appendDecimal(i);
402         const TVariable *glFragData =
403             new TVariable(&symbolTable, builder, gl_FragDataType, SymbolType::AngleInternal,
404                           TExtension::UNDEFINED);
405         glFragDataSlots.push_back(glFragData);
406         declareGLFragdataSequence.push_back(new TIntermDeclaration{glFragData});
407     }
408     root.insertChildNodes(FindMainIndex(&root), declareGLFragdataSequence);
409 
410     // Create an internal gl_FragData array type, compatible with indexing syntax.
411     TType *gl_FragDataTypeArray = new TType(EbtFloat, EbpMedium, EvqGlobal, 4, 1);
412     gl_FragDataTypeArray->makeArray(maxDrawBuffers);
413     const TVariable *glFragDataGlobal = new TVariable(&symbolTable, ImmutableString("gl_FragData"),
414                                                       gl_FragDataTypeArray, SymbolType::BuiltIn);
415 
416     DeclareGlobalVariable(&root, glFragDataGlobal);
417     const TIntermSymbol *originalGLFragData = FindSymbolNode(&root, ImmutableString("gl_FragData"));
418     ASSERT(originalGLFragData);
419 
420     // Replace gl_FragData() with our globally defined fragdata
421     if (!ReplaceVariable(&compiler, &root, &(originalGLFragData->variable()), glFragDataGlobal))
422     {
423         return false;
424     }
425 
426     // Assign each array attribute to an output
427     TIntermBlock *insertSequence = new TIntermBlock();
428     for (int i = 0; i < maxDrawBuffers; i++)
429     {
430         TIntermTyped *glFragDataSlot         = new TIntermSymbol(glFragDataSlots[i]);
431         TIntermTyped *glFragDataGlobalSymbol = new TIntermSymbol(glFragDataGlobal);
432         auto &access                         = AccessIndex(*glFragDataGlobalSymbol, i);
433         TIntermBinary *assignment =
434             new TIntermBinary(TOperator::EOpAssign, glFragDataSlot, &access);
435         insertSequence->appendStatement(assignment);
436     }
437     return RunAtTheEndOfShader(&compiler, &root, insertSequence, &symbolTable);
438 }
439 
EmulateInstanceID(TCompiler & compiler,TIntermBlock & root,DriverUniform & driverUniforms)440 ANGLE_NO_DISCARD bool EmulateInstanceID(TCompiler &compiler,
441                                         TIntermBlock &root,
442                                         DriverUniform &driverUniforms)
443 {
444     TIntermBinary *emuInstanceID = driverUniforms.getEmulatedInstanceId();
445     const TVariable *instanceID  = BuiltInVariable::gl_InstanceIndex();
446     return ReplaceVariableWithTyped(&compiler, &root, instanceID, emuInstanceID);
447 }
448 
449 // Unlike Vulkan having auto viewport flipping extension, in Metal we have to flip gl_Position.y
450 // manually.
451 // This operation performs flipping the gl_Position.y using this expression:
452 // gl_Position.y = gl_Position.y * negViewportScaleY
AppendVertexShaderPositionYCorrectionToMain(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,TIntermTyped * negFlipY)453 ANGLE_NO_DISCARD bool AppendVertexShaderPositionYCorrectionToMain(TCompiler *compiler,
454                                                                   TIntermBlock *root,
455                                                                   TSymbolTable *symbolTable,
456                                                                   TIntermTyped *negFlipY)
457 {
458     // Create a symbol reference to "gl_Position"
459     const TVariable *position  = BuiltInVariable::gl_Position();
460     TIntermSymbol *positionRef = new TIntermSymbol(position);
461 
462     // Create a swizzle to "gl_Position.y"
463     TVector<int> swizzleOffsetY;
464     swizzleOffsetY.push_back(1);
465     TIntermSwizzle *positionY = new TIntermSwizzle(positionRef, swizzleOffsetY);
466 
467     // Create the expression "gl_Position.y * negFlipY"
468     TIntermBinary *inverseY = new TIntermBinary(EOpMul, positionY->deepCopy(), negFlipY);
469 
470     // Create the assignment "gl_Position.y = gl_Position.y * negViewportScaleY
471     TIntermTyped *positionYLHS = positionY->deepCopy();
472     TIntermBinary *assignment  = new TIntermBinary(TOperator::EOpAssign, positionYLHS, inverseY);
473 
474     // Append the assignment as a statement at the end of the shader.
475     return RunAtTheEndOfShader(compiler, root, assignment, symbolTable);
476 }
477 }  // namespace
478 
TranslatorMetalDirect(sh::GLenum type,ShShaderSpec spec,ShShaderOutput output)479 TranslatorMetalDirect::TranslatorMetalDirect(sh::GLenum type,
480                                              ShShaderSpec spec,
481                                              ShShaderOutput output)
482     : TCompiler(type, spec, output)
483 {}
484 
485 // Add sample_mask writing to main, guarded by the function constant
486 // kCoverageMaskEnabledName
insertSampleMaskWritingLogic(TIntermBlock & root,DriverUniform & driverUniforms)487 ANGLE_NO_DISCARD bool TranslatorMetalDirect::insertSampleMaskWritingLogic(
488     TIntermBlock &root,
489     DriverUniform &driverUniforms)
490 {
491     // This transformation leaves the tree in an inconsistent state by using a variable that's
492     // defined in text, outside of the knowledge of the AST.
493     mValidateASTOptions.validateVariableReferences = false;
494 
495     TSymbolTable *symbolTable = &getSymbolTable();
496 
497     // Create kCoverageMaskEnabled and kSampleMaskWriteFuncName variable references.
498     TType *boolType = new TType(EbtBool);
499     boolType->setQualifier(EvqConst);
500     TVariable *coverageMaskEnabledVar =
501         new TVariable(symbolTable, sh::ImmutableString(sh::mtl::kCoverageMaskEnabledConstName),
502                       boolType, SymbolType::AngleInternal);
503 
504     TFunction *sampleMaskWriteFunc = new TFunction(symbolTable, kSampleMaskWriteFuncName.rawName(),
505                                                    kSampleMaskWriteFuncName.symbolType(),
506                                                    StaticType::GetBasic<EbtVoid>(), false);
507 
508     TType *uintType = new TType(EbtUInt);
509     TVariable *maskArg =
510         new TVariable(symbolTable, ImmutableString("mask"), uintType, SymbolType::AngleInternal);
511     sampleMaskWriteFunc->addParameter(maskArg);
512 
513     TVariable *gl_SampleMaskArg = new TVariable(symbolTable, ImmutableString("gl_SampleMask"),
514                                                 uintType, SymbolType::AngleInternal);
515     sampleMaskWriteFunc->addParameter(gl_SampleMaskArg);
516 
517     // Insert this MSL code to the end of main() in the shader
518     // if (ANGLECoverageMaskEnabled)
519     // {
520     //      ANGLE_writeSampleMask(ANGLE_angleUniforms.coverageMask,
521     //      ANGLE_fragmentOut.gl_SampleMask);
522     // }
523     TIntermSequence *args       = new TIntermSequence;
524     TIntermBinary *coverageMask = driverUniforms.getCoverageMask();
525     args->push_back(coverageMask);
526     const TIntermSymbol *gl_SampleMask = FindSymbolNode(&root, ImmutableString("gl_SampleMask"));
527     args->push_back(gl_SampleMask->deepCopy());
528 
529     TIntermAggregate *callSampleMaskWriteFunc =
530         TIntermAggregate::CreateFunctionCall(*sampleMaskWriteFunc, args);
531     TIntermBlock *callBlock = new TIntermBlock;
532     callBlock->appendStatement(callSampleMaskWriteFunc);
533 
534     TIntermSymbol *coverageMaskEnabled = new TIntermSymbol(coverageMaskEnabledVar);
535     TIntermIfElse *ifCall              = new TIntermIfElse(coverageMaskEnabled, callBlock, nullptr);
536     return RunAtTheEndOfShader(this, &root, ifCall, symbolTable);
537 }
538 
insertRasterizationDiscardLogic(TIntermBlock & root)539 ANGLE_NO_DISCARD bool TranslatorMetalDirect::insertRasterizationDiscardLogic(TIntermBlock &root)
540 {
541     // This transformation leaves the tree in an inconsistent state by using a variable that's
542     // defined in text, outside of the knowledge of the AST.
543     mValidateASTOptions.validateVariableReferences = false;
544 
545     TSymbolTable *symbolTable = &getSymbolTable();
546 
547     TType *boolType = new TType(EbtBool);
548     boolType->setQualifier(EvqConst);
549     TVariable *discardEnabledVar =
550         new TVariable(symbolTable, sh::ImmutableString(sh::mtl::kRasterizerDiscardEnabledConstName),
551                       boolType, SymbolType::AngleInternal);
552 
553     const TVariable *position  = BuiltInVariable::gl_Position();
554     TIntermSymbol *positionRef = new TIntermSymbol(position);
555 
556     // Create vec4(-3, -3, -3, 1):
557     auto vec4Type             = new TType(EbtFloat, 4);
558     TIntermSequence *vec4Args = new TIntermSequence();
559     vec4Args->push_back(CreateFloatNode(-3.0f));
560     vec4Args->push_back(CreateFloatNode(-3.0f));
561     vec4Args->push_back(CreateFloatNode(-3.0f));
562     vec4Args->push_back(CreateFloatNode(1.0f));
563     TIntermAggregate *constVarConstructor =
564         TIntermAggregate::CreateConstructor(*vec4Type, vec4Args);
565 
566     // Create the assignment "gl_Position = vec4(-3, -3, -3, 1)"
567     TIntermBinary *assignment =
568         new TIntermBinary(TOperator::EOpAssign, positionRef->deepCopy(), constVarConstructor);
569 
570     TIntermBlock *discardBlock = new TIntermBlock;
571     discardBlock->appendStatement(assignment);
572 
573     TIntermSymbol *discardEnabled = new TIntermSymbol(discardEnabledVar);
574     TIntermIfElse *ifCall         = new TIntermIfElse(discardEnabled, discardBlock, nullptr);
575 
576     return RunAtTheEndOfShader(this, &root, ifCall, symbolTable);
577 }
578 
579 // Metal needs to inverse the depth if depthRange is is reverse order, i.e. depth near > depth far
580 // This is achieved by multiply the depth value with scale value stored in
581 // driver uniform's depthRange.reserved
transformDepthBeforeCorrection(TIntermBlock * root,const DriverUniform * driverUniforms)582 bool TranslatorMetalDirect::transformDepthBeforeCorrection(TIntermBlock *root,
583                                                            const DriverUniform *driverUniforms)
584 {
585     // Create a symbol reference to "gl_Position"
586     const TVariable *position  = BuiltInVariable::gl_Position();
587     TIntermSymbol *positionRef = new TIntermSymbol(position);
588 
589     // Create a swizzle to "gl_Position.z"
590     TVector<int> swizzleOffsetZ = {2};
591     TIntermSwizzle *positionZ   = new TIntermSwizzle(positionRef, swizzleOffsetZ);
592 
593     // Create a ref to "depthRange.reserved"
594     TIntermBinary *viewportZScale = driverUniforms->getDepthRangeReservedFieldRef();
595 
596     // Create the expression "gl_Position.z * depthRange.reserved".
597     TIntermBinary *zScale = new TIntermBinary(EOpMul, positionZ->deepCopy(), viewportZScale);
598 
599     // Create the assignment "gl_Position.z = gl_Position.z * depthRange.reserved"
600     TIntermTyped *positionZLHS = positionZ->deepCopy();
601     TIntermBinary *assignment  = new TIntermBinary(TOperator::EOpAssign, positionZLHS, zScale);
602 
603     // Append the assignment as a statement at the end of the shader.
604     return RunAtTheEndOfShader(this, root, assignment, &getSymbolTable());
605 }
606 
GetMslKeywords()607 static std::set<ImmutableString> GetMslKeywords()
608 {
609     std::set<ImmutableString> keywords;
610 
611     keywords.emplace("alignas");
612     keywords.emplace("alignof");
613     keywords.emplace("as_type");
614     keywords.emplace("auto");
615     keywords.emplace("catch");
616     keywords.emplace("char");
617     keywords.emplace("class");
618     keywords.emplace("const_cast");
619     keywords.emplace("constant");
620     keywords.emplace("constexpr");
621     keywords.emplace("decltype");
622     keywords.emplace("delete");
623     keywords.emplace("device");
624     keywords.emplace("dynamic_cast");
625     keywords.emplace("enum");
626     keywords.emplace("explicit");
627     keywords.emplace("export");
628     keywords.emplace("extern");
629     keywords.emplace("fragment");
630     keywords.emplace("friend");
631     keywords.emplace("goto");
632     keywords.emplace("half");
633     keywords.emplace("inline");
634     keywords.emplace("int16_t");
635     keywords.emplace("int32_t");
636     keywords.emplace("int64_t");
637     keywords.emplace("int8_t");
638     keywords.emplace("kernel");
639     keywords.emplace("long");
640     keywords.emplace("main0");
641     keywords.emplace("metal");
642     keywords.emplace("mutable");
643     keywords.emplace("namespace");
644     keywords.emplace("new");
645     keywords.emplace("noexcept");
646     keywords.emplace("nullptr_t");
647     keywords.emplace("nullptr");
648     keywords.emplace("operator");
649     keywords.emplace("override");
650     keywords.emplace("private");
651     keywords.emplace("protected");
652     keywords.emplace("ptrdiff_t");
653     keywords.emplace("public");
654     keywords.emplace("ray_data");
655     keywords.emplace("register");
656     keywords.emplace("short");
657     keywords.emplace("signed");
658     keywords.emplace("size_t");
659     keywords.emplace("sizeof");
660     keywords.emplace("stage_in");
661     keywords.emplace("static_assert");
662     keywords.emplace("static_cast");
663     keywords.emplace("static");
664     keywords.emplace("template");
665     keywords.emplace("this");
666     keywords.emplace("thread_local");
667     keywords.emplace("thread");
668     keywords.emplace("threadgroup_imageblock");
669     keywords.emplace("threadgroup");
670     keywords.emplace("throw");
671     keywords.emplace("try");
672     keywords.emplace("typedef");
673     keywords.emplace("typeid");
674     keywords.emplace("typename");
675     keywords.emplace("uchar");
676     keywords.emplace("uint16_t");
677     keywords.emplace("uint32_t");
678     keywords.emplace("uint64_t");
679     keywords.emplace("uint8_t");
680     keywords.emplace("union");
681     keywords.emplace("unsigned");
682     keywords.emplace("ushort");
683     keywords.emplace("using");
684     keywords.emplace("vertex");
685     keywords.emplace("virtual");
686     keywords.emplace("volatile");
687     keywords.emplace("wchar_t");
688 
689     return keywords;
690 }
691 
metalShaderTypeFromGLSL(sh::GLenum shaderType)692 static inline MetalShaderType metalShaderTypeFromGLSL(sh::GLenum shaderType)
693 {
694     switch (shaderType)
695     {
696         case GL_VERTEX_SHADER:
697             return MetalShaderType::Vertex;
698         case GL_FRAGMENT_SHADER:
699             return MetalShaderType::Fragment;
700         case GL_COMPUTE_SHADER:
701             ASSERT(0 && "compute shaders not currently supported");
702             return MetalShaderType::Compute;
703         default:
704             ASSERT(0 && "Invalid shader type.");
705             return MetalShaderType::None;
706     }
707 }
708 
translateImpl(TInfoSinkBase & sink,TIntermBlock * root,ShCompileOptions compileOptions,PerformanceDiagnostics *,SpecConst * specConst,DriverUniform * driverUniforms)709 bool TranslatorMetalDirect::translateImpl(TInfoSinkBase &sink,
710                                           TIntermBlock *root,
711                                           ShCompileOptions compileOptions,
712                                           PerformanceDiagnostics * /*perfDiagnostics*/,
713                                           SpecConst *specConst,
714                                           DriverUniform *driverUniforms)
715 {
716     TSymbolTable &symbolTable = getSymbolTable();
717     IdGen idGen;
718     ProgramPreludeConfig ppc(metalShaderTypeFromGLSL(getShaderType()));
719 
720     if (!WrapMain(*this, idGen, *root))
721     {
722         return false;
723     }
724 
725     // Remove declarations of inactive shader interface variables so glslang wrapper doesn't need to
726     // replace them.  Note: this is done before extracting samplers from structs, as removing such
727     // inactive samplers is not yet supported.  Note also that currently, CollectVariables marks
728     // every field of an active uniform that's of struct type as active, i.e. no extracted sampler
729     // is inactive.
730     if (!RemoveInactiveInterfaceVariables(this, root, &getSymbolTable(), getAttributes(),
731                                           getInputVaryings(), getOutputVariables(), getUniforms(),
732                                           getInterfaceBlocks()))
733     {
734         return false;
735     }
736 
737     // Write out default uniforms into a uniform block assigned to a specific set/binding.
738     int defaultUniformCount           = 0;
739     int aggregateTypesUsedForUniforms = 0;
740     int atomicCounterCount            = 0;
741     for (const auto &uniform : getUniforms())
742     {
743         if (!uniform.isBuiltIn() && uniform.active && !gl::IsOpaqueType(uniform.type))
744         {
745             ++defaultUniformCount;
746         }
747 
748         if (uniform.isStruct() || uniform.isArrayOfArrays())
749         {
750             ++aggregateTypesUsedForUniforms;
751         }
752 
753         if (uniform.active && gl::IsAtomicCounterType(uniform.type))
754         {
755             ++atomicCounterCount;
756         }
757     }
758 
759     if (aggregateTypesUsedForUniforms > 0)
760     {
761         if (!NameEmbeddedStructUniformsMetal(this, root, &symbolTable))
762         {
763             return false;
764         }
765 
766         int removedUniformsCount;
767         bool rewriteStructSamplersResult =
768             RewriteStructSamplers(this, root, &symbolTable, &removedUniformsCount);
769 
770         if (!rewriteStructSamplersResult)
771         {
772             return false;
773         }
774         defaultUniformCount -= removedUniformsCount;
775     }
776 
777     if (compileOptions & SH_EMULATE_SEAMFUL_CUBE_MAP_SAMPLING)
778     {
779         if (!RewriteCubeMapSamplersAs2DArray(this, root, &symbolTable,
780                                              getShaderType() == GL_FRAGMENT_SHADER))
781         {
782             return false;
783         }
784     }
785 
786     if (getShaderType() == GL_COMPUTE_SHADER)
787     {
788         driverUniforms->addComputeDriverUniformsToShader(root, &getSymbolTable());
789     }
790     else
791     {
792         driverUniforms->addGraphicsDriverUniformsToShader(root, &getSymbolTable());
793     }
794 
795     if (atomicCounterCount > 0)
796     {
797         const TIntermTyped *acbBufferOffsets = driverUniforms->getAbcBufferOffsets();
798         if (!RewriteAtomicCounters(this, root, &symbolTable, acbBufferOffsets))
799         {
800             return false;
801         }
802     }
803     else if (getShaderVersion() >= 310)
804     {
805         // Vulkan doesn't support Atomic Storage as a Storage Class, but we've seen
806         // cases where builtins are using it even with no active atomic counters.
807         // This pass simply removes those builtins in that scenario.
808         if (!RemoveAtomicCounterBuiltins(this, root))
809         {
810             return false;
811         }
812     }
813 
814     if (getShaderType() != GL_COMPUTE_SHADER)
815     {
816         if (!ReplaceGLDepthRangeWithDriverUniform(this, root, driverUniforms, &getSymbolTable()))
817         {
818             return false;
819         }
820     }
821 
822     {
823         bool usesInstanceId = false;
824         bool usesVertexId   = false;
825         for (const ShaderVariable &var : mAttributes)
826         {
827             if (var.isBuiltIn())
828             {
829                 if (var.name == "gl_InstanceID")
830                 {
831                     usesInstanceId = true;
832                 }
833                 if (var.name == "gl_VertexID")
834                 {
835                     usesVertexId = true;
836                 }
837             }
838         }
839 
840         if (usesInstanceId)
841         {
842             root->insertChildNodes(
843                 FindMainIndex(root),
844                 TIntermSequence{new TIntermDeclaration{BuiltInVariable::gl_InstanceID()}});
845         }
846         if (usesVertexId)
847         {
848             if (!ReplaceVariable(this, root, BuiltInVariable::gl_VertexID(), &kgl_VertexIDMetal))
849             {
850                 return false;
851             }
852             DeclareRightBeforeMain(*root, kgl_VertexIDMetal);
853         }
854     }
855     SymbolEnv symbolEnv(*this, *root);
856     // Declare gl_FragColor and gl_FragData as webgl_FragColor and webgl_FragData
857     // if it's core profile shaders and they are used.
858     if (getShaderType() == GL_FRAGMENT_SHADER)
859     {
860         bool usesPointCoord  = false;
861         bool usesFragCoord   = false;
862         bool usesFrontFacing = false;
863         for (const ShaderVariable &inputVarying : mInputVaryings)
864         {
865             if (inputVarying.isBuiltIn())
866             {
867                 if (inputVarying.name == "gl_PointCoord")
868                 {
869                     usesPointCoord = true;
870                 }
871                 else if (inputVarying.name == "gl_FragCoord")
872                 {
873                     usesFragCoord = true;
874                 }
875                 else if (inputVarying.name == "gl_FrontFacing")
876                 {
877                     usesFrontFacing = true;
878                 }
879             }
880         }
881 
882         bool usesFragColor    = false;
883         bool usesFragData     = false;
884         bool usesFragDepth    = false;
885         bool usesFragDepthEXT = false;
886         for (const ShaderVariable &outputVarying : mOutputVariables)
887         {
888             if (outputVarying.isBuiltIn())
889             {
890                 if (outputVarying.name == "gl_FragColor")
891                 {
892                     usesFragColor = true;
893                 }
894                 else if (outputVarying.name == "gl_FragData")
895                 {
896                     usesFragData = true;
897                 }
898                 else if (outputVarying.name == "gl_FragDepth")
899                 {
900                     usesFragDepth = true;
901                 }
902                 else if (outputVarying.name == "gl_FragDepthEXT")
903                 {
904                     usesFragDepthEXT = true;
905                 }
906             }
907         }
908 
909         if (usesFragColor && usesFragData)
910         {
911             return false;
912         }
913 
914         if (usesFragColor)
915         {
916             AddFragColorDeclaration(*root, symbolTable);
917         }
918 
919         if (usesFragData)
920         {
921             if (!AddFragDataDeclaration(*this, *root))
922             {
923                 return false;
924             }
925         }
926         if (usesFragDepth)
927         {
928             AddFragDepthDeclaration(*root, symbolTable);
929         }
930         else if (usesFragDepthEXT)
931         {
932             AddFragDepthEXTDeclaration(*this, *root, symbolTable);
933         }
934 
935         // Always add sample_mask. It will be guarded by a function constant decided at runtime.
936         bool usesSampleMask = true;
937         if (usesSampleMask)
938         {
939             AddSampleMaskDeclaration(*root, symbolTable);
940         }
941 
942         bool usePreRotation = compileOptions & SH_ADD_PRE_ROTATION;
943 
944         if (usesPointCoord)
945         {
946             TIntermTyped *flipNegXY = specConst->getNegFlipXY();
947             if (!flipNegXY)
948             {
949                 flipNegXY = driverUniforms->getNegFlipXYRef();
950             }
951             TIntermConstantUnion *pivot = CreateFloatNode(0.5f);
952             TIntermTyped *fragRotation  = nullptr;
953             if (usePreRotation)
954             {
955                 fragRotation = specConst->getFragRotationMatrix();
956                 if (!fragRotation)
957                 {
958                     fragRotation = driverUniforms->getFragRotationMatrixRef();
959                 }
960             }
961             if (!RotateAndFlipBuiltinVariable(this, root, GetMainSequence(root), flipNegXY,
962                                               &getSymbolTable(), BuiltInVariable::gl_PointCoord(),
963                                               kFlippedPointCoordName, pivot, fragRotation))
964             {
965                 return false;
966             }
967             DeclareRightBeforeMain(*root, *BuiltInVariable::gl_PointCoord());
968         }
969 
970         if (usesFragCoord)
971         {
972             if (!InsertFragCoordCorrection(this, compileOptions, root, GetMainSequence(root),
973                                            &getSymbolTable(), specConst, driverUniforms))
974             {
975                 return false;
976             }
977             DeclareRightBeforeMain(*root, *BuiltInVariable::gl_FragCoord());
978         }
979 
980         if (!RewriteDfdy(this, compileOptions, root, getSymbolTable(), getShaderVersion(),
981                          specConst, driverUniforms))
982         {
983             return false;
984         }
985 
986         if (usesFrontFacing)
987         {
988             DeclareRightBeforeMain(*root, *BuiltInVariable::gl_FrontFacing());
989         }
990 
991         EmitEarlyFragmentTestsGLSL(*this, sink);
992     }
993     else if (getShaderType() == GL_VERTEX_SHADER)
994     {
995         DeclareRightBeforeMain(*root, *BuiltInVariable::gl_Position());
996 
997         if (FindSymbolNode(root, BuiltInVariable::gl_PointSize()->name()))
998         {
999             DeclareRightBeforeMain(*root, *BuiltInVariable::gl_PointSize());
1000         }
1001 
1002         if (FindSymbolNode(root, BuiltInVariable::gl_VertexIndex()->name()))
1003         {
1004             if (!ReplaceVariable(this, root, BuiltInVariable::gl_VertexIndex(), &kgl_VertexIDMetal))
1005             {
1006                 return false;
1007             }
1008             DeclareRightBeforeMain(*root, kgl_VertexIDMetal);
1009         }
1010 
1011         // Search for the gl_ClipDistance usage, if its used, we need to do some replacements.
1012         bool useClipDistance = false;
1013         for (const ShaderVariable &outputVarying : mOutputVaryings)
1014         {
1015             if (outputVarying.name == "gl_ClipDistance")
1016             {
1017                 useClipDistance = true;
1018                 break;
1019             }
1020         }
1021 
1022         if (useClipDistance &&
1023             !ReplaceClipDistanceAssignments(this, root, &getSymbolTable(), getShaderType(),
1024                                             driverUniforms->getClipDistancesEnabled()))
1025         {
1026             return false;
1027         }
1028 
1029         if (!transformDepthBeforeCorrection(root, driverUniforms))
1030         {
1031             return false;
1032         }
1033 
1034         if ((compileOptions & SH_ADD_PRE_ROTATION) != 0 &&
1035             !AppendPreRotation(this, root, &getSymbolTable(), specConst, driverUniforms))
1036         {
1037             return false;
1038         }
1039     }
1040     else if (getShaderType() == GL_GEOMETRY_SHADER)
1041     {
1042         WriteGeometryShaderLayoutQualifiers(
1043             sink, getGeometryShaderInputPrimitiveType(), getGeometryShaderInvocations(),
1044             getGeometryShaderOutputPrimitiveType(), getGeometryShaderMaxVertices());
1045     }
1046     else
1047     {
1048         ASSERT(getShaderType() == GL_COMPUTE_SHADER);
1049         EmitWorkGroupSizeGLSL(*this, sink);
1050     }
1051 
1052     if (getShaderType() == GL_VERTEX_SHADER)
1053     {
1054         auto negFlipY = driverUniforms->getNegFlipYRef();
1055 
1056         if (mEmulatedInstanceID)
1057         {
1058             if (!EmulateInstanceID(*this, *root, *driverUniforms))
1059             {
1060                 return false;
1061             }
1062         }
1063 
1064         if (!AppendVertexShaderPositionYCorrectionToMain(this, root, &getSymbolTable(), negFlipY))
1065         {
1066             return false;
1067         }
1068 
1069         if (!insertRasterizationDiscardLogic(*root))
1070         {
1071             return false;
1072         }
1073     }
1074     else if (getShaderType() == GL_FRAGMENT_SHADER)
1075     {
1076         if (!insertSampleMaskWritingLogic(*root, *driverUniforms))
1077         {
1078             return false;
1079         }
1080     }
1081 
1082     if (!validateAST(root))
1083     {
1084         return false;
1085     }
1086 
1087     if (!RewriteKeywords(*this, *root, idGen, GetMslKeywords()))
1088     {
1089         return false;
1090     }
1091 
1092     if (!ReduceInterfaceBlocks(*this, *root, idGen))
1093     {
1094         return false;
1095     }
1096 
1097     if (!SeparateCompoundStructDeclarations(*this, idGen, *root))
1098     {
1099         return false;
1100     }
1101     // This is the largest size required to pass all the tests in
1102     // (dEQP-GLES3.functional.shaders.large_constant_arrays)
1103     // This value could in principle be smaller.
1104     const size_t hoistThresholdSize = 256;
1105     if (!HoistConstants(*this, *root, idGen, hoistThresholdSize))
1106     {
1107         return false;
1108     }
1109 
1110     Invariants invariants;
1111     if (!RewriteGlobalQualifierDecls(*this, *root, invariants))
1112     {
1113         return false;
1114     }
1115 
1116     if (!AddExplicitTypeCasts(*this, *root, symbolEnv))
1117     {
1118         return false;
1119     }
1120 
1121     // The RewritePipelines phase leaves the tree in an inconsistent state by inserting
1122     // references to structures like "ANGLE_TextureEnv<metal::texture2d<float>>" which are
1123     // defined in text (in ProgramPrelude), outside of the knowledge of the AST.
1124     mValidateASTOptions.validateStructUsage = false;
1125 
1126     PipelineStructs pipelineStructs;
1127     if (!RewritePipelines(*this, *root, idGen, *driverUniforms, symbolEnv, invariants,
1128                           pipelineStructs))
1129     {
1130         return false;
1131     }
1132     if (getShaderType() == GL_VERTEX_SHADER)
1133     {
1134         if (!IntroduceVertexAndInstanceIndex(*this, *root))
1135         {
1136             return false;
1137         }
1138     }
1139     if (!SeparateCompoundExpressions(*this, symbolEnv, idGen, *root))
1140     {
1141         return false;
1142     }
1143 
1144     if (!RewriteCaseDeclarations(*this, *root))
1145     {
1146         return false;
1147     }
1148 
1149     if (!RewriteUnaddressableReferences(*this, *root, symbolEnv))
1150     {
1151         return false;
1152     }
1153 
1154     if (!RewriteOutArgs(*this, *root, symbolEnv))
1155     {
1156         return false;
1157     }
1158     if (!FixTypeConstructors(*this, symbolEnv, *root))
1159     {
1160         return false;
1161     }
1162     if (!ToposortStructs(*this, symbolEnv, *root, ppc))
1163     {
1164         return false;
1165     }
1166     if (!EmitMetal(*this, *root, idGen, pipelineStructs, invariants, symbolEnv, ppc,
1167                    &getSymbolTable()))
1168     {
1169         return false;
1170     }
1171 
1172     ASSERT(validateAST(root));
1173 
1174     return true;
1175 }
1176 
translate(TIntermBlock * root,ShCompileOptions compileOptions,PerformanceDiagnostics * perfDiagnostics)1177 bool TranslatorMetalDirect::translate(TIntermBlock *root,
1178                                       ShCompileOptions compileOptions,
1179                                       PerformanceDiagnostics *perfDiagnostics)
1180 {
1181     if (!root)
1182     {
1183         return false;
1184     }
1185 
1186     TInfoSinkBase &sink = getInfoSink().obj;
1187 
1188     if ((compileOptions & SH_REWRITE_ROW_MAJOR_MATRICES) != 0 && getShaderVersion() >= 300)
1189     {
1190         // "Make sure every uniform buffer variable has a name.  The following transformation relies
1191         // on this." This pass was removed in e196bc85ac2dda0e9f6664cfc2eca0029e33d2d1, but
1192         // currently finding it still necessary for MSL.
1193         // TODO(jcunningham): Look into removing the NameNamelessUniformBuffers and fixing the root
1194         // cause in RewriteRowMajorMatrices
1195         if (!NameNamelessUniformBuffers(this, root, &getSymbolTable()))
1196         {
1197             return false;
1198         }
1199         if (!RewriteRowMajorMatrices(this, root, &getSymbolTable()))
1200         {
1201             return false;
1202         }
1203     }
1204 
1205     bool precisionEmulation = false;
1206     if (!emulatePrecisionIfNeeded(root, sink, &precisionEmulation, SH_SPIRV_VULKAN_OUTPUT))
1207     {
1208         return false;
1209     }
1210 
1211     SpecConst specConst(&getSymbolTable(), compileOptions, getShaderType());
1212     DriverUniformExtended driverUniforms(DriverUniformMode::Structure);
1213     if (!translateImpl(sink, root, compileOptions, perfDiagnostics, &specConst, &driverUniforms))
1214     {
1215         return false;
1216     }
1217 
1218     return true;
1219 }
shouldFlattenPragmaStdglInvariantAll()1220 bool TranslatorMetalDirect::shouldFlattenPragmaStdglInvariantAll()
1221 {
1222     // Not neccesary for MSL transformation.
1223     return false;
1224 }
1225 
1226 }  // namespace sh
1227