• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include <cstring>
8 #include <unordered_map>
9 #include <unordered_set>
10 
11 #include "compiler/translator/TranslatorMetalDirect.h"
12 #include "compiler/translator/TranslatorMetalDirect/AstHelpers.h"
13 #include "compiler/translator/TranslatorMetalDirect/DiscoverDependentFunctions.h"
14 #include "compiler/translator/TranslatorMetalDirect/IdGen.h"
15 #include "compiler/translator/TranslatorMetalDirect/MapSymbols.h"
16 #include "compiler/translator/TranslatorMetalDirect/Pipeline.h"
17 #include "compiler/translator/TranslatorMetalDirect/RewritePipelines.h"
18 #include "compiler/translator/TranslatorMetalDirect/SymbolEnv.h"
19 #include "compiler/translator/tree_ops/PruneNoOps.h"
20 #include "compiler/translator/tree_util/DriverUniform.h"
21 #include "compiler/translator/tree_util/FindMain.h"
22 #include "compiler/translator/tree_util/IntermRebuild.h"
23 #include "compiler/translator/tree_util/IntermTraverse.h"
24 using namespace sh;
25 
26 ////////////////////////////////////////////////////////////////////////////////
27 
28 namespace
29 {
30 
IsVariableInvariant(const std::vector<sh::ShaderVariable> & mVars,const ImmutableString & name)31 bool IsVariableInvariant(const std::vector<sh::ShaderVariable> &mVars, const ImmutableString &name)
32 {
33     for (const auto &var : mVars)
34     {
35         if (name == var.name)
36         {
37             return var.isInvariant;
38         }
39     }
40     // TODO(kpidington): this should be UNREACHABLE() but isn't because the translator generates
41     // declarations to unused built-in variables.
42     return false;
43 }
44 
45 using VariableSet  = std::unordered_set<const TVariable *>;
46 using VariableList = std::vector<const TVariable *>;
47 
48 ////////////////////////////////////////////////////////////////////////////////
49 
50 struct PipelineStructInfo
51 {
52     VariableSet pipelineVariables;
53     PipelineScoped<TStructure> pipelineStruct;
54     const TFunction *funcOriginalToModified = nullptr;
55     const TFunction *funcModifiedToOriginal = nullptr;
56 
isEmpty__anon986796d00111::PipelineStructInfo57     bool isEmpty() const
58     {
59         if (pipelineStruct.isTotallyEmpty())
60         {
61             ASSERT(pipelineVariables.empty());
62             return true;
63         }
64         else
65         {
66             ASSERT(pipelineStruct.isTotallyFull());
67             ASSERT(!pipelineVariables.empty());
68             return false;
69         }
70     }
71 };
72 
73 class GeneratePipelineStruct : private TIntermRebuild
74 {
75   private:
76     const Pipeline &mPipeline;
77     SymbolEnv &mSymbolEnv;
78     const std::vector<sh::ShaderVariable> *mVariableInfos;
79     VariableList mPipelineVariableList;
80     IdGen &mIdGen;
81     PipelineStructInfo mInfo;
82 
83   public:
Exec(PipelineStructInfo & out,TCompiler & compiler,TIntermBlock & root,IdGen & idGen,const Pipeline & pipeline,SymbolEnv & symbolEnv,const std::vector<sh::ShaderVariable> * variableInfos)84     static bool Exec(PipelineStructInfo &out,
85                      TCompiler &compiler,
86                      TIntermBlock &root,
87                      IdGen &idGen,
88                      const Pipeline &pipeline,
89                      SymbolEnv &symbolEnv,
90                      const std::vector<sh::ShaderVariable> *variableInfos)
91     {
92         GeneratePipelineStruct self(compiler, idGen, pipeline, symbolEnv, variableInfos);
93         if (!self.exec(root))
94         {
95             return false;
96         }
97         out = self.mInfo;
98         return true;
99     }
100 
101   private:
GeneratePipelineStruct(TCompiler & compiler,IdGen & idGen,const Pipeline & pipeline,SymbolEnv & symbolEnv,const std::vector<sh::ShaderVariable> * variableInfos)102     GeneratePipelineStruct(TCompiler &compiler,
103                            IdGen &idGen,
104                            const Pipeline &pipeline,
105                            SymbolEnv &symbolEnv,
106                            const std::vector<sh::ShaderVariable> *variableInfos)
107         : TIntermRebuild(compiler, true, true),
108           mPipeline(pipeline),
109           mSymbolEnv(symbolEnv),
110           mVariableInfos(variableInfos),
111           mIdGen(idGen)
112     {}
113 
exec(TIntermBlock & root)114     bool exec(TIntermBlock &root)
115     {
116         if (!rebuildRoot(root))
117         {
118             return false;
119         }
120 
121         if (mInfo.pipelineVariables.empty())
122         {
123             return true;
124         }
125 
126         TIntermSequence seq;
127 
128         const TStructure &pipelineStruct = [&]() -> const TStructure & {
129             if (mPipeline.globalInstanceVar)
130             {
131                 return *mPipeline.globalInstanceVar->getType().getStruct();
132             }
133             else
134             {
135                 return createInternalPipelineStruct(root, seq);
136             }
137         }();
138 
139         ModifiedStructMachineries modifiedMachineries;
140         const bool isUBO    = mPipeline.type == Pipeline::Type::UniformBuffer;
141         const bool modified = TryCreateModifiedStruct(
142             mCompiler, mSymbolEnv, mIdGen, mPipeline.externalStructModifyConfig(), pipelineStruct,
143             mPipeline.getStructTypeName(Pipeline::Variant::Modified), modifiedMachineries, isUBO,
144             !isUBO);
145 
146         if (modified)
147         {
148             ASSERT(mPipeline.type != Pipeline::Type::Texture);
149             ASSERT(mPipeline.type == Pipeline::Type::AngleUniforms ||
150                    !mPipeline.globalInstanceVar);  // This shouldn't happen by construction.
151 
152             auto getFunction = [](sh::TIntermFunctionDefinition *funcDecl) {
153                 return funcDecl ? funcDecl->getFunction() : nullptr;
154             };
155 
156             const size_t size = modifiedMachineries.size();
157             ASSERT(size > 0);
158             for (size_t i = 0; i < size; ++i)
159             {
160                 const ModifiedStructMachinery &machinery = modifiedMachineries.at(i);
161                 ASSERT(machinery.modifiedStruct);
162 
163                 seq.push_back(new TIntermDeclaration{
164                     &CreateStructTypeVariable(mSymbolTable, *machinery.modifiedStruct)});
165 
166                 if (mPipeline.isPipelineOut())
167                 {
168                     ASSERT(machinery.funcOriginalToModified);
169                     ASSERT(!machinery.funcModifiedToOriginal);
170                     seq.push_back(machinery.funcOriginalToModified);
171                 }
172                 else
173                 {
174                     ASSERT(machinery.funcModifiedToOriginal);
175                     ASSERT(!machinery.funcOriginalToModified);
176                     seq.push_back(machinery.funcModifiedToOriginal);
177                 }
178 
179                 if (i == size - 1)
180                 {
181                     mInfo.funcOriginalToModified = getFunction(machinery.funcOriginalToModified);
182                     mInfo.funcModifiedToOriginal = getFunction(machinery.funcModifiedToOriginal);
183 
184                     mInfo.pipelineStruct.internal = &pipelineStruct;
185                     mInfo.pipelineStruct.external =
186                         modified ? machinery.modifiedStruct : &pipelineStruct;
187                 }
188             }
189         }
190         else
191         {
192             mInfo.pipelineStruct.internal = &pipelineStruct;
193             mInfo.pipelineStruct.external = &pipelineStruct;
194         }
195 
196         root.insertChildNodes(FindMainIndex(&root), seq);
197 
198         return true;
199     }
200 
201   private:
visitFunctionDefinitionPre(TIntermFunctionDefinition & node)202     PreResult visitFunctionDefinitionPre(TIntermFunctionDefinition &node) override
203     {
204         return {node, VisitBits::Neither};
205     }
visitDeclarationPost(TIntermDeclaration & declNode)206     PostResult visitDeclarationPost(TIntermDeclaration &declNode) override
207     {
208         Declaration decl     = ViewDeclaration(declNode);
209         const TVariable &var = decl.symbol.variable();
210         if (mPipeline.uses(var))
211         {
212             ASSERT(mInfo.pipelineVariables.find(&var) == mInfo.pipelineVariables.end());
213             mInfo.pipelineVariables.insert(&var);
214             mPipelineVariableList.push_back(&var);
215             return nullptr;
216         }
217 
218         return declNode;
219     }
220 
createInternalPipelineStruct(TIntermBlock & root,TIntermSequence & outDeclSeq)221     const TStructure &createInternalPipelineStruct(TIntermBlock &root, TIntermSequence &outDeclSeq)
222     {
223         auto &fields = *new TFieldList();
224 
225         switch (mPipeline.type)
226         {
227             case Pipeline::Type::Texture:
228             {
229                 for (const TVariable *var : mPipelineVariableList)
230                 {
231                     const TType &varType         = var->getType();
232                     const TBasicType samplerType = varType.getBasicType();
233 
234                     const TStructure &textureEnv = mSymbolEnv.getTextureEnv(samplerType);
235                     auto *textureEnvType         = new TType(&textureEnv, false);
236                     if (varType.isArray())
237                     {
238                         textureEnvType->makeArrays(varType.getArraySizes());
239                     }
240 
241                     fields.push_back(
242                         new TField(textureEnvType, var->name(), kNoSourceLoc, var->symbolType()));
243                 }
244             }
245             break;
246 
247             case Pipeline::Type::UniformBuffer:
248             {
249                 for (const TVariable *var : mPipelineVariableList)
250                 {
251                     auto &type  = CloneType(var->getType());
252                     auto *field = new TField(&type, var->name(), kNoSourceLoc, var->symbolType());
253                     mSymbolEnv.markAsPointer(*field, AddressSpace::Constant);
254                     mSymbolEnv.markAsUBO(*field);
255                     mSymbolEnv.markAsPointer(*var, AddressSpace::Constant);
256                     fields.push_back(field);
257                 }
258             }
259             break;
260             default:
261             {
262                 for (const TVariable *var : mPipelineVariableList)
263                 {
264                     auto &type = CloneType(var->getType());
265                     if (mVariableInfos && IsVariableInvariant(*mVariableInfos, var->name()))
266                     {
267                         type.setInvariant(true);
268                     }
269                     auto *field = new TField(&type, var->name(), kNoSourceLoc, var->symbolType());
270                     fields.push_back(field);
271                 }
272             }
273             break;
274         }
275 
276         Name pipelineStructName = mPipeline.getStructTypeName(Pipeline::Variant::Original);
277         auto &s = *new TStructure(&mSymbolTable, pipelineStructName.rawName(), &fields,
278                                   pipelineStructName.symbolType());
279 
280         outDeclSeq.push_back(new TIntermDeclaration{&CreateStructTypeVariable(mSymbolTable, s)});
281 
282         return s;
283     }
284 };
285 
286 ////////////////////////////////////////////////////////////////////////////////
287 
CreatePipelineMainLocalVar(TSymbolTable & symbolTable,const Pipeline & pipeline,PipelineScoped<TStructure> pipelineStruct)288 PipelineScoped<TVariable> CreatePipelineMainLocalVar(TSymbolTable &symbolTable,
289                                                      const Pipeline &pipeline,
290                                                      PipelineScoped<TStructure> pipelineStruct)
291 {
292     ASSERT(pipelineStruct.isTotallyFull());
293 
294     PipelineScoped<TVariable> pipelineMainLocalVar;
295 
296     auto populateExternalMainLocalVar = [&]() {
297         ASSERT(!pipelineMainLocalVar.external);
298         pipelineMainLocalVar.external = &CreateInstanceVariable(
299             symbolTable, *pipelineStruct.external,
300             pipeline.getStructInstanceName(pipelineStruct.isUniform()
301                                                ? Pipeline::Variant::Original
302                                                : Pipeline::Variant::Modified));
303     };
304 
305     auto populateDistinctInternalMainLocalVar = [&]() {
306         ASSERT(!pipelineMainLocalVar.internal);
307         pipelineMainLocalVar.internal =
308             &CreateInstanceVariable(symbolTable, *pipelineStruct.internal,
309                                     pipeline.getStructInstanceName(Pipeline::Variant::Original));
310     };
311 
312     if (pipeline.type == Pipeline::Type::InstanceId)
313     {
314         populateDistinctInternalMainLocalVar();
315     }
316     else if (pipeline.alwaysRequiresLocalVariableDeclarationInMain())
317     {
318         populateExternalMainLocalVar();
319 
320         if (pipelineStruct.isUniform())
321         {
322             pipelineMainLocalVar.internal = pipelineMainLocalVar.external;
323         }
324         else
325         {
326             populateDistinctInternalMainLocalVar();
327         }
328     }
329     else if (!pipelineStruct.isUniform())
330     {
331         populateDistinctInternalMainLocalVar();
332     }
333 
334     return pipelineMainLocalVar;
335 }
336 
337 class PipelineFunctionEnv
338 {
339   private:
340     TCompiler &mCompiler;
341     SymbolEnv &mSymbolEnv;
342     TSymbolTable &mSymbolTable;
343     IdGen &mIdGen;
344     const Pipeline &mPipeline;
345     const std::unordered_set<const TFunction *> &mPipelineFunctions;
346     const PipelineScoped<TStructure> mPipelineStruct;
347     PipelineScoped<TVariable> &mPipelineMainLocalVar;
348 
349     std::unordered_map<const TFunction *, const TFunction *> mFuncMap;
350 
351   public:
PipelineFunctionEnv(TCompiler & compiler,SymbolEnv & symbolEnv,IdGen & idGen,const Pipeline & pipeline,const std::unordered_set<const TFunction * > & pipelineFunctions,PipelineScoped<TStructure> pipelineStruct,PipelineScoped<TVariable> & pipelineMainLocalVar)352     PipelineFunctionEnv(TCompiler &compiler,
353                         SymbolEnv &symbolEnv,
354                         IdGen &idGen,
355                         const Pipeline &pipeline,
356                         const std::unordered_set<const TFunction *> &pipelineFunctions,
357                         PipelineScoped<TStructure> pipelineStruct,
358                         PipelineScoped<TVariable> &pipelineMainLocalVar)
359         : mCompiler(compiler),
360           mSymbolEnv(symbolEnv),
361           mSymbolTable(symbolEnv.symbolTable()),
362           mIdGen(idGen),
363           mPipeline(pipeline),
364           mPipelineFunctions(pipelineFunctions),
365           mPipelineStruct(pipelineStruct),
366           mPipelineMainLocalVar(pipelineMainLocalVar)
367     {}
368 
isOriginalPipelineFunction(const TFunction & func) const369     bool isOriginalPipelineFunction(const TFunction &func) const
370     {
371         return mPipelineFunctions.find(&func) != mPipelineFunctions.end();
372     }
373 
isUpdatedPipelineFunction(const TFunction & func) const374     bool isUpdatedPipelineFunction(const TFunction &func) const
375     {
376         auto it = mFuncMap.find(&func);
377         if (it == mFuncMap.end())
378         {
379             return false;
380         }
381         return &func == it->second;
382     }
383 
getUpdatedFunction(const TFunction & func)384     const TFunction &getUpdatedFunction(const TFunction &func)
385     {
386         ASSERT(isOriginalPipelineFunction(func) || isUpdatedPipelineFunction(func));
387 
388         const TFunction *newFunc;
389 
390         auto it = mFuncMap.find(&func);
391         if (it == mFuncMap.end())
392         {
393             const bool isMain = func.isMain();
394 
395             if (isMain && mPipeline.isPipelineOut())
396             {
397                 ASSERT(func.getReturnType().getBasicType() == TBasicType::EbtVoid);
398                 newFunc = &CloneFunctionAndChangeReturnType(mSymbolTable, nullptr, func,
399                                                             *mPipelineStruct.external);
400             }
401             else if (isMain && (mPipeline.type == Pipeline::Type::InvocationVertexGlobals ||
402                                 mPipeline.type == Pipeline::Type::InvocationFragmentGlobals))
403             {
404                 std::vector<const TVariable *> variables;
405                 for (const TField *field : mPipelineStruct.external->fields())
406                 {
407                     variables.push_back(new TVariable(&mSymbolTable, field->name(), field->type(),
408                                                       field->symbolType()));
409                 }
410                 newFunc = &CloneFunctionAndAppendParams(mSymbolTable, nullptr, func, variables);
411             }
412             else if (isMain && mPipeline.type == Pipeline::Type::Texture)
413             {
414                 std::vector<const TVariable *> variables;
415                 TranslatorMetalReflection *reflection =
416                     mtl::getTranslatorMetalReflection(&mCompiler);
417                 for (const TField *field : mPipelineStruct.external->fields())
418                 {
419                     const TStructure *textureEnv = field->type()->getStruct();
420                     ASSERT(textureEnv && textureEnv->fields().size() == 2);
421                     for (const TField *subfield : textureEnv->fields())
422                     {
423                         const Name name = mIdGen.createNewName({field->name(), subfield->name()});
424                         TType &type     = *new TType(*subfield->type());
425                         ASSERT(!type.isArray());
426                         type.makeArrays(field->type()->getArraySizes());
427                         auto *var =
428                             new TVariable(&mSymbolTable, name.rawName(), &type, name.symbolType());
429                         variables.push_back(var);
430                         reflection->addOriginalName(var->uniqueId().get(), field->name().data());
431                     }
432                 }
433                 newFunc = &CloneFunctionAndAppendParams(mSymbolTable, nullptr, func, variables);
434             }
435             else if (isMain && mPipeline.type == Pipeline::Type::InstanceId)
436             {
437                 Name instanceIdName = mPipeline.getStructInstanceName(Pipeline::Variant::Modified);
438                 auto *instanceIdVar =
439                     new TVariable(&mSymbolTable, instanceIdName.rawName(),
440                                   new TType(TBasicType::EbtUInt), instanceIdName.symbolType());
441 
442                 auto *baseInstanceVar =
443                     new TVariable(&mSymbolTable, kBaseInstanceName.rawName(),
444                                   new TType(TBasicType::EbtUInt), kBaseInstanceName.symbolType());
445 
446                 newFunc = &CloneFunctionAndPrependTwoParams(mSymbolTable, nullptr, func,
447                                                             *instanceIdVar, *baseInstanceVar);
448                 mPipelineMainLocalVar.external      = instanceIdVar;
449                 mPipelineMainLocalVar.externalExtra = baseInstanceVar;
450             }
451             else if (isMain && mPipeline.alwaysRequiresLocalVariableDeclarationInMain())
452             {
453                 ASSERT(mPipelineMainLocalVar.isTotallyFull());
454                 newFunc = &func;
455             }
456             else
457             {
458                 const TVariable *var;
459                 AddressSpace addressSpace;
460 
461                 if (isMain && !mPipelineMainLocalVar.isUniform())
462                 {
463                     var = &CreateInstanceVariable(
464                         mSymbolTable, *mPipelineStruct.external,
465                         mPipeline.getStructInstanceName(Pipeline::Variant::Modified));
466                     addressSpace = mPipeline.externalAddressSpace();
467                 }
468                 else
469                 {
470                     if (mPipeline.type == Pipeline::Type::UniformBuffer)
471                     {
472                         TranslatorMetalReflection *reflection =
473                             ((sh::TranslatorMetalDirect *)&mCompiler)
474                                 ->getTranslatorMetalReflection();
475                         // TODO: need more checks to make sure they line up? Could be reordered?
476                         ASSERT(mPipelineStruct.external->fields().size() ==
477                                mPipelineStruct.internal->fields().size());
478                         for (size_t i = 0; i < mPipelineStruct.external->fields().size(); i++)
479                         {
480                             const TField *externalField = mPipelineStruct.external->fields()[i];
481                             const TField *internalField = mPipelineStruct.internal->fields()[i];
482                             const TType &externalType   = *externalField->type();
483                             const TType &internalType   = *internalField->type();
484                             ASSERT(externalType.getBasicType() == internalType.getBasicType());
485                             if (externalType.getBasicType() == TBasicType::EbtStruct)
486                             {
487                                 const TStructure *externalEnv = externalType.getStruct();
488                                 const TStructure *internalEnv = internalType.getStruct();
489                                 const std::string internalName =
490                                     reflection->getOriginalName(internalEnv->uniqueId().get());
491                                 reflection->addOriginalName(externalEnv->uniqueId().get(),
492                                                             internalName);
493                             }
494                         }
495                     }
496                     var = &CreateInstanceVariable(
497                         mSymbolTable, *mPipelineStruct.internal,
498                         mPipeline.getStructInstanceName(Pipeline::Variant::Original));
499                     addressSpace = mPipelineMainLocalVar.isUniform()
500                                        ? mPipeline.externalAddressSpace()
501                                        : AddressSpace::Thread;
502                 }
503 
504                 bool markAsReference = true;
505                 if (isMain)
506                 {
507                     switch (mPipeline.type)
508                     {
509                         case Pipeline::Type::VertexIn:
510                         case Pipeline::Type::FragmentIn:
511                             markAsReference = false;
512                             break;
513 
514                         default:
515                             break;
516                     }
517                 }
518 
519                 if (markAsReference)
520                 {
521                     mSymbolEnv.markAsReference(*var, addressSpace);
522                 }
523 
524                 newFunc = &CloneFunctionAndPrependParam(mSymbolTable, nullptr, func, *var);
525             }
526 
527             mFuncMap[&func]   = newFunc;
528             mFuncMap[newFunc] = newFunc;
529         }
530         else
531         {
532             newFunc = it->second;
533         }
534 
535         return *newFunc;
536     }
537 
createUpdatedFunctionPrototype(TIntermFunctionPrototype & funcProtoNode)538     TIntermFunctionPrototype *createUpdatedFunctionPrototype(
539         TIntermFunctionPrototype &funcProtoNode)
540     {
541         const TFunction &func = *funcProtoNode.getFunction();
542         if (!isOriginalPipelineFunction(func) && !isUpdatedPipelineFunction(func))
543         {
544             return nullptr;
545         }
546         const TFunction &newFunc = getUpdatedFunction(func);
547         return new TIntermFunctionPrototype(&newFunc);
548     }
549 };
550 
551 class UpdatePipelineFunctions : private TIntermRebuild
552 {
553   private:
554     const Pipeline &mPipeline;
555     const PipelineScoped<TStructure> mPipelineStruct;
556     PipelineScoped<TVariable> &mPipelineMainLocalVar;
557     SymbolEnv &mSymbolEnv;
558     PipelineFunctionEnv mEnv;
559     const TFunction *mFuncOriginalToModified;
560     const TFunction *mFuncModifiedToOriginal;
561 
562   public:
ThreadPipeline(TCompiler & compiler,TIntermBlock & root,const Pipeline & pipeline,const std::unordered_set<const TFunction * > & pipelineFunctions,PipelineScoped<TStructure> pipelineStruct,PipelineScoped<TVariable> & pipelineMainLocalVar,IdGen & idGen,SymbolEnv & symbolEnv,const TFunction * funcOriginalToModified,const TFunction * funcModifiedToOriginal)563     static bool ThreadPipeline(TCompiler &compiler,
564                                TIntermBlock &root,
565                                const Pipeline &pipeline,
566                                const std::unordered_set<const TFunction *> &pipelineFunctions,
567                                PipelineScoped<TStructure> pipelineStruct,
568                                PipelineScoped<TVariable> &pipelineMainLocalVar,
569                                IdGen &idGen,
570                                SymbolEnv &symbolEnv,
571                                const TFunction *funcOriginalToModified,
572                                const TFunction *funcModifiedToOriginal)
573     {
574         UpdatePipelineFunctions self(compiler, pipeline, pipelineFunctions, pipelineStruct,
575                                      pipelineMainLocalVar, idGen, symbolEnv, funcOriginalToModified,
576                                      funcModifiedToOriginal);
577         if (!self.rebuildRoot(root))
578         {
579             return false;
580         }
581         return true;
582     }
583 
584   private:
UpdatePipelineFunctions(TCompiler & compiler,const Pipeline & pipeline,const std::unordered_set<const TFunction * > & pipelineFunctions,PipelineScoped<TStructure> pipelineStruct,PipelineScoped<TVariable> & pipelineMainLocalVar,IdGen & idGen,SymbolEnv & symbolEnv,const TFunction * funcOriginalToModified,const TFunction * funcModifiedToOriginal)585     UpdatePipelineFunctions(TCompiler &compiler,
586                             const Pipeline &pipeline,
587                             const std::unordered_set<const TFunction *> &pipelineFunctions,
588                             PipelineScoped<TStructure> pipelineStruct,
589                             PipelineScoped<TVariable> &pipelineMainLocalVar,
590                             IdGen &idGen,
591                             SymbolEnv &symbolEnv,
592                             const TFunction *funcOriginalToModified,
593                             const TFunction *funcModifiedToOriginal)
594         : TIntermRebuild(compiler, false, true),
595           mPipeline(pipeline),
596           mPipelineStruct(pipelineStruct),
597           mPipelineMainLocalVar(pipelineMainLocalVar),
598           mSymbolEnv(symbolEnv),
599           mEnv(compiler,
600                symbolEnv,
601                idGen,
602                pipeline,
603                pipelineFunctions,
604                pipelineStruct,
605                mPipelineMainLocalVar),
606           mFuncOriginalToModified(funcOriginalToModified),
607           mFuncModifiedToOriginal(funcModifiedToOriginal)
608     {
609         ASSERT(mPipelineStruct.isTotallyFull());
610     }
611 
getInternalPipelineVariable(const TFunction & pipelineFunc)612     const TVariable &getInternalPipelineVariable(const TFunction &pipelineFunc)
613     {
614         if (pipelineFunc.isMain() && (mPipeline.alwaysRequiresLocalVariableDeclarationInMain() ||
615                                       !mPipelineMainLocalVar.isUniform()))
616         {
617             ASSERT(mPipelineMainLocalVar.internal);
618             return *mPipelineMainLocalVar.internal;
619         }
620         else
621         {
622             ASSERT(pipelineFunc.getParamCount() > 0);
623             return *pipelineFunc.getParam(0);
624         }
625     }
626 
getExternalPipelineVariable(const TFunction & mainFunc)627     const TVariable &getExternalPipelineVariable(const TFunction &mainFunc)
628     {
629         ASSERT(mainFunc.isMain());
630         if (mPipelineMainLocalVar.external)
631         {
632             return *mPipelineMainLocalVar.external;
633         }
634         else
635         {
636             ASSERT(mainFunc.getParamCount() > 0);
637             return *mainFunc.getParam(0);
638         }
639     }
640 
getExternalExtraPipelineVariable(const TFunction & mainFunc)641     const TVariable &getExternalExtraPipelineVariable(const TFunction &mainFunc)
642     {
643         ASSERT(mainFunc.isMain());
644         if (mPipelineMainLocalVar.externalExtra)
645         {
646             return *mPipelineMainLocalVar.externalExtra;
647         }
648         else
649         {
650             ASSERT(mainFunc.getParamCount() > 1);
651             return *mainFunc.getParam(1);
652         }
653     }
654 
visitAggregatePost(TIntermAggregate & callNode)655     PostResult visitAggregatePost(TIntermAggregate &callNode) override
656     {
657         if (callNode.isConstructor())
658         {
659             return callNode;
660         }
661         else
662         {
663             const TFunction &oldCalledFunc = *callNode.getFunction();
664             if (!mEnv.isOriginalPipelineFunction(oldCalledFunc))
665             {
666                 return callNode;
667             }
668             const TFunction &newCalledFunc = mEnv.getUpdatedFunction(oldCalledFunc);
669 
670             const TFunction *oldOwnerFunc = getParentFunction();
671             ASSERT(oldOwnerFunc);
672             const TFunction &newOwnerFunc = mEnv.getUpdatedFunction(*oldOwnerFunc);
673 
674             return *TIntermAggregate::CreateFunctionCall(
675                 newCalledFunc, &CloneSequenceAndPrepend(
676                                    *callNode.getSequence(),
677                                    *new TIntermSymbol(&getInternalPipelineVariable(newOwnerFunc))));
678         }
679     }
680 
visitFunctionPrototypePost(TIntermFunctionPrototype & funcProtoNode)681     PostResult visitFunctionPrototypePost(TIntermFunctionPrototype &funcProtoNode) override
682     {
683         TIntermFunctionPrototype *newFuncProtoNode =
684             mEnv.createUpdatedFunctionPrototype(funcProtoNode);
685         if (newFuncProtoNode == nullptr)
686         {
687             return funcProtoNode;
688         }
689         return *newFuncProtoNode;
690     }
691 
visitFunctionDefinitionPost(TIntermFunctionDefinition & funcDefNode)692     PostResult visitFunctionDefinitionPost(TIntermFunctionDefinition &funcDefNode) override
693     {
694         if (funcDefNode.getFunction()->isMain())
695         {
696             return visitMain(funcDefNode);
697         }
698         else
699         {
700             return visitNonMain(funcDefNode);
701         }
702     }
703 
visitNonMain(TIntermFunctionDefinition & funcDefNode)704     TIntermNode &visitNonMain(TIntermFunctionDefinition &funcDefNode)
705     {
706         TIntermFunctionPrototype &funcProtoNode = *funcDefNode.getFunctionPrototype();
707         ASSERT(!funcProtoNode.getFunction()->isMain());
708 
709         TIntermFunctionPrototype *newFuncProtoNode =
710             mEnv.createUpdatedFunctionPrototype(funcProtoNode);
711         if (newFuncProtoNode == nullptr)
712         {
713             return funcDefNode;
714         }
715 
716         const TFunction &func = *newFuncProtoNode->getFunction();
717         ASSERT(!func.isMain());
718 
719         TIntermBlock *body = funcDefNode.getBody();
720 
721         return *new TIntermFunctionDefinition(newFuncProtoNode, body);
722     }
723 
visitMain(TIntermFunctionDefinition & funcDefNode)724     TIntermNode &visitMain(TIntermFunctionDefinition &funcDefNode)
725     {
726         TIntermFunctionPrototype &funcProtoNode = *funcDefNode.getFunctionPrototype();
727         ASSERT(funcProtoNode.getFunction()->isMain());
728 
729         TIntermFunctionPrototype *newFuncProtoNode =
730             mEnv.createUpdatedFunctionPrototype(funcProtoNode);
731         if (newFuncProtoNode == nullptr)
732         {
733             return funcDefNode;
734         }
735 
736         const TFunction &func = *newFuncProtoNode->getFunction();
737         ASSERT(func.isMain());
738 
739         auto callModifiedToOriginal = [&](TIntermBlock &body) {
740             ASSERT(mPipelineMainLocalVar.internal);
741             if (!mPipeline.isPipelineOut())
742             {
743                 ASSERT(mFuncModifiedToOriginal);
744                 auto *m = new TIntermSymbol(&getExternalPipelineVariable(func));
745                 auto *o = new TIntermSymbol(mPipelineMainLocalVar.internal);
746                 body.appendStatement(TIntermAggregate::CreateFunctionCall(
747                     *mFuncModifiedToOriginal, new TIntermSequence{m, o}));
748             }
749         };
750 
751         auto callOriginalToModified = [&](TIntermBlock &body) {
752             ASSERT(mPipelineMainLocalVar.internal);
753             if (mPipeline.isPipelineOut())
754             {
755                 ASSERT(mFuncOriginalToModified);
756                 auto *o = new TIntermSymbol(mPipelineMainLocalVar.internal);
757                 auto *m = new TIntermSymbol(&getExternalPipelineVariable(func));
758                 body.appendStatement(TIntermAggregate::CreateFunctionCall(
759                     *mFuncOriginalToModified, new TIntermSequence{o, m}));
760             }
761         };
762 
763         TIntermBlock *body = funcDefNode.getBody();
764 
765         if (mPipeline.alwaysRequiresLocalVariableDeclarationInMain())
766         {
767             ASSERT(mPipelineMainLocalVar.isTotallyFull());
768 
769             auto *newBody = new TIntermBlock();
770             newBody->appendStatement(new TIntermDeclaration{mPipelineMainLocalVar.internal});
771 
772             if (mPipeline.type == Pipeline::Type::InvocationVertexGlobals ||
773                 mPipeline.type == Pipeline::Type::InvocationFragmentGlobals)
774             {
775                 // Populate struct instance with references to global pipeline variables.
776                 for (const TField *field : mPipelineStruct.external->fields())
777                 {
778                     auto *var        = new TVariable(&mSymbolTable, field->name(), field->type(),
779                                               field->symbolType());
780                     auto *symbol     = new TIntermSymbol(var);
781                     auto &accessNode = AccessField(*mPipelineMainLocalVar.internal, var->name());
782                     auto *assignNode = new TIntermBinary(TOperator::EOpAssign, &accessNode, symbol);
783                     newBody->appendStatement(assignNode);
784                 }
785             }
786             else if (mPipeline.type == Pipeline::Type::Texture)
787             {
788                 const TFieldList &fields = mPipelineStruct.external->fields();
789 
790                 ASSERT(func.getParamCount() >= 2 * fields.size());
791                 size_t paramIndex = func.getParamCount() - 2 * fields.size();
792 
793                 for (const TField *field : fields)
794                 {
795                     const TVariable &textureParam = *func.getParam(paramIndex++);
796                     const TVariable &samplerParam = *func.getParam(paramIndex++);
797 
798                     auto go = [&](TIntermTyped &env, const int *index) {
799                         TIntermTyped &textureField = AccessField(
800                             AccessIndex(*env.deepCopy(), index), ImmutableString("texture"));
801                         TIntermTyped &samplerField = AccessField(
802                             AccessIndex(*env.deepCopy(), index), ImmutableString("sampler"));
803 
804                         auto mkAssign = [&](TIntermTyped &field, const TVariable &param) {
805                             return new TIntermBinary(TOperator::EOpAssign, &field,
806                                                      &mSymbolEnv.callFunctionOverload(
807                                                          Name("addressof"), field.getType(),
808                                                          *new TIntermSequence{&AccessIndex(
809                                                              *new TIntermSymbol(&param), index)}));
810                         };
811 
812                         newBody->appendStatement(mkAssign(textureField, textureParam));
813                         newBody->appendStatement(mkAssign(samplerField, samplerParam));
814                     };
815 
816                     TIntermTyped &env = AccessField(*mPipelineMainLocalVar.internal, field->name());
817                     const TType &envType = env.getType();
818 
819                     if (envType.isArray())
820                     {
821                         ASSERT(!envType.isArrayOfArrays());
822                         const auto n = static_cast<int>(envType.getArraySizeProduct());
823                         for (int i = 0; i < n; ++i)
824                         {
825                             go(env, &i);
826                         }
827                     }
828                     else
829                     {
830                         go(env, nullptr);
831                     }
832                 }
833             }
834             else if (mPipeline.type == Pipeline::Type::InstanceId)
835             {
836                 auto varInstanceId   = new TIntermSymbol(&getExternalPipelineVariable(func));
837                 auto varBaseInstance = new TIntermSymbol(&getExternalExtraPipelineVariable(func));
838 
839                 newBody->appendStatement(new TIntermBinary(
840                     TOperator::EOpAssign,
841                     &AccessFieldByIndex(*new TIntermSymbol(&getInternalPipelineVariable(func)), 0),
842                     &AsType(
843                         mSymbolEnv, *new TType(TBasicType::EbtInt),
844                         *new TIntermBinary(TOperator::EOpSub, varInstanceId, varBaseInstance))));
845             }
846             else if (!mPipelineMainLocalVar.isUniform())
847             {
848                 newBody->appendStatement(new TIntermDeclaration{mPipelineMainLocalVar.external});
849                 callModifiedToOriginal(*newBody);
850             }
851 
852             newBody->appendStatement(body);
853 
854             if (!mPipelineMainLocalVar.isUniform())
855             {
856                 callOriginalToModified(*newBody);
857             }
858 
859             if (mPipeline.isPipelineOut())
860             {
861                 newBody->appendStatement(new TIntermBranch(
862                     TOperator::EOpReturn, new TIntermSymbol(mPipelineMainLocalVar.external)));
863             }
864 
865             body = newBody;
866         }
867         else if (!mPipelineMainLocalVar.isUniform())
868         {
869             ASSERT(!mPipelineMainLocalVar.external);
870             ASSERT(mPipelineMainLocalVar.internal);
871 
872             auto *newBody = new TIntermBlock();
873             newBody->appendStatement(new TIntermDeclaration{mPipelineMainLocalVar.internal});
874             callModifiedToOriginal(*newBody);
875             newBody->appendStatement(body);
876             callOriginalToModified(*newBody);
877             body = newBody;
878         }
879 
880         return *new TIntermFunctionDefinition(newFuncProtoNode, body);
881     }
882 };
883 
884 ////////////////////////////////////////////////////////////////////////////////
885 
UpdatePipelineSymbols(Pipeline::Type pipelineType,TCompiler & compiler,TIntermBlock & root,SymbolEnv & symbolEnv,const VariableSet & pipelineVariables,PipelineScoped<TVariable> pipelineMainLocalVar)886 bool UpdatePipelineSymbols(Pipeline::Type pipelineType,
887                            TCompiler &compiler,
888                            TIntermBlock &root,
889                            SymbolEnv &symbolEnv,
890                            const VariableSet &pipelineVariables,
891                            PipelineScoped<TVariable> pipelineMainLocalVar)
892 {
893     auto map = [&](const TFunction *owner, TIntermSymbol &symbol) -> TIntermNode & {
894         if (!owner)
895             return symbol;
896         const TVariable &var = symbol.variable();
897         if (pipelineVariables.find(&var) == pipelineVariables.end())
898         {
899             return symbol;
900         }
901         const TVariable *structInstanceVar;
902         if (owner->isMain())
903         {
904             ASSERT(pipelineMainLocalVar.internal);
905             structInstanceVar = pipelineMainLocalVar.internal;
906         }
907         else
908         {
909             ASSERT(owner->getParamCount() > 0);
910             structInstanceVar = owner->getParam(0);
911         }
912         ASSERT(structInstanceVar);
913         return AccessField(*structInstanceVar, var.name());
914     };
915     return MapSymbols(compiler, root, map);
916 }
917 
918 ////////////////////////////////////////////////////////////////////////////////
919 
RewritePipeline(TCompiler & compiler,TIntermBlock & root,IdGen & idGen,const Pipeline & pipeline,SymbolEnv & symbolEnv,const std::vector<sh::ShaderVariable> * variableInfo,PipelineScoped<TStructure> & outStruct)920 bool RewritePipeline(TCompiler &compiler,
921                      TIntermBlock &root,
922                      IdGen &idGen,
923                      const Pipeline &pipeline,
924                      SymbolEnv &symbolEnv,
925                      const std::vector<sh::ShaderVariable> *variableInfo,
926                      PipelineScoped<TStructure> &outStruct)
927 {
928     ASSERT(outStruct.isTotallyEmpty());
929 
930     TSymbolTable &symbolTable = compiler.getSymbolTable();
931 
932     PipelineStructInfo psi;
933     if (!GeneratePipelineStruct::Exec(psi, compiler, root, idGen, pipeline, symbolEnv,
934                                       variableInfo))
935     {
936         return false;
937     }
938 
939     if (psi.isEmpty())
940     {
941         return true;
942     }
943 
944     const auto pipelineFunctions = DiscoverDependentFunctions(root, [&](const TVariable &var) {
945         return psi.pipelineVariables.find(&var) != psi.pipelineVariables.end();
946     });
947 
948     auto pipelineMainLocalVar =
949         CreatePipelineMainLocalVar(symbolTable, pipeline, psi.pipelineStruct);
950 
951     if (!UpdatePipelineFunctions::ThreadPipeline(
952             compiler, root, pipeline, pipelineFunctions, psi.pipelineStruct, pipelineMainLocalVar,
953             idGen, symbolEnv, psi.funcOriginalToModified, psi.funcModifiedToOriginal))
954     {
955         return false;
956     }
957 
958     if (!pipeline.globalInstanceVar)
959     {
960         if (!UpdatePipelineSymbols(pipeline.type, compiler, root, symbolEnv, psi.pipelineVariables,
961                                    pipelineMainLocalVar))
962         {
963             return false;
964         }
965     }
966 
967     if (!PruneNoOps(&compiler, &root, &compiler.getSymbolTable()))
968     {
969         return false;
970     }
971 
972     outStruct = psi.pipelineStruct;
973     return true;
974 }
975 
976 }  // anonymous namespace
977 
RewritePipelines(TCompiler & compiler,TIntermBlock & root,const std::vector<sh::ShaderVariable> & inputVaryings,const std::vector<sh::ShaderVariable> & outputVaryings,IdGen & idGen,DriverUniform & angleUniformsGlobalInstanceVar,SymbolEnv & symbolEnv,PipelineStructs & outStructs)978 bool sh::RewritePipelines(TCompiler &compiler,
979                           TIntermBlock &root,
980                           const std::vector<sh::ShaderVariable> &inputVaryings,
981                           const std::vector<sh::ShaderVariable> &outputVaryings,
982                           IdGen &idGen,
983                           DriverUniform &angleUniformsGlobalInstanceVar,
984                           SymbolEnv &symbolEnv,
985                           PipelineStructs &outStructs)
986 {
987     struct Info
988     {
989         Pipeline::Type pipelineType;
990         PipelineScoped<TStructure> &outStruct;
991         const TVariable *globalInstanceVar;
992         const std::vector<sh::ShaderVariable> *variableInfo;
993     };
994 
995     Info infos[] = {
996         {Pipeline::Type::InstanceId, outStructs.instanceId, nullptr, nullptr},
997         {Pipeline::Type::Texture, outStructs.texture, nullptr, nullptr},
998         {Pipeline::Type::NonConstantGlobals, outStructs.nonConstantGlobals, nullptr, nullptr},
999         {Pipeline::Type::AngleUniforms, outStructs.angleUniforms,
1000          angleUniformsGlobalInstanceVar.getDriverUniformsVariable(), nullptr},
1001         {Pipeline::Type::UserUniforms, outStructs.userUniforms, nullptr, nullptr},
1002         {Pipeline::Type::VertexIn, outStructs.vertexIn, nullptr, &inputVaryings},
1003         {Pipeline::Type::VertexOut, outStructs.vertexOut, nullptr, &outputVaryings},
1004         {Pipeline::Type::FragmentIn, outStructs.fragmentIn, nullptr, &inputVaryings},
1005         {Pipeline::Type::FragmentOut, outStructs.fragmentOut, nullptr, &outputVaryings},
1006         {Pipeline::Type::InvocationVertexGlobals, outStructs.invocationVertexGlobals, nullptr,
1007          nullptr},
1008         {Pipeline::Type::InvocationFragmentGlobals, outStructs.invocationFragmentGlobals, nullptr,
1009          &inputVaryings},
1010         {Pipeline::Type::UniformBuffer, outStructs.uniformBuffers, nullptr, nullptr},
1011     };
1012 
1013     for (Info &info : infos)
1014     {
1015         Pipeline pipeline{info.pipelineType, info.globalInstanceVar};
1016         if (!RewritePipeline(compiler, root, idGen, pipeline, symbolEnv, info.variableInfo,
1017                              info.outStruct))
1018         {
1019             return false;
1020         }
1021     }
1022 
1023     return true;
1024 }
1025