• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include <cstring>
8 #include <unordered_map>
9 #include <unordered_set>
10 
11 #include "compiler/translator/msl/AstHelpers.h"
12 #include "compiler/translator/msl/DiscoverDependentFunctions.h"
13 #include "compiler/translator/msl/IdGen.h"
14 #include "compiler/translator/msl/IntermRebuild.h"
15 #include "compiler/translator/msl/MapSymbols.h"
16 #include "compiler/translator/msl/Pipeline.h"
17 #include "compiler/translator/msl/RewritePipelines.h"
18 #include "compiler/translator/msl/SymbolEnv.h"
19 #include "compiler/translator/msl/TranslatorMSL.h"
20 #include "compiler/translator/tree_ops/PruneNoOps.h"
21 #include "compiler/translator/tree_util/DriverUniform.h"
22 #include "compiler/translator/tree_util/FindMain.h"
23 #include "compiler/translator/tree_util/IntermTraverse.h"
24 using namespace sh;
25 
26 ////////////////////////////////////////////////////////////////////////////////
27 
28 namespace
29 {
30 
IsVariableInvariant(const std::vector<sh::ShaderVariable> & mVars,const ImmutableString & name)31 bool IsVariableInvariant(const std::vector<sh::ShaderVariable> &mVars, const ImmutableString &name)
32 {
33     for (const auto &var : mVars)
34     {
35         if (name == var.name)
36         {
37             return var.isInvariant;
38         }
39     }
40     // TODO(kpidington): this should be UNREACHABLE() but isn't because the translator generates
41     // declarations to unused built-in variables.
42     return false;
43 }
44 
45 using VariableSet  = std::unordered_set<const TVariable *>;
46 using VariableList = std::vector<const TVariable *>;
47 
48 ////////////////////////////////////////////////////////////////////////////////
49 
50 struct PipelineStructInfo
51 {
52     VariableSet pipelineVariables;
53     PipelineScoped<TStructure> pipelineStruct;
54     const TFunction *funcOriginalToModified = nullptr;
55     const TFunction *funcModifiedToOriginal = nullptr;
56 
isEmpty__anonb33832dd0111::PipelineStructInfo57     bool isEmpty() const
58     {
59         if (pipelineStruct.isTotallyEmpty())
60         {
61             ASSERT(pipelineVariables.empty());
62             return true;
63         }
64         else
65         {
66             ASSERT(pipelineStruct.isTotallyFull());
67             ASSERT(!pipelineVariables.empty());
68             return false;
69         }
70     }
71 };
72 
73 class GeneratePipelineStruct : private TIntermRebuild
74 {
75   private:
76     const Pipeline &mPipeline;
77     SymbolEnv &mSymbolEnv;
78     const std::vector<sh::ShaderVariable> *mVariableInfos;
79     VariableList mPipelineVariableList;
80     IdGen &mIdGen;
81     PipelineStructInfo mInfo;
82 
83   public:
Exec(PipelineStructInfo & out,TCompiler & compiler,TIntermBlock & root,IdGen & idGen,const Pipeline & pipeline,SymbolEnv & symbolEnv,const std::vector<sh::ShaderVariable> * variableInfos)84     static bool Exec(PipelineStructInfo &out,
85                      TCompiler &compiler,
86                      TIntermBlock &root,
87                      IdGen &idGen,
88                      const Pipeline &pipeline,
89                      SymbolEnv &symbolEnv,
90                      const std::vector<sh::ShaderVariable> *variableInfos)
91     {
92         GeneratePipelineStruct self(compiler, idGen, pipeline, symbolEnv, variableInfos);
93         if (!self.exec(root))
94         {
95             return false;
96         }
97         out = self.mInfo;
98         return true;
99     }
100 
101   private:
GeneratePipelineStruct(TCompiler & compiler,IdGen & idGen,const Pipeline & pipeline,SymbolEnv & symbolEnv,const std::vector<sh::ShaderVariable> * variableInfos)102     GeneratePipelineStruct(TCompiler &compiler,
103                            IdGen &idGen,
104                            const Pipeline &pipeline,
105                            SymbolEnv &symbolEnv,
106                            const std::vector<sh::ShaderVariable> *variableInfos)
107         : TIntermRebuild(compiler, true, true),
108           mPipeline(pipeline),
109           mSymbolEnv(symbolEnv),
110           mVariableInfos(variableInfos),
111           mIdGen(idGen)
112     {}
113 
exec(TIntermBlock & root)114     bool exec(TIntermBlock &root)
115     {
116         if (!rebuildRoot(root))
117         {
118             return false;
119         }
120 
121         if (mInfo.pipelineVariables.empty())
122         {
123             return true;
124         }
125 
126         TIntermSequence seq;
127 
128         const TStructure &pipelineStruct = [&]() -> const TStructure & {
129             if (mPipeline.globalInstanceVar)
130             {
131                 return *mPipeline.globalInstanceVar->getType().getStruct();
132             }
133             else
134             {
135                 return createInternalPipelineStruct(root, seq);
136             }
137         }();
138 
139         ModifiedStructMachineries modifiedMachineries;
140         const bool isUBO     = mPipeline.type == Pipeline::Type::UniformBuffer;
141         const bool isUniform = mPipeline.type == Pipeline::Type::UniformBuffer ||
142                                mPipeline.type == Pipeline::Type::UserUniforms;
143         const bool modified = TryCreateModifiedStruct(
144             mCompiler, mSymbolEnv, mIdGen, mPipeline.externalStructModifyConfig(), pipelineStruct,
145             mPipeline.getStructTypeName(Pipeline::Variant::Modified), modifiedMachineries, isUBO,
146             !isUniform);
147 
148         if (modified)
149         {
150             ASSERT(mPipeline.type != Pipeline::Type::Texture);
151             ASSERT(mPipeline.type == Pipeline::Type::AngleUniforms ||
152                    !mPipeline.globalInstanceVar);  // This shouldn't happen by construction.
153 
154             auto getFunction = [](sh::TIntermFunctionDefinition *funcDecl) {
155                 return funcDecl ? funcDecl->getFunction() : nullptr;
156             };
157 
158             const size_t size = modifiedMachineries.size();
159             ASSERT(size > 0);
160             for (size_t i = 0; i < size; ++i)
161             {
162                 const ModifiedStructMachinery &machinery = modifiedMachineries.at(i);
163                 ASSERT(machinery.modifiedStruct);
164 
165                 seq.push_back(new TIntermDeclaration{
166                     &CreateStructTypeVariable(mSymbolTable, *machinery.modifiedStruct)});
167 
168                 if (mPipeline.isPipelineOut())
169                 {
170                     ASSERT(machinery.funcOriginalToModified);
171                     ASSERT(!machinery.funcModifiedToOriginal);
172                     seq.push_back(machinery.funcOriginalToModified);
173                 }
174                 else
175                 {
176                     ASSERT(machinery.funcModifiedToOriginal);
177                     ASSERT(!machinery.funcOriginalToModified);
178                     seq.push_back(machinery.funcModifiedToOriginal);
179                 }
180 
181                 if (i == size - 1)
182                 {
183                     mInfo.funcOriginalToModified = getFunction(machinery.funcOriginalToModified);
184                     mInfo.funcModifiedToOriginal = getFunction(machinery.funcModifiedToOriginal);
185 
186                     mInfo.pipelineStruct.internal = &pipelineStruct;
187                     mInfo.pipelineStruct.external =
188                         modified ? machinery.modifiedStruct : &pipelineStruct;
189                 }
190             }
191         }
192         else
193         {
194             mInfo.pipelineStruct.internal = &pipelineStruct;
195             mInfo.pipelineStruct.external = &pipelineStruct;
196         }
197 
198         root.insertChildNodes(FindMainIndex(&root), seq);
199 
200         return true;
201     }
202 
203   private:
visitFunctionDefinitionPre(TIntermFunctionDefinition & node)204     PreResult visitFunctionDefinitionPre(TIntermFunctionDefinition &node) override
205     {
206         return {node, VisitBits::Neither};
207     }
visitDeclarationPost(TIntermDeclaration & declNode)208     PostResult visitDeclarationPost(TIntermDeclaration &declNode) override
209     {
210         Declaration decl     = ViewDeclaration(declNode);
211         const TVariable &var = decl.symbol.variable();
212         if (mPipeline.uses(var))
213         {
214             ASSERT(mInfo.pipelineVariables.find(&var) == mInfo.pipelineVariables.end());
215             mInfo.pipelineVariables.insert(&var);
216             mPipelineVariableList.push_back(&var);
217             return nullptr;
218         }
219 
220         return declNode;
221     }
222 
createInternalPipelineStruct(TIntermBlock & root,TIntermSequence & outDeclSeq)223     const TStructure &createInternalPipelineStruct(TIntermBlock &root, TIntermSequence &outDeclSeq)
224     {
225         auto &fields = *new TFieldList();
226 
227         switch (mPipeline.type)
228         {
229             case Pipeline::Type::Texture:
230             {
231                 for (const TVariable *var : mPipelineVariableList)
232                 {
233                     const TType &varType         = var->getType();
234                     const TBasicType samplerType = varType.getBasicType();
235 
236                     const TStructure &textureEnv = mSymbolEnv.getTextureEnv(samplerType);
237                     auto *textureEnvType         = new TType(&textureEnv, false);
238                     if (varType.isArray())
239                     {
240                         textureEnvType->makeArrays(varType.getArraySizes());
241                     }
242 
243                     fields.push_back(
244                         new TField(textureEnvType, var->name(), kNoSourceLoc, var->symbolType()));
245                 }
246             }
247             break;
248 
249             case Pipeline::Type::Image:
250             {
251                 for (const TVariable *var : mPipelineVariableList)
252                 {
253                     auto &type  = CloneType(var->getType());
254                     auto *field = new TField(&type, var->name(), kNoSourceLoc, var->symbolType());
255                     fields.push_back(field);
256                 }
257             }
258             break;
259 
260             case Pipeline::Type::UniformBuffer:
261             {
262                 for (const TVariable *var : mPipelineVariableList)
263                 {
264                     auto &type  = CloneType(var->getType());
265                     auto *field = new TField(&type, var->name(), kNoSourceLoc, var->symbolType());
266                     mSymbolEnv.markAsPointer(*field, AddressSpace::Constant);
267                     mSymbolEnv.markAsUBO(*field);
268                     mSymbolEnv.markAsPointer(*var, AddressSpace::Constant);
269                     fields.push_back(field);
270                 }
271             }
272             break;
273             default:
274             {
275                 for (const TVariable *var : mPipelineVariableList)
276                 {
277                     auto &type = CloneType(var->getType());
278                     if (mVariableInfos && IsVariableInvariant(*mVariableInfos, var->name()))
279                     {
280                         type.setInvariant(true);
281                     }
282                     auto *field = new TField(&type, var->name(), kNoSourceLoc, var->symbolType());
283                     fields.push_back(field);
284                 }
285             }
286             break;
287         }
288 
289         Name pipelineStructName = mPipeline.getStructTypeName(Pipeline::Variant::Original);
290         auto &s = *new TStructure(&mSymbolTable, pipelineStructName.rawName(), &fields,
291                                   pipelineStructName.symbolType());
292 
293         outDeclSeq.push_back(new TIntermDeclaration{&CreateStructTypeVariable(mSymbolTable, s)});
294 
295         return s;
296     }
297 };
298 
299 ////////////////////////////////////////////////////////////////////////////////
300 
CreatePipelineMainLocalVar(TSymbolTable & symbolTable,const Pipeline & pipeline,PipelineScoped<TStructure> pipelineStruct)301 PipelineScoped<TVariable> CreatePipelineMainLocalVar(TSymbolTable &symbolTable,
302                                                      const Pipeline &pipeline,
303                                                      PipelineScoped<TStructure> pipelineStruct)
304 {
305     ASSERT(pipelineStruct.isTotallyFull());
306 
307     PipelineScoped<TVariable> pipelineMainLocalVar;
308 
309     auto populateExternalMainLocalVar = [&]() {
310         ASSERT(!pipelineMainLocalVar.external);
311         pipelineMainLocalVar.external = &CreateInstanceVariable(
312             symbolTable, *pipelineStruct.external,
313             pipeline.getStructInstanceName(pipelineStruct.isUniform()
314                                                ? Pipeline::Variant::Original
315                                                : Pipeline::Variant::Modified));
316     };
317 
318     auto populateDistinctInternalMainLocalVar = [&]() {
319         ASSERT(!pipelineMainLocalVar.internal);
320         pipelineMainLocalVar.internal =
321             &CreateInstanceVariable(symbolTable, *pipelineStruct.internal,
322                                     pipeline.getStructInstanceName(Pipeline::Variant::Original));
323     };
324 
325     if (pipeline.type == Pipeline::Type::InstanceId)
326     {
327         populateDistinctInternalMainLocalVar();
328     }
329     else if (pipeline.alwaysRequiresLocalVariableDeclarationInMain())
330     {
331         populateExternalMainLocalVar();
332 
333         if (pipelineStruct.isUniform())
334         {
335             pipelineMainLocalVar.internal = pipelineMainLocalVar.external;
336         }
337         else
338         {
339             populateDistinctInternalMainLocalVar();
340         }
341     }
342     else if (!pipelineStruct.isUniform())
343     {
344         populateDistinctInternalMainLocalVar();
345     }
346 
347     return pipelineMainLocalVar;
348 }
349 
350 class PipelineFunctionEnv
351 {
352   private:
353     TCompiler &mCompiler;
354     SymbolEnv &mSymbolEnv;
355     TSymbolTable &mSymbolTable;
356     IdGen &mIdGen;
357     const Pipeline &mPipeline;
358     const std::unordered_set<const TFunction *> &mPipelineFunctions;
359     const PipelineScoped<TStructure> mPipelineStruct;
360     PipelineScoped<TVariable> &mPipelineMainLocalVar;
361     size_t mFirstParamIdxInMainFn = 0;
362 
363     std::unordered_map<const TFunction *, const TFunction *> mFuncMap;
364 
365     // Optional expression with which to initialize mPipelineMainLocalVar.
366     TIntermTyped *mPipelineInitExpr = nullptr;
367 
368   public:
PipelineFunctionEnv(TCompiler & compiler,SymbolEnv & symbolEnv,IdGen & idGen,const Pipeline & pipeline,const std::unordered_set<const TFunction * > & pipelineFunctions,PipelineScoped<TStructure> pipelineStruct,PipelineScoped<TVariable> & pipelineMainLocalVar)369     PipelineFunctionEnv(TCompiler &compiler,
370                         SymbolEnv &symbolEnv,
371                         IdGen &idGen,
372                         const Pipeline &pipeline,
373                         const std::unordered_set<const TFunction *> &pipelineFunctions,
374                         PipelineScoped<TStructure> pipelineStruct,
375                         PipelineScoped<TVariable> &pipelineMainLocalVar)
376         : mCompiler(compiler),
377           mSymbolEnv(symbolEnv),
378           mSymbolTable(symbolEnv.symbolTable()),
379           mIdGen(idGen),
380           mPipeline(pipeline),
381           mPipelineFunctions(pipelineFunctions),
382           mPipelineStruct(pipelineStruct),
383           mPipelineMainLocalVar(pipelineMainLocalVar)
384     {}
385 
isOriginalPipelineFunction(const TFunction & func) const386     bool isOriginalPipelineFunction(const TFunction &func) const
387     {
388         return mPipelineFunctions.find(&func) != mPipelineFunctions.end();
389     }
390 
isUpdatedPipelineFunction(const TFunction & func) const391     bool isUpdatedPipelineFunction(const TFunction &func) const
392     {
393         auto it = mFuncMap.find(&func);
394         if (it == mFuncMap.end())
395         {
396             return false;
397         }
398         return &func == it->second;
399     }
400 
getUpdatedFunction(const TFunction & func)401     const TFunction &getUpdatedFunction(const TFunction &func)
402     {
403         ASSERT(isOriginalPipelineFunction(func) || isUpdatedPipelineFunction(func));
404 
405         const TFunction *newFunc;
406 
407         auto it = mFuncMap.find(&func);
408         if (it == mFuncMap.end())
409         {
410             const bool isMain = func.isMain();
411             if (isMain)
412             {
413                 mFirstParamIdxInMainFn = func.getParamCount();
414             }
415 
416             if (isMain && mPipeline.isPipelineOut())
417             {
418                 ASSERT(func.getReturnType().getBasicType() == TBasicType::EbtVoid);
419                 newFunc = &CloneFunctionAndChangeReturnType(mSymbolTable, nullptr, func,
420                                                             *mPipelineStruct.external);
421                 if (mPipeline.type == Pipeline::Type::FragmentOut &&
422                     mCompiler.hasPixelLocalStorageUniforms())
423                 {
424                     // Add an input argument to main() that contains the current framebuffer
425                     // attachment values, for loading pixel local storage.
426                     TType *type = new TType(mPipelineStruct.external, true);
427                     TVariable *lastFragmentOut =
428                         new TVariable(&mSymbolTable, ImmutableString("lastFragmentOut"), type,
429                                       SymbolType::AngleInternal);
430                     newFunc = &CloneFunctionAndPrependParam(mSymbolTable, nullptr, *newFunc,
431                                                             *lastFragmentOut);
432                     // Initialize the main local variable with the current framebuffer contents.
433                     mPipelineInitExpr = new TIntermSymbol(lastFragmentOut);
434                 }
435             }
436             else if (isMain && (mPipeline.type == Pipeline::Type::InvocationVertexGlobals ||
437                                 mPipeline.type == Pipeline::Type::InvocationFragmentGlobals))
438             {
439                 std::vector<const TVariable *> variables;
440                 for (const TField *field : mPipelineStruct.external->fields())
441                 {
442                     variables.push_back(new TVariable(&mSymbolTable, field->name(), field->type(),
443                                                       field->symbolType()));
444                 }
445                 newFunc = &CloneFunctionAndAppendParams(mSymbolTable, nullptr, func, variables);
446             }
447             else if (isMain && mPipeline.type == Pipeline::Type::Texture)
448             {
449                 std::vector<const TVariable *> variables;
450                 TranslatorMetalReflection *reflection =
451                     mtl::getTranslatorMetalReflection(&mCompiler);
452                 for (const TField *field : mPipelineStruct.external->fields())
453                 {
454                     const TStructure *textureEnv = field->type()->getStruct();
455                     ASSERT(textureEnv && textureEnv->fields().size() == 2);
456                     for (const TField *subfield : textureEnv->fields())
457                     {
458                         const Name name = mIdGen.createNewName({field->name(), subfield->name()});
459                         TType &type     = *new TType(*subfield->type());
460                         ASSERT(!type.isArray());
461                         type.makeArrays(field->type()->getArraySizes());
462                         auto *var =
463                             new TVariable(&mSymbolTable, name.rawName(), &type, name.symbolType());
464                         variables.push_back(var);
465                         reflection->addOriginalName(var->uniqueId().get(), field->name().data());
466                     }
467                 }
468                 newFunc = &CloneFunctionAndAppendParams(mSymbolTable, nullptr, func, variables);
469             }
470             else if (isMain && mPipeline.type == Pipeline::Type::InstanceId)
471             {
472                 Name instanceIdName = mPipeline.getStructInstanceName(Pipeline::Variant::Modified);
473                 auto *instanceIdVar =
474                     new TVariable(&mSymbolTable, instanceIdName.rawName(),
475                                   new TType(TBasicType::EbtUInt), instanceIdName.symbolType());
476 
477                 auto *baseInstanceVar =
478                     new TVariable(&mSymbolTable, kBaseInstanceName.rawName(),
479                                   new TType(TBasicType::EbtUInt), kBaseInstanceName.symbolType());
480 
481                 newFunc = &CloneFunctionAndPrependTwoParams(mSymbolTable, nullptr, func,
482                                                             *instanceIdVar, *baseInstanceVar);
483                 mPipelineMainLocalVar.external      = instanceIdVar;
484                 mPipelineMainLocalVar.externalExtra = baseInstanceVar;
485             }
486             else if (isMain && mPipeline.alwaysRequiresLocalVariableDeclarationInMain())
487             {
488                 ASSERT(mPipelineMainLocalVar.isTotallyFull());
489                 newFunc = &func;
490             }
491             else
492             {
493                 const TVariable *var;
494                 AddressSpace addressSpace;
495 
496                 if (isMain && !mPipelineMainLocalVar.isUniform())
497                 {
498                     var = &CreateInstanceVariable(
499                         mSymbolTable, *mPipelineStruct.external,
500                         mPipeline.getStructInstanceName(Pipeline::Variant::Modified));
501                     addressSpace = mPipeline.externalAddressSpace();
502                 }
503                 else
504                 {
505                     var = &CreateInstanceVariable(
506                         mSymbolTable, *mPipelineStruct.internal,
507                         mPipeline.getStructInstanceName(Pipeline::Variant::Original));
508                     addressSpace = mPipelineMainLocalVar.isUniform()
509                                        ? mPipeline.externalAddressSpace()
510                                        : AddressSpace::Thread;
511                 }
512 
513                 bool markAsReference = true;
514                 if (isMain)
515                 {
516                     switch (mPipeline.type)
517                     {
518                         case Pipeline::Type::VertexIn:
519                         case Pipeline::Type::FragmentIn:
520                         case Pipeline::Type::Image:
521                             markAsReference = false;
522                             break;
523 
524                         default:
525                             break;
526                     }
527                 }
528 
529                 if (markAsReference)
530                 {
531                     mSymbolEnv.markAsReference(*var, addressSpace);
532                 }
533 
534                 newFunc = &CloneFunctionAndPrependParam(mSymbolTable, nullptr, func, *var);
535             }
536 
537             mFuncMap[&func]   = newFunc;
538             mFuncMap[newFunc] = newFunc;
539         }
540         else
541         {
542             newFunc = it->second;
543         }
544 
545         return *newFunc;
546     }
547 
createUpdatedFunctionPrototype(TIntermFunctionPrototype & funcProtoNode)548     TIntermFunctionPrototype *createUpdatedFunctionPrototype(
549         TIntermFunctionPrototype &funcProtoNode)
550     {
551         const TFunction &func = *funcProtoNode.getFunction();
552         if (!isOriginalPipelineFunction(func) && !isUpdatedPipelineFunction(func))
553         {
554             return nullptr;
555         }
556         const TFunction &newFunc = getUpdatedFunction(func);
557         return new TIntermFunctionPrototype(&newFunc);
558     }
559 
560     // If not null, this is the value we need to initialize the pipeline main local variable with.
getOptionalPipelineInitExpr()561     TIntermTyped *getOptionalPipelineInitExpr() { return mPipelineInitExpr; }
562 
getFirstParamIdxInMainFn() const563     size_t getFirstParamIdxInMainFn() const { return mFirstParamIdxInMainFn; }
564 };
565 
566 class UpdatePipelineFunctions : private TIntermRebuild
567 {
568   private:
569     const Pipeline &mPipeline;
570     const PipelineScoped<TStructure> mPipelineStruct;
571     PipelineScoped<TVariable> &mPipelineMainLocalVar;
572     SymbolEnv &mSymbolEnv;
573     PipelineFunctionEnv mEnv;
574     const TFunction *mFuncOriginalToModified;
575     const TFunction *mFuncModifiedToOriginal;
576 
577   public:
ThreadPipeline(TCompiler & compiler,TIntermBlock & root,const Pipeline & pipeline,const std::unordered_set<const TFunction * > & pipelineFunctions,PipelineScoped<TStructure> pipelineStruct,PipelineScoped<TVariable> & pipelineMainLocalVar,IdGen & idGen,SymbolEnv & symbolEnv,const TFunction * funcOriginalToModified,const TFunction * funcModifiedToOriginal)578     static bool ThreadPipeline(TCompiler &compiler,
579                                TIntermBlock &root,
580                                const Pipeline &pipeline,
581                                const std::unordered_set<const TFunction *> &pipelineFunctions,
582                                PipelineScoped<TStructure> pipelineStruct,
583                                PipelineScoped<TVariable> &pipelineMainLocalVar,
584                                IdGen &idGen,
585                                SymbolEnv &symbolEnv,
586                                const TFunction *funcOriginalToModified,
587                                const TFunction *funcModifiedToOriginal)
588     {
589         UpdatePipelineFunctions self(compiler, pipeline, pipelineFunctions, pipelineStruct,
590                                      pipelineMainLocalVar, idGen, symbolEnv, funcOriginalToModified,
591                                      funcModifiedToOriginal);
592         if (!self.rebuildRoot(root))
593         {
594             return false;
595         }
596         return true;
597     }
598 
599   private:
UpdatePipelineFunctions(TCompiler & compiler,const Pipeline & pipeline,const std::unordered_set<const TFunction * > & pipelineFunctions,PipelineScoped<TStructure> pipelineStruct,PipelineScoped<TVariable> & pipelineMainLocalVar,IdGen & idGen,SymbolEnv & symbolEnv,const TFunction * funcOriginalToModified,const TFunction * funcModifiedToOriginal)600     UpdatePipelineFunctions(TCompiler &compiler,
601                             const Pipeline &pipeline,
602                             const std::unordered_set<const TFunction *> &pipelineFunctions,
603                             PipelineScoped<TStructure> pipelineStruct,
604                             PipelineScoped<TVariable> &pipelineMainLocalVar,
605                             IdGen &idGen,
606                             SymbolEnv &symbolEnv,
607                             const TFunction *funcOriginalToModified,
608                             const TFunction *funcModifiedToOriginal)
609         : TIntermRebuild(compiler, false, true),
610           mPipeline(pipeline),
611           mPipelineStruct(pipelineStruct),
612           mPipelineMainLocalVar(pipelineMainLocalVar),
613           mSymbolEnv(symbolEnv),
614           mEnv(compiler,
615                symbolEnv,
616                idGen,
617                pipeline,
618                pipelineFunctions,
619                pipelineStruct,
620                mPipelineMainLocalVar),
621           mFuncOriginalToModified(funcOriginalToModified),
622           mFuncModifiedToOriginal(funcModifiedToOriginal)
623     {
624         ASSERT(mPipelineStruct.isTotallyFull());
625     }
626 
getInternalPipelineVariable(const TFunction & pipelineFunc)627     const TVariable &getInternalPipelineVariable(const TFunction &pipelineFunc)
628     {
629         if (pipelineFunc.isMain() && (mPipeline.alwaysRequiresLocalVariableDeclarationInMain() ||
630                                       !mPipelineMainLocalVar.isUniform()))
631         {
632             ASSERT(mPipelineMainLocalVar.internal);
633             return *mPipelineMainLocalVar.internal;
634         }
635         else
636         {
637             ASSERT(pipelineFunc.getParamCount() > 0);
638             return *pipelineFunc.getParam(0);
639         }
640     }
641 
getExternalPipelineVariable(const TFunction & mainFunc)642     const TVariable &getExternalPipelineVariable(const TFunction &mainFunc)
643     {
644         ASSERT(mainFunc.isMain());
645         if (mPipelineMainLocalVar.external)
646         {
647             return *mPipelineMainLocalVar.external;
648         }
649         else
650         {
651             ASSERT(mainFunc.getParamCount() > 0);
652             return *mainFunc.getParam(0);
653         }
654     }
655 
getExternalExtraPipelineVariable(const TFunction & mainFunc)656     const TVariable &getExternalExtraPipelineVariable(const TFunction &mainFunc)
657     {
658         ASSERT(mainFunc.isMain());
659         if (mPipelineMainLocalVar.externalExtra)
660         {
661             return *mPipelineMainLocalVar.externalExtra;
662         }
663         else
664         {
665             ASSERT(mainFunc.getParamCount() > 1);
666             return *mainFunc.getParam(1);
667         }
668     }
669 
visitAggregatePost(TIntermAggregate & callNode)670     PostResult visitAggregatePost(TIntermAggregate &callNode) override
671     {
672         if (callNode.isConstructor())
673         {
674             return callNode;
675         }
676         else
677         {
678             const TFunction &oldCalledFunc = *callNode.getFunction();
679             if (!mEnv.isOriginalPipelineFunction(oldCalledFunc))
680             {
681                 return callNode;
682             }
683             const TFunction &newCalledFunc = mEnv.getUpdatedFunction(oldCalledFunc);
684 
685             const TFunction *oldOwnerFunc = getParentFunction();
686             ASSERT(oldOwnerFunc);
687             const TFunction &newOwnerFunc = mEnv.getUpdatedFunction(*oldOwnerFunc);
688 
689             return *TIntermAggregate::CreateFunctionCall(
690                 newCalledFunc, &CloneSequenceAndPrepend(
691                                    *callNode.getSequence(),
692                                    *new TIntermSymbol(&getInternalPipelineVariable(newOwnerFunc))));
693         }
694     }
695 
visitFunctionPrototypePost(TIntermFunctionPrototype & funcProtoNode)696     PostResult visitFunctionPrototypePost(TIntermFunctionPrototype &funcProtoNode) override
697     {
698         TIntermFunctionPrototype *newFuncProtoNode =
699             mEnv.createUpdatedFunctionPrototype(funcProtoNode);
700         if (newFuncProtoNode == nullptr)
701         {
702             return funcProtoNode;
703         }
704         return *newFuncProtoNode;
705     }
706 
visitFunctionDefinitionPost(TIntermFunctionDefinition & funcDefNode)707     PostResult visitFunctionDefinitionPost(TIntermFunctionDefinition &funcDefNode) override
708     {
709         if (funcDefNode.getFunction()->isMain())
710         {
711             return visitMain(funcDefNode);
712         }
713         else
714         {
715             return visitNonMain(funcDefNode);
716         }
717     }
718 
visitNonMain(TIntermFunctionDefinition & funcDefNode)719     TIntermNode &visitNonMain(TIntermFunctionDefinition &funcDefNode)
720     {
721         TIntermFunctionPrototype &funcProtoNode = *funcDefNode.getFunctionPrototype();
722         ASSERT(!funcProtoNode.getFunction()->isMain());
723 
724         TIntermFunctionPrototype *newFuncProtoNode =
725             mEnv.createUpdatedFunctionPrototype(funcProtoNode);
726         if (newFuncProtoNode == nullptr)
727         {
728             return funcDefNode;
729         }
730 
731         const TFunction &func = *newFuncProtoNode->getFunction();
732         ASSERT(!func.isMain());
733 
734         TIntermBlock *body = funcDefNode.getBody();
735 
736         return *new TIntermFunctionDefinition(newFuncProtoNode, body);
737     }
738 
visitMain(TIntermFunctionDefinition & funcDefNode)739     TIntermNode &visitMain(TIntermFunctionDefinition &funcDefNode)
740     {
741         TIntermFunctionPrototype &funcProtoNode = *funcDefNode.getFunctionPrototype();
742         ASSERT(funcProtoNode.getFunction()->isMain());
743 
744         TIntermFunctionPrototype *newFuncProtoNode =
745             mEnv.createUpdatedFunctionPrototype(funcProtoNode);
746         if (newFuncProtoNode == nullptr)
747         {
748             return funcDefNode;
749         }
750 
751         const TFunction &func = *newFuncProtoNode->getFunction();
752         ASSERT(func.isMain());
753 
754         auto callModifiedToOriginal = [&](TIntermBlock &body) {
755             ASSERT(mPipelineMainLocalVar.internal);
756             if (!mPipeline.isPipelineOut())
757             {
758                 ASSERT(mFuncModifiedToOriginal);
759                 auto *m = new TIntermSymbol(&getExternalPipelineVariable(func));
760                 auto *o = new TIntermSymbol(mPipelineMainLocalVar.internal);
761                 body.appendStatement(TIntermAggregate::CreateFunctionCall(
762                     *mFuncModifiedToOriginal, new TIntermSequence{m, o}));
763             }
764         };
765 
766         auto callOriginalToModified = [&](TIntermBlock &body) {
767             ASSERT(mPipelineMainLocalVar.internal);
768             if (mPipeline.isPipelineOut())
769             {
770                 ASSERT(mFuncOriginalToModified);
771                 auto *o = new TIntermSymbol(mPipelineMainLocalVar.internal);
772                 auto *m = new TIntermSymbol(&getExternalPipelineVariable(func));
773                 body.appendStatement(TIntermAggregate::CreateFunctionCall(
774                     *mFuncOriginalToModified, new TIntermSequence{o, m}));
775             }
776         };
777 
778         TIntermBlock *body = funcDefNode.getBody();
779 
780         if (mPipeline.alwaysRequiresLocalVariableDeclarationInMain())
781         {
782             ASSERT(mPipelineMainLocalVar.isTotallyFull());
783 
784             auto *newBody = new TIntermBlock();
785             newBody->appendStatement(new TIntermDeclaration(mPipelineMainLocalVar.internal,
786                                                             mEnv.getOptionalPipelineInitExpr()));
787 
788             if (mPipeline.type == Pipeline::Type::InvocationVertexGlobals ||
789                 mPipeline.type == Pipeline::Type::InvocationFragmentGlobals)
790             {
791                 // Populate struct instance with references to global pipeline variables.
792                 for (const TField *field : mPipelineStruct.external->fields())
793                 {
794                     auto *var        = new TVariable(&mSymbolTable, field->name(), field->type(),
795                                                      field->symbolType());
796                     auto *symbol     = new TIntermSymbol(var);
797                     auto &accessNode = AccessField(*mPipelineMainLocalVar.internal, var->name());
798                     auto *assignNode = new TIntermBinary(TOperator::EOpAssign, &accessNode, symbol);
799                     newBody->appendStatement(assignNode);
800                 }
801             }
802             else if (mPipeline.type == Pipeline::Type::Texture)
803             {
804                 const TFieldList &fields = mPipelineStruct.external->fields();
805 
806                 ASSERT(func.getParamCount() >= mEnv.getFirstParamIdxInMainFn() + 2 * fields.size());
807                 size_t paramIndex = mEnv.getFirstParamIdxInMainFn();
808 
809                 for (const TField *field : fields)
810                 {
811                     const TVariable &textureParam = *func.getParam(paramIndex++);
812                     const TVariable &samplerParam = *func.getParam(paramIndex++);
813 
814                     auto go = [&](TIntermTyped &env, const int *index) {
815                         TIntermTyped &textureField = AccessField(
816                             AccessIndex(*env.deepCopy(), index), ImmutableString("texture"));
817                         TIntermTyped &samplerField = AccessField(
818                             AccessIndex(*env.deepCopy(), index), ImmutableString("sampler"));
819 
820                         auto mkAssign = [&](TIntermTyped &field, const TVariable &param) {
821                             return new TIntermBinary(TOperator::EOpAssign, &field,
822                                                      &mSymbolEnv.callFunctionOverload(
823                                                          Name("addressof"), field.getType(),
824                                                          *new TIntermSequence{&AccessIndex(
825                                                              *new TIntermSymbol(&param), index)}));
826                         };
827 
828                         newBody->appendStatement(mkAssign(textureField, textureParam));
829                         newBody->appendStatement(mkAssign(samplerField, samplerParam));
830                     };
831 
832                     TIntermTyped &env = AccessField(*mPipelineMainLocalVar.internal, field->name());
833                     const TType &envType = env.getType();
834 
835                     if (envType.isArray())
836                     {
837                         ASSERT(!envType.isArrayOfArrays());
838                         const auto n = static_cast<int>(envType.getArraySizeProduct());
839                         for (int i = 0; i < n; ++i)
840                         {
841                             go(env, &i);
842                         }
843                     }
844                     else
845                     {
846                         go(env, nullptr);
847                     }
848                 }
849             }
850             else if (mPipeline.type == Pipeline::Type::InstanceId)
851             {
852                 auto varInstanceId   = new TIntermSymbol(&getExternalPipelineVariable(func));
853                 auto varBaseInstance = new TIntermSymbol(&getExternalExtraPipelineVariable(func));
854 
855                 newBody->appendStatement(new TIntermBinary(
856                     TOperator::EOpAssign,
857                     &AccessFieldByIndex(*new TIntermSymbol(&getInternalPipelineVariable(func)), 0),
858                     &AsType(
859                         mSymbolEnv, *new TType(TBasicType::EbtInt),
860                         *new TIntermBinary(TOperator::EOpSub, varInstanceId, varBaseInstance))));
861             }
862             else if (!mPipelineMainLocalVar.isUniform())
863             {
864                 newBody->appendStatement(new TIntermDeclaration{mPipelineMainLocalVar.external});
865                 callModifiedToOriginal(*newBody);
866             }
867 
868             newBody->appendStatement(body);
869 
870             if (!mPipelineMainLocalVar.isUniform())
871             {
872                 callOriginalToModified(*newBody);
873             }
874 
875             if (mPipeline.isPipelineOut())
876             {
877                 newBody->appendStatement(new TIntermBranch(
878                     TOperator::EOpReturn, new TIntermSymbol(mPipelineMainLocalVar.external)));
879             }
880 
881             body = newBody;
882         }
883         else if (!mPipelineMainLocalVar.isUniform())
884         {
885             ASSERT(!mPipelineMainLocalVar.external);
886             ASSERT(mPipelineMainLocalVar.internal);
887 
888             auto *newBody = new TIntermBlock();
889             newBody->appendStatement(new TIntermDeclaration{mPipelineMainLocalVar.internal});
890             callModifiedToOriginal(*newBody);
891             newBody->appendStatement(body);
892             callOriginalToModified(*newBody);
893             body = newBody;
894         }
895 
896         return *new TIntermFunctionDefinition(newFuncProtoNode, body);
897     }
898 };
899 
900 ////////////////////////////////////////////////////////////////////////////////
901 
UpdatePipelineSymbols(Pipeline::Type pipelineType,TCompiler & compiler,TIntermBlock & root,SymbolEnv & symbolEnv,const VariableSet & pipelineVariables,PipelineScoped<TVariable> pipelineMainLocalVar)902 bool UpdatePipelineSymbols(Pipeline::Type pipelineType,
903                            TCompiler &compiler,
904                            TIntermBlock &root,
905                            SymbolEnv &symbolEnv,
906                            const VariableSet &pipelineVariables,
907                            PipelineScoped<TVariable> pipelineMainLocalVar)
908 {
909     auto map = [&](const TFunction *owner, TIntermSymbol &symbol) -> TIntermNode & {
910         if (!owner)
911             return symbol;
912         const TVariable &var = symbol.variable();
913         if (pipelineVariables.find(&var) == pipelineVariables.end())
914         {
915             return symbol;
916         }
917         const TVariable *structInstanceVar;
918         if (owner->isMain() && pipelineType != Pipeline::Type::FragmentIn)
919         {
920             ASSERT(pipelineMainLocalVar.internal);
921             structInstanceVar = pipelineMainLocalVar.internal;
922         }
923         else
924         {
925             ASSERT(owner->getParamCount() > 0);
926             structInstanceVar = owner->getParam(0);
927         }
928         ASSERT(structInstanceVar);
929         return AccessField(*structInstanceVar, var.name());
930     };
931     return MapSymbols(compiler, root, map);
932 }
933 
934 ////////////////////////////////////////////////////////////////////////////////
935 
RewritePipeline(TCompiler & compiler,TIntermBlock & root,IdGen & idGen,const Pipeline & pipeline,SymbolEnv & symbolEnv,const std::vector<sh::ShaderVariable> * variableInfo,PipelineScoped<TStructure> & outStruct)936 bool RewritePipeline(TCompiler &compiler,
937                      TIntermBlock &root,
938                      IdGen &idGen,
939                      const Pipeline &pipeline,
940                      SymbolEnv &symbolEnv,
941                      const std::vector<sh::ShaderVariable> *variableInfo,
942                      PipelineScoped<TStructure> &outStruct)
943 {
944     ASSERT(outStruct.isTotallyEmpty());
945 
946     TSymbolTable &symbolTable = compiler.getSymbolTable();
947 
948     PipelineStructInfo psi;
949     if (!GeneratePipelineStruct::Exec(psi, compiler, root, idGen, pipeline, symbolEnv,
950                                       variableInfo))
951     {
952         return false;
953     }
954 
955     if (psi.isEmpty())
956     {
957         return true;
958     }
959 
960     const auto pipelineFunctions = DiscoverDependentFunctions(root, [&](const TVariable &var) {
961         return psi.pipelineVariables.find(&var) != psi.pipelineVariables.end();
962     });
963 
964     auto pipelineMainLocalVar =
965         CreatePipelineMainLocalVar(symbolTable, pipeline, psi.pipelineStruct);
966 
967     if (!UpdatePipelineFunctions::ThreadPipeline(
968             compiler, root, pipeline, pipelineFunctions, psi.pipelineStruct, pipelineMainLocalVar,
969             idGen, symbolEnv, psi.funcOriginalToModified, psi.funcModifiedToOriginal))
970     {
971         return false;
972     }
973 
974     if (!pipeline.globalInstanceVar)
975     {
976         if (!UpdatePipelineSymbols(pipeline.type, compiler, root, symbolEnv, psi.pipelineVariables,
977                                    pipelineMainLocalVar))
978         {
979             return false;
980         }
981     }
982 
983     if (!PruneNoOps(&compiler, &root, &compiler.getSymbolTable()))
984     {
985         return false;
986     }
987 
988     outStruct = psi.pipelineStruct;
989     return true;
990 }
991 
992 }  // anonymous namespace
993 
RewritePipelines(TCompiler & compiler,TIntermBlock & root,const std::vector<sh::ShaderVariable> & inputVaryings,const std::vector<sh::ShaderVariable> & outputVaryings,IdGen & idGen,DriverUniform & angleUniformsGlobalInstanceVar,SymbolEnv & symbolEnv,PipelineStructs & outStructs)994 bool sh::RewritePipelines(TCompiler &compiler,
995                           TIntermBlock &root,
996                           const std::vector<sh::ShaderVariable> &inputVaryings,
997                           const std::vector<sh::ShaderVariable> &outputVaryings,
998                           IdGen &idGen,
999                           DriverUniform &angleUniformsGlobalInstanceVar,
1000                           SymbolEnv &symbolEnv,
1001                           PipelineStructs &outStructs)
1002 {
1003     struct Info
1004     {
1005         Pipeline::Type pipelineType;
1006         PipelineScoped<TStructure> &outStruct;
1007         const TVariable *globalInstanceVar;
1008         const std::vector<sh::ShaderVariable> *variableInfo;
1009     };
1010 
1011     Info infos[] = {
1012         {Pipeline::Type::InstanceId, outStructs.instanceId, nullptr, nullptr},
1013         {Pipeline::Type::Texture, outStructs.image, nullptr, nullptr},
1014         {Pipeline::Type::Image, outStructs.texture, nullptr, nullptr},
1015         {Pipeline::Type::NonConstantGlobals, outStructs.nonConstantGlobals, nullptr, nullptr},
1016         {Pipeline::Type::AngleUniforms, outStructs.angleUniforms,
1017          angleUniformsGlobalInstanceVar.getDriverUniformsVariable(), nullptr},
1018         {Pipeline::Type::UserUniforms, outStructs.userUniforms, nullptr, nullptr},
1019         {Pipeline::Type::VertexIn, outStructs.vertexIn, nullptr, &inputVaryings},
1020         {Pipeline::Type::VertexOut, outStructs.vertexOut, nullptr, &outputVaryings},
1021         {Pipeline::Type::FragmentIn, outStructs.fragmentIn, nullptr, &inputVaryings},
1022         {Pipeline::Type::FragmentOut, outStructs.fragmentOut, nullptr, &outputVaryings},
1023         {Pipeline::Type::InvocationVertexGlobals, outStructs.invocationVertexGlobals, nullptr,
1024          nullptr},
1025         {Pipeline::Type::InvocationFragmentGlobals, outStructs.invocationFragmentGlobals, nullptr,
1026          &inputVaryings},
1027         {Pipeline::Type::UniformBuffer, outStructs.uniformBuffers, nullptr, nullptr},
1028     };
1029 
1030     for (Info &info : infos)
1031     {
1032         if ((compiler.getShaderType() != GL_VERTEX_SHADER &&
1033              (info.pipelineType == Pipeline::Type::VertexIn ||
1034               info.pipelineType == Pipeline::Type::VertexOut ||
1035               info.pipelineType == Pipeline::Type::InvocationVertexGlobals)) ||
1036             (compiler.getShaderType() != GL_FRAGMENT_SHADER &&
1037              (info.pipelineType == Pipeline::Type::FragmentIn ||
1038               info.pipelineType == Pipeline::Type::FragmentOut ||
1039               info.pipelineType == Pipeline::Type::InvocationFragmentGlobals)))
1040             continue;
1041 
1042         Pipeline pipeline{info.pipelineType, info.globalInstanceVar};
1043         if (!RewritePipeline(compiler, root, idGen, pipeline, symbolEnv, info.variableInfo,
1044                              info.outStruct))
1045         {
1046             return false;
1047         }
1048     }
1049 
1050     return true;
1051 }
1052