• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include <cstring>
8 #include <unordered_map>
9 #include <unordered_set>
10 
11 #include "compiler/translator/TranslatorMetalDirect.h"
12 #include "compiler/translator/TranslatorMetalDirect/AstHelpers.h"
13 #include "compiler/translator/TranslatorMetalDirect/DiscoverDependentFunctions.h"
14 #include "compiler/translator/TranslatorMetalDirect/IdGen.h"
15 #include "compiler/translator/TranslatorMetalDirect/MapSymbols.h"
16 #include "compiler/translator/TranslatorMetalDirect/Pipeline.h"
17 #include "compiler/translator/TranslatorMetalDirect/RewritePipelines.h"
18 #include "compiler/translator/TranslatorMetalDirect/SymbolEnv.h"
19 #include "compiler/translator/tree_ops/PruneNoOps.h"
20 #include "compiler/translator/tree_util/DriverUniform.h"
21 #include "compiler/translator/tree_util/FindMain.h"
22 #include "compiler/translator/tree_util/IntermRebuild.h"
23 #include "compiler/translator/tree_util/IntermTraverse.h"
24 using namespace sh;
25 
26 ////////////////////////////////////////////////////////////////////////////////
27 
28 namespace
29 {
30 
IsVariableInvariant(const std::vector<sh::ShaderVariable> & mVars,const ImmutableString & name)31 bool IsVariableInvariant(const std::vector<sh::ShaderVariable> &mVars, const ImmutableString &name)
32 {
33     for (const auto &var : mVars)
34     {
35         if (name == var.name)
36         {
37             return var.isInvariant;
38         }
39     }
40     // TODO(kpidington): this should be UNREACHABLE() but isn't because the translator generates
41     // declarations to unused built-in variables.
42     return false;
43 }
44 
45 using VariableSet  = std::unordered_set<const TVariable *>;
46 using VariableList = std::vector<const TVariable *>;
47 
48 ////////////////////////////////////////////////////////////////////////////////
49 
50 struct PipelineStructInfo
51 {
52     VariableSet pipelineVariables;
53     PipelineScoped<TStructure> pipelineStruct;
54     const TFunction *funcOriginalToModified = nullptr;
55     const TFunction *funcModifiedToOriginal = nullptr;
56 
isEmpty__anon940233fe0111::PipelineStructInfo57     bool isEmpty() const
58     {
59         if (pipelineStruct.isTotallyEmpty())
60         {
61             ASSERT(pipelineVariables.empty());
62             return true;
63         }
64         else
65         {
66             ASSERT(pipelineStruct.isTotallyFull());
67             ASSERT(!pipelineVariables.empty());
68             return false;
69         }
70     }
71 };
72 
73 class GeneratePipelineStruct : private TIntermRebuild
74 {
75   private:
76     const Pipeline &mPipeline;
77     SymbolEnv &mSymbolEnv;
78     const std::vector<sh::ShaderVariable> *mVariableInfos;
79     VariableList mPipelineVariableList;
80     IdGen &mIdGen;
81     PipelineStructInfo mInfo;
82 
83   public:
Exec(PipelineStructInfo & out,TCompiler & compiler,TIntermBlock & root,IdGen & idGen,const Pipeline & pipeline,SymbolEnv & symbolEnv,const std::vector<sh::ShaderVariable> * variableInfos)84     static bool Exec(PipelineStructInfo &out,
85                      TCompiler &compiler,
86                      TIntermBlock &root,
87                      IdGen &idGen,
88                      const Pipeline &pipeline,
89                      SymbolEnv &symbolEnv,
90                      const std::vector<sh::ShaderVariable> *variableInfos)
91     {
92         GeneratePipelineStruct self(compiler, idGen, pipeline, symbolEnv, variableInfos);
93         if (!self.exec(root))
94         {
95             return false;
96         }
97         out = self.mInfo;
98         return true;
99     }
100 
101   private:
GeneratePipelineStruct(TCompiler & compiler,IdGen & idGen,const Pipeline & pipeline,SymbolEnv & symbolEnv,const std::vector<sh::ShaderVariable> * variableInfos)102     GeneratePipelineStruct(TCompiler &compiler,
103                            IdGen &idGen,
104                            const Pipeline &pipeline,
105                            SymbolEnv &symbolEnv,
106                            const std::vector<sh::ShaderVariable> *variableInfos)
107         : TIntermRebuild(compiler, true, true),
108           mPipeline(pipeline),
109           mSymbolEnv(symbolEnv),
110           mVariableInfos(variableInfos),
111           mIdGen(idGen)
112     {}
113 
exec(TIntermBlock & root)114     bool exec(TIntermBlock &root)
115     {
116         if (!rebuildRoot(root))
117         {
118             return false;
119         }
120 
121         if (mInfo.pipelineVariables.empty())
122         {
123             return true;
124         }
125 
126         TIntermSequence seq;
127 
128         const TStructure &pipelineStruct = [&]() -> const TStructure & {
129             if (mPipeline.globalInstanceVar)
130             {
131                 return *mPipeline.globalInstanceVar->getType().getStruct();
132             }
133             else
134             {
135                 return createInternalPipelineStruct(root, seq);
136             }
137         }();
138 
139         ModifiedStructMachineries modifiedMachineries;
140         const bool isUBO    = mPipeline.type == Pipeline::Type::UniformBuffer;
141         const bool modified = TryCreateModifiedStruct(
142             mCompiler, mSymbolEnv, mIdGen, mPipeline.externalStructModifyConfig(), pipelineStruct,
143             mPipeline.getStructTypeName(Pipeline::Variant::Modified), modifiedMachineries, isUBO,
144             !isUBO);
145 
146         if (modified)
147         {
148             ASSERT(mPipeline.type != Pipeline::Type::Texture);
149             ASSERT(mPipeline.type == Pipeline::Type::AngleUniforms ||
150                    !mPipeline.globalInstanceVar);  // This shouldn't happen by construction.
151 
152             auto getFunction = [](sh::TIntermFunctionDefinition *funcDecl) {
153                 return funcDecl ? funcDecl->getFunction() : nullptr;
154             };
155 
156             const size_t size = modifiedMachineries.size();
157             ASSERT(size > 0);
158             for (size_t i = 0; i < size; ++i)
159             {
160                 const ModifiedStructMachinery &machinery = modifiedMachineries.at(i);
161                 ASSERT(machinery.modifiedStruct);
162 
163                 seq.push_back(new TIntermDeclaration{
164                     &CreateStructTypeVariable(mSymbolTable, *machinery.modifiedStruct)});
165 
166                 if (mPipeline.isPipelineOut())
167                 {
168                     ASSERT(machinery.funcOriginalToModified);
169                     ASSERT(!machinery.funcModifiedToOriginal);
170                     seq.push_back(machinery.funcOriginalToModified);
171                 }
172                 else
173                 {
174                     ASSERT(machinery.funcModifiedToOriginal);
175                     ASSERT(!machinery.funcOriginalToModified);
176                     seq.push_back(machinery.funcModifiedToOriginal);
177                 }
178 
179                 if (i == size - 1)
180                 {
181                     mInfo.funcOriginalToModified = getFunction(machinery.funcOriginalToModified);
182                     mInfo.funcModifiedToOriginal = getFunction(machinery.funcModifiedToOriginal);
183 
184                     mInfo.pipelineStruct.internal = &pipelineStruct;
185                     mInfo.pipelineStruct.external =
186                         modified ? machinery.modifiedStruct : &pipelineStruct;
187                 }
188             }
189         }
190         else
191         {
192             mInfo.pipelineStruct.internal = &pipelineStruct;
193             mInfo.pipelineStruct.external = &pipelineStruct;
194         }
195 
196         root.insertChildNodes(FindMainIndex(&root), seq);
197 
198         return true;
199     }
200 
201   private:
visitFunctionDefinitionPre(TIntermFunctionDefinition & node)202     PreResult visitFunctionDefinitionPre(TIntermFunctionDefinition &node) override
203     {
204         return {node, VisitBits::Neither};
205     }
visitDeclarationPost(TIntermDeclaration & declNode)206     PostResult visitDeclarationPost(TIntermDeclaration &declNode) override
207     {
208         Declaration decl     = ViewDeclaration(declNode);
209         const TVariable &var = decl.symbol.variable();
210         if (mPipeline.uses(var))
211         {
212             ASSERT(mInfo.pipelineVariables.find(&var) == mInfo.pipelineVariables.end());
213             mInfo.pipelineVariables.insert(&var);
214             mPipelineVariableList.push_back(&var);
215             return nullptr;
216         }
217 
218         return declNode;
219     }
220 
createInternalPipelineStruct(TIntermBlock & root,TIntermSequence & outDeclSeq)221     const TStructure &createInternalPipelineStruct(TIntermBlock &root, TIntermSequence &outDeclSeq)
222     {
223         auto &fields = *new TFieldList();
224 
225         switch (mPipeline.type)
226         {
227             case Pipeline::Type::Texture:
228             {
229                 for (const TVariable *var : mPipelineVariableList)
230                 {
231                     const TType &varType         = var->getType();
232                     const TBasicType samplerType = varType.getBasicType();
233 
234                     const TStructure &textureEnv = mSymbolEnv.getTextureEnv(samplerType);
235                     auto *textureEnvType         = new TType(&textureEnv, false);
236                     if (varType.isArray())
237                     {
238                         textureEnvType->makeArrays(varType.getArraySizes());
239                     }
240 
241                     fields.push_back(
242                         new TField(textureEnvType, var->name(), kNoSourceLoc, var->symbolType()));
243                 }
244             }
245             break;
246 
247             case Pipeline::Type::UniformBuffer:
248             {
249                 for (const TVariable *var : mPipelineVariableList)
250                 {
251                     auto &type  = CloneType(var->getType());
252                     auto *field = new TField(&type, var->name(), kNoSourceLoc, var->symbolType());
253                     mSymbolEnv.markAsPointer(*field, AddressSpace::Constant);
254                     mSymbolEnv.markAsUBO(*field);
255                     mSymbolEnv.markAsPointer(*var, AddressSpace::Constant);
256                     fields.push_back(field);
257                 }
258             }
259             break;
260             default:
261             {
262                 for (const TVariable *var : mPipelineVariableList)
263                 {
264                     auto &type = CloneType(var->getType());
265                     if (mVariableInfos && IsVariableInvariant(*mVariableInfos, var->name()))
266                     {
267                         type.setInvariant(true);
268                     }
269                     auto *field = new TField(&type, var->name(), kNoSourceLoc, var->symbolType());
270                     fields.push_back(field);
271                 }
272             }
273             break;
274         }
275 
276         Name pipelineStructName = mPipeline.getStructTypeName(Pipeline::Variant::Original);
277         auto &s = *new TStructure(&mSymbolTable, pipelineStructName.rawName(), &fields,
278                                   pipelineStructName.symbolType());
279 
280         outDeclSeq.push_back(new TIntermDeclaration{&CreateStructTypeVariable(mSymbolTable, s)});
281 
282         return s;
283     }
284 };
285 
286 ////////////////////////////////////////////////////////////////////////////////
287 
CreatePipelineMainLocalVar(TSymbolTable & symbolTable,const Pipeline & pipeline,PipelineScoped<TStructure> pipelineStruct)288 PipelineScoped<TVariable> CreatePipelineMainLocalVar(TSymbolTable &symbolTable,
289                                                      const Pipeline &pipeline,
290                                                      PipelineScoped<TStructure> pipelineStruct)
291 {
292     ASSERT(pipelineStruct.isTotallyFull());
293 
294     PipelineScoped<TVariable> pipelineMainLocalVar;
295 
296     auto populateExternalMainLocalVar = [&]() {
297         ASSERT(!pipelineMainLocalVar.external);
298         pipelineMainLocalVar.external = &CreateInstanceVariable(
299             symbolTable, *pipelineStruct.external,
300             pipeline.getStructInstanceName(pipelineStruct.isUniform()
301                                                ? Pipeline::Variant::Original
302                                                : Pipeline::Variant::Modified));
303     };
304 
305     auto populateDistinctInternalMainLocalVar = [&]() {
306         ASSERT(!pipelineMainLocalVar.internal);
307         pipelineMainLocalVar.internal =
308             &CreateInstanceVariable(symbolTable, *pipelineStruct.internal,
309                                     pipeline.getStructInstanceName(Pipeline::Variant::Original));
310     };
311 
312     if (pipeline.type == Pipeline::Type::InstanceId)
313     {
314         populateDistinctInternalMainLocalVar();
315     }
316     else if (pipeline.alwaysRequiresLocalVariableDeclarationInMain())
317     {
318         populateExternalMainLocalVar();
319 
320         if (pipelineStruct.isUniform())
321         {
322             pipelineMainLocalVar.internal = pipelineMainLocalVar.external;
323         }
324         else
325         {
326             populateDistinctInternalMainLocalVar();
327         }
328     }
329     else if (!pipelineStruct.isUniform())
330     {
331         populateDistinctInternalMainLocalVar();
332     }
333 
334     return pipelineMainLocalVar;
335 }
336 
337 class PipelineFunctionEnv
338 {
339   private:
340     TCompiler &mCompiler;
341     SymbolEnv &mSymbolEnv;
342     TSymbolTable &mSymbolTable;
343     IdGen &mIdGen;
344     const Pipeline &mPipeline;
345     const std::unordered_set<const TFunction *> &mPipelineFunctions;
346     const PipelineScoped<TStructure> mPipelineStruct;
347     PipelineScoped<TVariable> &mPipelineMainLocalVar;
348 
349     std::unordered_map<const TFunction *, const TFunction *> mFuncMap;
350 
351   public:
PipelineFunctionEnv(TCompiler & compiler,SymbolEnv & symbolEnv,IdGen & idGen,const Pipeline & pipeline,const std::unordered_set<const TFunction * > & pipelineFunctions,PipelineScoped<TStructure> pipelineStruct,PipelineScoped<TVariable> & pipelineMainLocalVar)352     PipelineFunctionEnv(TCompiler &compiler,
353                         SymbolEnv &symbolEnv,
354                         IdGen &idGen,
355                         const Pipeline &pipeline,
356                         const std::unordered_set<const TFunction *> &pipelineFunctions,
357                         PipelineScoped<TStructure> pipelineStruct,
358                         PipelineScoped<TVariable> &pipelineMainLocalVar)
359         : mCompiler(compiler),
360           mSymbolEnv(symbolEnv),
361           mSymbolTable(symbolEnv.symbolTable()),
362           mIdGen(idGen),
363           mPipeline(pipeline),
364           mPipelineFunctions(pipelineFunctions),
365           mPipelineStruct(pipelineStruct),
366           mPipelineMainLocalVar(pipelineMainLocalVar)
367     {}
368 
isOriginalPipelineFunction(const TFunction & func) const369     bool isOriginalPipelineFunction(const TFunction &func) const
370     {
371         return mPipelineFunctions.find(&func) != mPipelineFunctions.end();
372     }
373 
isUpdatedPipelineFunction(const TFunction & func) const374     bool isUpdatedPipelineFunction(const TFunction &func) const
375     {
376         auto it = mFuncMap.find(&func);
377         if (it == mFuncMap.end())
378         {
379             return false;
380         }
381         return &func == it->second;
382     }
383 
getUpdatedFunction(const TFunction & func)384     const TFunction &getUpdatedFunction(const TFunction &func)
385     {
386         ASSERT(isOriginalPipelineFunction(func) || isUpdatedPipelineFunction(func));
387 
388         const TFunction *newFunc;
389 
390         auto it = mFuncMap.find(&func);
391         if (it == mFuncMap.end())
392         {
393             const bool isMain = func.isMain();
394 
395             if (isMain && mPipeline.isPipelineOut())
396             {
397                 ASSERT(func.getReturnType().getBasicType() == TBasicType::EbtVoid);
398                 newFunc = &CloneFunctionAndChangeReturnType(mSymbolTable, nullptr, func,
399                                                             *mPipelineStruct.external);
400             }
401             else if (isMain && (mPipeline.type == Pipeline::Type::InvocationVertexGlobals ||
402                                 mPipeline.type == Pipeline::Type::InvocationFragmentGlobals))
403             {
404                 std::vector<const TVariable *> variables;
405                 for (const TField *field : mPipelineStruct.external->fields())
406                 {
407                     variables.push_back(new TVariable(&mSymbolTable, field->name(), field->type(),
408                                                       field->symbolType()));
409                 }
410                 newFunc = &CloneFunctionAndAppendParams(mSymbolTable, nullptr, func, variables);
411             }
412             else if (isMain && mPipeline.type == Pipeline::Type::Texture)
413             {
414                 std::vector<const TVariable *> variables;
415                 TranslatorMetalReflection *reflection =
416                     mtl::getTranslatorMetalReflection(&mCompiler);
417                 for (const TField *field : mPipelineStruct.external->fields())
418                 {
419                     const TStructure *textureEnv = field->type()->getStruct();
420                     ASSERT(textureEnv && textureEnv->fields().size() == 2);
421                     for (const TField *subfield : textureEnv->fields())
422                     {
423                         const Name name = mIdGen.createNewName({field->name(), subfield->name()});
424                         TType &type     = *new TType(*subfield->type());
425                         ASSERT(!type.isArray());
426                         type.makeArrays(field->type()->getArraySizes());
427                         auto *var =
428                             new TVariable(&mSymbolTable, name.rawName(), &type, name.symbolType());
429                         variables.push_back(var);
430                         reflection->addOriginalName(var->uniqueId().get(), field->name().data());
431                     }
432                 }
433                 newFunc = &CloneFunctionAndAppendParams(mSymbolTable, nullptr, func, variables);
434             }
435             else if (isMain && mPipeline.type == Pipeline::Type::InstanceId)
436             {
437                 Name name = mPipeline.getStructInstanceName(Pipeline::Variant::Modified);
438                 auto *var = new TVariable(&mSymbolTable, name.rawName(),
439                                           new TType(TBasicType::EbtUInt), name.symbolType());
440                 newFunc   = &CloneFunctionAndPrependParam(mSymbolTable, nullptr, func, *var);
441                 mPipelineMainLocalVar.external = var;
442             }
443             else if (isMain && mPipeline.alwaysRequiresLocalVariableDeclarationInMain())
444             {
445                 ASSERT(mPipelineMainLocalVar.isTotallyFull());
446                 newFunc = &func;
447             }
448             else
449             {
450                 const TVariable *var;
451                 AddressSpace addressSpace;
452 
453                 if (isMain && !mPipelineMainLocalVar.isUniform())
454                 {
455                     var = &CreateInstanceVariable(
456                         mSymbolTable, *mPipelineStruct.external,
457                         mPipeline.getStructInstanceName(Pipeline::Variant::Modified));
458                     addressSpace = mPipeline.externalAddressSpace();
459                 }
460                 else
461                 {
462                     if (mPipeline.type == Pipeline::Type::UniformBuffer)
463                     {
464                         TranslatorMetalReflection *reflection =
465                             ((sh::TranslatorMetalDirect *)&mCompiler)
466                                 ->getTranslatorMetalReflection();
467                         // TODO: need more checks to make sure they line up? Could be reordered?
468                         ASSERT(mPipelineStruct.external->fields().size() ==
469                                mPipelineStruct.internal->fields().size());
470                         for (size_t i = 0; i < mPipelineStruct.external->fields().size(); i++)
471                         {
472                             const TField *externalField = mPipelineStruct.external->fields()[i];
473                             const TField *internalField = mPipelineStruct.internal->fields()[i];
474                             const TType &externalType   = *externalField->type();
475                             const TType &internalType   = *internalField->type();
476                             ASSERT(externalType.getBasicType() == internalType.getBasicType());
477                             if (externalType.getBasicType() == TBasicType::EbtStruct)
478                             {
479                                 const TStructure *externalEnv = externalType.getStruct();
480                                 const TStructure *internalEnv = internalType.getStruct();
481                                 const std::string internalName =
482                                     reflection->getOriginalName(internalEnv->uniqueId().get());
483                                 reflection->addOriginalName(externalEnv->uniqueId().get(),
484                                                             internalName);
485                             }
486                         }
487                     }
488                     var = &CreateInstanceVariable(
489                         mSymbolTable, *mPipelineStruct.internal,
490                         mPipeline.getStructInstanceName(Pipeline::Variant::Original));
491                     addressSpace = mPipelineMainLocalVar.isUniform()
492                                        ? mPipeline.externalAddressSpace()
493                                        : AddressSpace::Thread;
494                 }
495 
496                 bool markAsReference = true;
497                 if (isMain)
498                 {
499                     switch (mPipeline.type)
500                     {
501                         case Pipeline::Type::VertexIn:
502                         case Pipeline::Type::FragmentIn:
503                             markAsReference = false;
504                             break;
505 
506                         default:
507                             break;
508                     }
509                 }
510 
511                 if (markAsReference)
512                 {
513                     mSymbolEnv.markAsReference(*var, addressSpace);
514                 }
515 
516                 newFunc = &CloneFunctionAndPrependParam(mSymbolTable, nullptr, func, *var);
517             }
518 
519             mFuncMap[&func]   = newFunc;
520             mFuncMap[newFunc] = newFunc;
521         }
522         else
523         {
524             newFunc = it->second;
525         }
526 
527         return *newFunc;
528     }
529 
createUpdatedFunctionPrototype(TIntermFunctionPrototype & funcProtoNode)530     TIntermFunctionPrototype *createUpdatedFunctionPrototype(
531         TIntermFunctionPrototype &funcProtoNode)
532     {
533         const TFunction &func = *funcProtoNode.getFunction();
534         if (!isOriginalPipelineFunction(func) && !isUpdatedPipelineFunction(func))
535         {
536             return nullptr;
537         }
538         const TFunction &newFunc = getUpdatedFunction(func);
539         return new TIntermFunctionPrototype(&newFunc);
540     }
541 };
542 
543 class UpdatePipelineFunctions : private TIntermRebuild
544 {
545   private:
546     const Pipeline &mPipeline;
547     const PipelineScoped<TStructure> mPipelineStruct;
548     PipelineScoped<TVariable> &mPipelineMainLocalVar;
549     SymbolEnv &mSymbolEnv;
550     PipelineFunctionEnv mEnv;
551     const TFunction *mFuncOriginalToModified;
552     const TFunction *mFuncModifiedToOriginal;
553 
554   public:
ThreadPipeline(TCompiler & compiler,TIntermBlock & root,const Pipeline & pipeline,const std::unordered_set<const TFunction * > & pipelineFunctions,PipelineScoped<TStructure> pipelineStruct,PipelineScoped<TVariable> & pipelineMainLocalVar,IdGen & idGen,SymbolEnv & symbolEnv,const TFunction * funcOriginalToModified,const TFunction * funcModifiedToOriginal)555     static bool ThreadPipeline(TCompiler &compiler,
556                                TIntermBlock &root,
557                                const Pipeline &pipeline,
558                                const std::unordered_set<const TFunction *> &pipelineFunctions,
559                                PipelineScoped<TStructure> pipelineStruct,
560                                PipelineScoped<TVariable> &pipelineMainLocalVar,
561                                IdGen &idGen,
562                                SymbolEnv &symbolEnv,
563                                const TFunction *funcOriginalToModified,
564                                const TFunction *funcModifiedToOriginal)
565     {
566         UpdatePipelineFunctions self(compiler, pipeline, pipelineFunctions, pipelineStruct,
567                                      pipelineMainLocalVar, idGen, symbolEnv, funcOriginalToModified,
568                                      funcModifiedToOriginal);
569         if (!self.rebuildRoot(root))
570         {
571             return false;
572         }
573         return true;
574     }
575 
576   private:
UpdatePipelineFunctions(TCompiler & compiler,const Pipeline & pipeline,const std::unordered_set<const TFunction * > & pipelineFunctions,PipelineScoped<TStructure> pipelineStruct,PipelineScoped<TVariable> & pipelineMainLocalVar,IdGen & idGen,SymbolEnv & symbolEnv,const TFunction * funcOriginalToModified,const TFunction * funcModifiedToOriginal)577     UpdatePipelineFunctions(TCompiler &compiler,
578                             const Pipeline &pipeline,
579                             const std::unordered_set<const TFunction *> &pipelineFunctions,
580                             PipelineScoped<TStructure> pipelineStruct,
581                             PipelineScoped<TVariable> &pipelineMainLocalVar,
582                             IdGen &idGen,
583                             SymbolEnv &symbolEnv,
584                             const TFunction *funcOriginalToModified,
585                             const TFunction *funcModifiedToOriginal)
586         : TIntermRebuild(compiler, false, true),
587           mPipeline(pipeline),
588           mPipelineStruct(pipelineStruct),
589           mPipelineMainLocalVar(pipelineMainLocalVar),
590           mSymbolEnv(symbolEnv),
591           mEnv(compiler,
592                symbolEnv,
593                idGen,
594                pipeline,
595                pipelineFunctions,
596                pipelineStruct,
597                mPipelineMainLocalVar),
598           mFuncOriginalToModified(funcOriginalToModified),
599           mFuncModifiedToOriginal(funcModifiedToOriginal)
600     {
601         ASSERT(mPipelineStruct.isTotallyFull());
602     }
603 
getInternalPipelineVariable(const TFunction & pipelineFunc)604     const TVariable &getInternalPipelineVariable(const TFunction &pipelineFunc)
605     {
606         if (pipelineFunc.isMain() && (mPipeline.alwaysRequiresLocalVariableDeclarationInMain() ||
607                                       !mPipelineMainLocalVar.isUniform()))
608         {
609             ASSERT(mPipelineMainLocalVar.internal);
610             return *mPipelineMainLocalVar.internal;
611         }
612         else
613         {
614             ASSERT(pipelineFunc.getParamCount() > 0);
615             return *pipelineFunc.getParam(0);
616         }
617     }
618 
getExternalPipelineVariable(const TFunction & mainFunc)619     const TVariable &getExternalPipelineVariable(const TFunction &mainFunc)
620     {
621         ASSERT(mainFunc.isMain());
622         if (mPipelineMainLocalVar.external)
623         {
624             return *mPipelineMainLocalVar.external;
625         }
626         else
627         {
628             ASSERT(mainFunc.getParamCount() > 0);
629             return *mainFunc.getParam(0);
630         }
631     }
632 
visitAggregatePost(TIntermAggregate & callNode)633     PostResult visitAggregatePost(TIntermAggregate &callNode) override
634     {
635         if (callNode.isConstructor())
636         {
637             return callNode;
638         }
639         else
640         {
641             const TFunction &oldCalledFunc = *callNode.getFunction();
642             if (!mEnv.isOriginalPipelineFunction(oldCalledFunc))
643             {
644                 return callNode;
645             }
646             const TFunction &newCalledFunc = mEnv.getUpdatedFunction(oldCalledFunc);
647 
648             const TFunction *oldOwnerFunc = getParentFunction();
649             ASSERT(oldOwnerFunc);
650             const TFunction &newOwnerFunc = mEnv.getUpdatedFunction(*oldOwnerFunc);
651 
652             return *TIntermAggregate::CreateFunctionCall(
653                 newCalledFunc, &CloneSequenceAndPrepend(
654                                    *callNode.getSequence(),
655                                    *new TIntermSymbol(&getInternalPipelineVariable(newOwnerFunc))));
656         }
657     }
658 
visitFunctionPrototypePost(TIntermFunctionPrototype & funcProtoNode)659     PostResult visitFunctionPrototypePost(TIntermFunctionPrototype &funcProtoNode) override
660     {
661         TIntermFunctionPrototype *newFuncProtoNode =
662             mEnv.createUpdatedFunctionPrototype(funcProtoNode);
663         if (newFuncProtoNode == nullptr)
664         {
665             return funcProtoNode;
666         }
667         return *newFuncProtoNode;
668     }
669 
visitFunctionDefinitionPost(TIntermFunctionDefinition & funcDefNode)670     PostResult visitFunctionDefinitionPost(TIntermFunctionDefinition &funcDefNode) override
671     {
672         if (funcDefNode.getFunction()->isMain())
673         {
674             return visitMain(funcDefNode);
675         }
676         else
677         {
678             return visitNonMain(funcDefNode);
679         }
680     }
681 
visitNonMain(TIntermFunctionDefinition & funcDefNode)682     TIntermNode &visitNonMain(TIntermFunctionDefinition &funcDefNode)
683     {
684         TIntermFunctionPrototype &funcProtoNode = *funcDefNode.getFunctionPrototype();
685         ASSERT(!funcProtoNode.getFunction()->isMain());
686 
687         TIntermFunctionPrototype *newFuncProtoNode =
688             mEnv.createUpdatedFunctionPrototype(funcProtoNode);
689         if (newFuncProtoNode == nullptr)
690         {
691             return funcDefNode;
692         }
693 
694         const TFunction &func = *newFuncProtoNode->getFunction();
695         ASSERT(!func.isMain());
696 
697         TIntermBlock *body = funcDefNode.getBody();
698 
699         return *new TIntermFunctionDefinition(newFuncProtoNode, body);
700     }
701 
visitMain(TIntermFunctionDefinition & funcDefNode)702     TIntermNode &visitMain(TIntermFunctionDefinition &funcDefNode)
703     {
704         TIntermFunctionPrototype &funcProtoNode = *funcDefNode.getFunctionPrototype();
705         ASSERT(funcProtoNode.getFunction()->isMain());
706 
707         TIntermFunctionPrototype *newFuncProtoNode =
708             mEnv.createUpdatedFunctionPrototype(funcProtoNode);
709         if (newFuncProtoNode == nullptr)
710         {
711             return funcDefNode;
712         }
713 
714         const TFunction &func = *newFuncProtoNode->getFunction();
715         ASSERT(func.isMain());
716 
717         auto callModifiedToOriginal = [&](TIntermBlock &body) {
718             ASSERT(mPipelineMainLocalVar.internal);
719             if (!mPipeline.isPipelineOut())
720             {
721                 ASSERT(mFuncModifiedToOriginal);
722                 auto *m = new TIntermSymbol(&getExternalPipelineVariable(func));
723                 auto *o = new TIntermSymbol(mPipelineMainLocalVar.internal);
724                 body.appendStatement(TIntermAggregate::CreateFunctionCall(
725                     *mFuncModifiedToOriginal, new TIntermSequence{m, o}));
726             }
727         };
728 
729         auto callOriginalToModified = [&](TIntermBlock &body) {
730             ASSERT(mPipelineMainLocalVar.internal);
731             if (mPipeline.isPipelineOut())
732             {
733                 ASSERT(mFuncOriginalToModified);
734                 auto *o = new TIntermSymbol(mPipelineMainLocalVar.internal);
735                 auto *m = new TIntermSymbol(&getExternalPipelineVariable(func));
736                 body.appendStatement(TIntermAggregate::CreateFunctionCall(
737                     *mFuncOriginalToModified, new TIntermSequence{o, m}));
738             }
739         };
740 
741         TIntermBlock *body = funcDefNode.getBody();
742 
743         if (mPipeline.alwaysRequiresLocalVariableDeclarationInMain())
744         {
745             ASSERT(mPipelineMainLocalVar.isTotallyFull());
746 
747             auto *newBody = new TIntermBlock();
748             newBody->appendStatement(new TIntermDeclaration{mPipelineMainLocalVar.internal});
749 
750             if (mPipeline.type == Pipeline::Type::InvocationVertexGlobals ||
751                 mPipeline.type == Pipeline::Type::InvocationFragmentGlobals)
752             {
753                 // Populate struct instance with references to global pipeline variables.
754                 for (const TField *field : mPipelineStruct.external->fields())
755                 {
756                     auto *var        = new TVariable(&mSymbolTable, field->name(), field->type(),
757                                               field->symbolType());
758                     auto *symbol     = new TIntermSymbol(var);
759                     auto &accessNode = AccessField(*mPipelineMainLocalVar.internal, var->name());
760                     auto *assignNode = new TIntermBinary(TOperator::EOpAssign, &accessNode, symbol);
761                     newBody->appendStatement(assignNode);
762                 }
763             }
764             else if (mPipeline.type == Pipeline::Type::Texture)
765             {
766                 const TFieldList &fields = mPipelineStruct.external->fields();
767 
768                 ASSERT(func.getParamCount() >= 2 * fields.size());
769                 size_t paramIndex = func.getParamCount() - 2 * fields.size();
770 
771                 for (const TField *field : fields)
772                 {
773                     const TVariable &textureParam = *func.getParam(paramIndex++);
774                     const TVariable &samplerParam = *func.getParam(paramIndex++);
775 
776                     auto go = [&](TIntermTyped &env, const int *index) {
777                         TIntermTyped &textureField = AccessField(
778                             AccessIndex(*env.deepCopy(), index), ImmutableString("texture"));
779                         TIntermTyped &samplerField = AccessField(
780                             AccessIndex(*env.deepCopy(), index), ImmutableString("sampler"));
781 
782                         auto mkAssign = [&](TIntermTyped &field, const TVariable &param) {
783                             return new TIntermBinary(TOperator::EOpAssign, &field,
784                                                      &mSymbolEnv.callFunctionOverload(
785                                                          Name("addressof"), field.getType(),
786                                                          *new TIntermSequence{&AccessIndex(
787                                                              *new TIntermSymbol(&param), index)}));
788                         };
789 
790                         newBody->appendStatement(mkAssign(textureField, textureParam));
791                         newBody->appendStatement(mkAssign(samplerField, samplerParam));
792                     };
793 
794                     TIntermTyped &env = AccessField(*mPipelineMainLocalVar.internal, field->name());
795                     const TType &envType = env.getType();
796 
797                     if (envType.isArray())
798                     {
799                         ASSERT(!envType.isArrayOfArrays());
800                         const auto n = static_cast<int>(envType.getArraySizeProduct());
801                         for (int i = 0; i < n; ++i)
802                         {
803                             go(env, &i);
804                         }
805                     }
806                     else
807                     {
808                         go(env, nullptr);
809                     }
810                 }
811             }
812             else if (mPipeline.type == Pipeline::Type::InstanceId)
813             {
814                 newBody->appendStatement(new TIntermBinary(
815                     TOperator::EOpAssign,
816                     &AccessFieldByIndex(*new TIntermSymbol(&getInternalPipelineVariable(func)), 0),
817                     &AsType(mSymbolEnv, *new TType(TBasicType::EbtInt),
818                             *new TIntermSymbol(&getExternalPipelineVariable(func)))));
819             }
820             else if (!mPipelineMainLocalVar.isUniform())
821             {
822                 newBody->appendStatement(new TIntermDeclaration{mPipelineMainLocalVar.external});
823                 callModifiedToOriginal(*newBody);
824             }
825 
826             newBody->appendStatement(body);
827 
828             if (!mPipelineMainLocalVar.isUniform())
829             {
830                 callOriginalToModified(*newBody);
831             }
832 
833             if (mPipeline.isPipelineOut())
834             {
835                 newBody->appendStatement(new TIntermBranch(
836                     TOperator::EOpReturn, new TIntermSymbol(mPipelineMainLocalVar.external)));
837             }
838 
839             body = newBody;
840         }
841         else if (!mPipelineMainLocalVar.isUniform())
842         {
843             ASSERT(!mPipelineMainLocalVar.external);
844             ASSERT(mPipelineMainLocalVar.internal);
845 
846             auto *newBody = new TIntermBlock();
847             newBody->appendStatement(new TIntermDeclaration{mPipelineMainLocalVar.internal});
848             callModifiedToOriginal(*newBody);
849             newBody->appendStatement(body);
850             callOriginalToModified(*newBody);
851             body = newBody;
852         }
853 
854         return *new TIntermFunctionDefinition(newFuncProtoNode, body);
855     }
856 };
857 
858 ////////////////////////////////////////////////////////////////////////////////
859 
UpdatePipelineSymbols(Pipeline::Type pipelineType,TCompiler & compiler,TIntermBlock & root,SymbolEnv & symbolEnv,const VariableSet & pipelineVariables,PipelineScoped<TVariable> pipelineMainLocalVar)860 bool UpdatePipelineSymbols(Pipeline::Type pipelineType,
861                            TCompiler &compiler,
862                            TIntermBlock &root,
863                            SymbolEnv &symbolEnv,
864                            const VariableSet &pipelineVariables,
865                            PipelineScoped<TVariable> pipelineMainLocalVar)
866 {
867     auto map = [&](const TFunction *owner, TIntermSymbol &symbol) -> TIntermNode & {
868         if (!owner)
869             return symbol;
870         const TVariable &var = symbol.variable();
871         if (pipelineVariables.find(&var) == pipelineVariables.end())
872         {
873             return symbol;
874         }
875         const TVariable *structInstanceVar;
876         if (owner->isMain())
877         {
878             ASSERT(pipelineMainLocalVar.internal);
879             structInstanceVar = pipelineMainLocalVar.internal;
880         }
881         else
882         {
883             ASSERT(owner->getParamCount() > 0);
884             structInstanceVar = owner->getParam(0);
885         }
886         ASSERT(structInstanceVar);
887         return AccessField(*structInstanceVar, var.name());
888     };
889     return MapSymbols(compiler, root, map);
890 }
891 
892 ////////////////////////////////////////////////////////////////////////////////
893 
RewritePipeline(TCompiler & compiler,TIntermBlock & root,IdGen & idGen,const Pipeline & pipeline,SymbolEnv & symbolEnv,const std::vector<sh::ShaderVariable> * variableInfo,PipelineScoped<TStructure> & outStruct)894 bool RewritePipeline(TCompiler &compiler,
895                      TIntermBlock &root,
896                      IdGen &idGen,
897                      const Pipeline &pipeline,
898                      SymbolEnv &symbolEnv,
899                      const std::vector<sh::ShaderVariable> *variableInfo,
900                      PipelineScoped<TStructure> &outStruct)
901 {
902     ASSERT(outStruct.isTotallyEmpty());
903 
904     TSymbolTable &symbolTable = compiler.getSymbolTable();
905 
906     PipelineStructInfo psi;
907     if (!GeneratePipelineStruct::Exec(psi, compiler, root, idGen, pipeline, symbolEnv,
908                                       variableInfo))
909     {
910         return false;
911     }
912 
913     if (psi.isEmpty())
914     {
915         return true;
916     }
917 
918     const auto pipelineFunctions = DiscoverDependentFunctions(root, [&](const TVariable &var) {
919         return psi.pipelineVariables.find(&var) != psi.pipelineVariables.end();
920     });
921 
922     auto pipelineMainLocalVar =
923         CreatePipelineMainLocalVar(symbolTable, pipeline, psi.pipelineStruct);
924 
925     if (!UpdatePipelineFunctions::ThreadPipeline(
926             compiler, root, pipeline, pipelineFunctions, psi.pipelineStruct, pipelineMainLocalVar,
927             idGen, symbolEnv, psi.funcOriginalToModified, psi.funcModifiedToOriginal))
928     {
929         return false;
930     }
931 
932     if (!pipeline.globalInstanceVar)
933     {
934         if (!UpdatePipelineSymbols(pipeline.type, compiler, root, symbolEnv, psi.pipelineVariables,
935                                    pipelineMainLocalVar))
936         {
937             return false;
938         }
939     }
940 
941     if (!PruneNoOps(&compiler, &root, &compiler.getSymbolTable()))
942     {
943         return false;
944     }
945 
946     outStruct = psi.pipelineStruct;
947     return true;
948 }
949 
950 }  // anonymous namespace
951 
RewritePipelines(TCompiler & compiler,TIntermBlock & root,const std::vector<sh::ShaderVariable> & inputVaryings,const std::vector<sh::ShaderVariable> & outputVaryings,IdGen & idGen,DriverUniform & angleUniformsGlobalInstanceVar,SymbolEnv & symbolEnv,PipelineStructs & outStructs)952 bool sh::RewritePipelines(TCompiler &compiler,
953                           TIntermBlock &root,
954                           const std::vector<sh::ShaderVariable> &inputVaryings,
955                           const std::vector<sh::ShaderVariable> &outputVaryings,
956                           IdGen &idGen,
957                           DriverUniform &angleUniformsGlobalInstanceVar,
958                           SymbolEnv &symbolEnv,
959                           PipelineStructs &outStructs)
960 {
961     struct Info
962     {
963         Pipeline::Type pipelineType;
964         PipelineScoped<TStructure> &outStruct;
965         const TVariable *globalInstanceVar;
966         const std::vector<sh::ShaderVariable> *variableInfo;
967     };
968 
969     Info infos[] = {
970         {Pipeline::Type::InstanceId, outStructs.instanceId, nullptr, nullptr},
971         {Pipeline::Type::Texture, outStructs.texture, nullptr, nullptr},
972         {Pipeline::Type::NonConstantGlobals, outStructs.nonConstantGlobals, nullptr, nullptr},
973         {Pipeline::Type::AngleUniforms, outStructs.angleUniforms,
974          angleUniformsGlobalInstanceVar.getDriverUniformsVariable(), nullptr},
975         {Pipeline::Type::UserUniforms, outStructs.userUniforms, nullptr, nullptr},
976         {Pipeline::Type::VertexIn, outStructs.vertexIn, nullptr, &inputVaryings},
977         {Pipeline::Type::VertexOut, outStructs.vertexOut, nullptr, &outputVaryings},
978         {Pipeline::Type::FragmentIn, outStructs.fragmentIn, nullptr, &inputVaryings},
979         {Pipeline::Type::FragmentOut, outStructs.fragmentOut, nullptr, &outputVaryings},
980         {Pipeline::Type::InvocationVertexGlobals, outStructs.invocationVertexGlobals, nullptr,
981          nullptr},
982         {Pipeline::Type::InvocationFragmentGlobals, outStructs.invocationFragmentGlobals, nullptr,
983          &inputVaryings},
984         {Pipeline::Type::UniformBuffer, outStructs.uniformBuffers, nullptr, nullptr},
985     };
986 
987     for (Info &info : infos)
988     {
989         Pipeline pipeline{info.pipelineType, info.globalInstanceVar};
990         if (!RewritePipeline(compiler, root, idGen, pipeline, symbolEnv, info.variableInfo,
991                              info.outStruct))
992         {
993             return false;
994         }
995     }
996 
997     return true;
998 }
999