1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include <cstring>
8 #include <unordered_map>
9 #include <unordered_set>
10
11 #include "compiler/translator/TranslatorMetalDirect.h"
12 #include "compiler/translator/TranslatorMetalDirect/AstHelpers.h"
13 #include "compiler/translator/TranslatorMetalDirect/DiscoverDependentFunctions.h"
14 #include "compiler/translator/TranslatorMetalDirect/IdGen.h"
15 #include "compiler/translator/TranslatorMetalDirect/MapSymbols.h"
16 #include "compiler/translator/TranslatorMetalDirect/Pipeline.h"
17 #include "compiler/translator/TranslatorMetalDirect/RewritePipelines.h"
18 #include "compiler/translator/TranslatorMetalDirect/SymbolEnv.h"
19 #include "compiler/translator/tree_ops/PruneNoOps.h"
20 #include "compiler/translator/tree_util/DriverUniform.h"
21 #include "compiler/translator/tree_util/FindMain.h"
22 #include "compiler/translator/tree_util/IntermRebuild.h"
23 #include "compiler/translator/tree_util/IntermTraverse.h"
24 using namespace sh;
25
26 ////////////////////////////////////////////////////////////////////////////////
27
28 namespace
29 {
30
IsVariableInvariant(const std::vector<sh::ShaderVariable> & mVars,const ImmutableString & name)31 bool IsVariableInvariant(const std::vector<sh::ShaderVariable> &mVars, const ImmutableString &name)
32 {
33 for (const auto &var : mVars)
34 {
35 if (name == var.name)
36 {
37 return var.isInvariant;
38 }
39 }
40 // TODO(kpidington): this should be UNREACHABLE() but isn't because the translator generates
41 // declarations to unused built-in variables.
42 return false;
43 }
44
45 using VariableSet = std::unordered_set<const TVariable *>;
46 using VariableList = std::vector<const TVariable *>;
47
48 ////////////////////////////////////////////////////////////////////////////////
49
50 struct PipelineStructInfo
51 {
52 VariableSet pipelineVariables;
53 PipelineScoped<TStructure> pipelineStruct;
54 const TFunction *funcOriginalToModified = nullptr;
55 const TFunction *funcModifiedToOriginal = nullptr;
56
isEmpty__anon986796d00111::PipelineStructInfo57 bool isEmpty() const
58 {
59 if (pipelineStruct.isTotallyEmpty())
60 {
61 ASSERT(pipelineVariables.empty());
62 return true;
63 }
64 else
65 {
66 ASSERT(pipelineStruct.isTotallyFull());
67 ASSERT(!pipelineVariables.empty());
68 return false;
69 }
70 }
71 };
72
73 class GeneratePipelineStruct : private TIntermRebuild
74 {
75 private:
76 const Pipeline &mPipeline;
77 SymbolEnv &mSymbolEnv;
78 const std::vector<sh::ShaderVariable> *mVariableInfos;
79 VariableList mPipelineVariableList;
80 IdGen &mIdGen;
81 PipelineStructInfo mInfo;
82
83 public:
Exec(PipelineStructInfo & out,TCompiler & compiler,TIntermBlock & root,IdGen & idGen,const Pipeline & pipeline,SymbolEnv & symbolEnv,const std::vector<sh::ShaderVariable> * variableInfos)84 static bool Exec(PipelineStructInfo &out,
85 TCompiler &compiler,
86 TIntermBlock &root,
87 IdGen &idGen,
88 const Pipeline &pipeline,
89 SymbolEnv &symbolEnv,
90 const std::vector<sh::ShaderVariable> *variableInfos)
91 {
92 GeneratePipelineStruct self(compiler, idGen, pipeline, symbolEnv, variableInfos);
93 if (!self.exec(root))
94 {
95 return false;
96 }
97 out = self.mInfo;
98 return true;
99 }
100
101 private:
GeneratePipelineStruct(TCompiler & compiler,IdGen & idGen,const Pipeline & pipeline,SymbolEnv & symbolEnv,const std::vector<sh::ShaderVariable> * variableInfos)102 GeneratePipelineStruct(TCompiler &compiler,
103 IdGen &idGen,
104 const Pipeline &pipeline,
105 SymbolEnv &symbolEnv,
106 const std::vector<sh::ShaderVariable> *variableInfos)
107 : TIntermRebuild(compiler, true, true),
108 mPipeline(pipeline),
109 mSymbolEnv(symbolEnv),
110 mVariableInfos(variableInfos),
111 mIdGen(idGen)
112 {}
113
exec(TIntermBlock & root)114 bool exec(TIntermBlock &root)
115 {
116 if (!rebuildRoot(root))
117 {
118 return false;
119 }
120
121 if (mInfo.pipelineVariables.empty())
122 {
123 return true;
124 }
125
126 TIntermSequence seq;
127
128 const TStructure &pipelineStruct = [&]() -> const TStructure & {
129 if (mPipeline.globalInstanceVar)
130 {
131 return *mPipeline.globalInstanceVar->getType().getStruct();
132 }
133 else
134 {
135 return createInternalPipelineStruct(root, seq);
136 }
137 }();
138
139 ModifiedStructMachineries modifiedMachineries;
140 const bool isUBO = mPipeline.type == Pipeline::Type::UniformBuffer;
141 const bool modified = TryCreateModifiedStruct(
142 mCompiler, mSymbolEnv, mIdGen, mPipeline.externalStructModifyConfig(), pipelineStruct,
143 mPipeline.getStructTypeName(Pipeline::Variant::Modified), modifiedMachineries, isUBO,
144 !isUBO);
145
146 if (modified)
147 {
148 ASSERT(mPipeline.type != Pipeline::Type::Texture);
149 ASSERT(mPipeline.type == Pipeline::Type::AngleUniforms ||
150 !mPipeline.globalInstanceVar); // This shouldn't happen by construction.
151
152 auto getFunction = [](sh::TIntermFunctionDefinition *funcDecl) {
153 return funcDecl ? funcDecl->getFunction() : nullptr;
154 };
155
156 const size_t size = modifiedMachineries.size();
157 ASSERT(size > 0);
158 for (size_t i = 0; i < size; ++i)
159 {
160 const ModifiedStructMachinery &machinery = modifiedMachineries.at(i);
161 ASSERT(machinery.modifiedStruct);
162
163 seq.push_back(new TIntermDeclaration{
164 &CreateStructTypeVariable(mSymbolTable, *machinery.modifiedStruct)});
165
166 if (mPipeline.isPipelineOut())
167 {
168 ASSERT(machinery.funcOriginalToModified);
169 ASSERT(!machinery.funcModifiedToOriginal);
170 seq.push_back(machinery.funcOriginalToModified);
171 }
172 else
173 {
174 ASSERT(machinery.funcModifiedToOriginal);
175 ASSERT(!machinery.funcOriginalToModified);
176 seq.push_back(machinery.funcModifiedToOriginal);
177 }
178
179 if (i == size - 1)
180 {
181 mInfo.funcOriginalToModified = getFunction(machinery.funcOriginalToModified);
182 mInfo.funcModifiedToOriginal = getFunction(machinery.funcModifiedToOriginal);
183
184 mInfo.pipelineStruct.internal = &pipelineStruct;
185 mInfo.pipelineStruct.external =
186 modified ? machinery.modifiedStruct : &pipelineStruct;
187 }
188 }
189 }
190 else
191 {
192 mInfo.pipelineStruct.internal = &pipelineStruct;
193 mInfo.pipelineStruct.external = &pipelineStruct;
194 }
195
196 root.insertChildNodes(FindMainIndex(&root), seq);
197
198 return true;
199 }
200
201 private:
visitFunctionDefinitionPre(TIntermFunctionDefinition & node)202 PreResult visitFunctionDefinitionPre(TIntermFunctionDefinition &node) override
203 {
204 return {node, VisitBits::Neither};
205 }
visitDeclarationPost(TIntermDeclaration & declNode)206 PostResult visitDeclarationPost(TIntermDeclaration &declNode) override
207 {
208 Declaration decl = ViewDeclaration(declNode);
209 const TVariable &var = decl.symbol.variable();
210 if (mPipeline.uses(var))
211 {
212 ASSERT(mInfo.pipelineVariables.find(&var) == mInfo.pipelineVariables.end());
213 mInfo.pipelineVariables.insert(&var);
214 mPipelineVariableList.push_back(&var);
215 return nullptr;
216 }
217
218 return declNode;
219 }
220
createInternalPipelineStruct(TIntermBlock & root,TIntermSequence & outDeclSeq)221 const TStructure &createInternalPipelineStruct(TIntermBlock &root, TIntermSequence &outDeclSeq)
222 {
223 auto &fields = *new TFieldList();
224
225 switch (mPipeline.type)
226 {
227 case Pipeline::Type::Texture:
228 {
229 for (const TVariable *var : mPipelineVariableList)
230 {
231 const TType &varType = var->getType();
232 const TBasicType samplerType = varType.getBasicType();
233
234 const TStructure &textureEnv = mSymbolEnv.getTextureEnv(samplerType);
235 auto *textureEnvType = new TType(&textureEnv, false);
236 if (varType.isArray())
237 {
238 textureEnvType->makeArrays(varType.getArraySizes());
239 }
240
241 fields.push_back(
242 new TField(textureEnvType, var->name(), kNoSourceLoc, var->symbolType()));
243 }
244 }
245 break;
246
247 case Pipeline::Type::UniformBuffer:
248 {
249 for (const TVariable *var : mPipelineVariableList)
250 {
251 auto &type = CloneType(var->getType());
252 auto *field = new TField(&type, var->name(), kNoSourceLoc, var->symbolType());
253 mSymbolEnv.markAsPointer(*field, AddressSpace::Constant);
254 mSymbolEnv.markAsUBO(*field);
255 mSymbolEnv.markAsPointer(*var, AddressSpace::Constant);
256 fields.push_back(field);
257 }
258 }
259 break;
260 default:
261 {
262 for (const TVariable *var : mPipelineVariableList)
263 {
264 auto &type = CloneType(var->getType());
265 if (mVariableInfos && IsVariableInvariant(*mVariableInfos, var->name()))
266 {
267 type.setInvariant(true);
268 }
269 auto *field = new TField(&type, var->name(), kNoSourceLoc, var->symbolType());
270 fields.push_back(field);
271 }
272 }
273 break;
274 }
275
276 Name pipelineStructName = mPipeline.getStructTypeName(Pipeline::Variant::Original);
277 auto &s = *new TStructure(&mSymbolTable, pipelineStructName.rawName(), &fields,
278 pipelineStructName.symbolType());
279
280 outDeclSeq.push_back(new TIntermDeclaration{&CreateStructTypeVariable(mSymbolTable, s)});
281
282 return s;
283 }
284 };
285
286 ////////////////////////////////////////////////////////////////////////////////
287
CreatePipelineMainLocalVar(TSymbolTable & symbolTable,const Pipeline & pipeline,PipelineScoped<TStructure> pipelineStruct)288 PipelineScoped<TVariable> CreatePipelineMainLocalVar(TSymbolTable &symbolTable,
289 const Pipeline &pipeline,
290 PipelineScoped<TStructure> pipelineStruct)
291 {
292 ASSERT(pipelineStruct.isTotallyFull());
293
294 PipelineScoped<TVariable> pipelineMainLocalVar;
295
296 auto populateExternalMainLocalVar = [&]() {
297 ASSERT(!pipelineMainLocalVar.external);
298 pipelineMainLocalVar.external = &CreateInstanceVariable(
299 symbolTable, *pipelineStruct.external,
300 pipeline.getStructInstanceName(pipelineStruct.isUniform()
301 ? Pipeline::Variant::Original
302 : Pipeline::Variant::Modified));
303 };
304
305 auto populateDistinctInternalMainLocalVar = [&]() {
306 ASSERT(!pipelineMainLocalVar.internal);
307 pipelineMainLocalVar.internal =
308 &CreateInstanceVariable(symbolTable, *pipelineStruct.internal,
309 pipeline.getStructInstanceName(Pipeline::Variant::Original));
310 };
311
312 if (pipeline.type == Pipeline::Type::InstanceId)
313 {
314 populateDistinctInternalMainLocalVar();
315 }
316 else if (pipeline.alwaysRequiresLocalVariableDeclarationInMain())
317 {
318 populateExternalMainLocalVar();
319
320 if (pipelineStruct.isUniform())
321 {
322 pipelineMainLocalVar.internal = pipelineMainLocalVar.external;
323 }
324 else
325 {
326 populateDistinctInternalMainLocalVar();
327 }
328 }
329 else if (!pipelineStruct.isUniform())
330 {
331 populateDistinctInternalMainLocalVar();
332 }
333
334 return pipelineMainLocalVar;
335 }
336
337 class PipelineFunctionEnv
338 {
339 private:
340 TCompiler &mCompiler;
341 SymbolEnv &mSymbolEnv;
342 TSymbolTable &mSymbolTable;
343 IdGen &mIdGen;
344 const Pipeline &mPipeline;
345 const std::unordered_set<const TFunction *> &mPipelineFunctions;
346 const PipelineScoped<TStructure> mPipelineStruct;
347 PipelineScoped<TVariable> &mPipelineMainLocalVar;
348
349 std::unordered_map<const TFunction *, const TFunction *> mFuncMap;
350
351 public:
PipelineFunctionEnv(TCompiler & compiler,SymbolEnv & symbolEnv,IdGen & idGen,const Pipeline & pipeline,const std::unordered_set<const TFunction * > & pipelineFunctions,PipelineScoped<TStructure> pipelineStruct,PipelineScoped<TVariable> & pipelineMainLocalVar)352 PipelineFunctionEnv(TCompiler &compiler,
353 SymbolEnv &symbolEnv,
354 IdGen &idGen,
355 const Pipeline &pipeline,
356 const std::unordered_set<const TFunction *> &pipelineFunctions,
357 PipelineScoped<TStructure> pipelineStruct,
358 PipelineScoped<TVariable> &pipelineMainLocalVar)
359 : mCompiler(compiler),
360 mSymbolEnv(symbolEnv),
361 mSymbolTable(symbolEnv.symbolTable()),
362 mIdGen(idGen),
363 mPipeline(pipeline),
364 mPipelineFunctions(pipelineFunctions),
365 mPipelineStruct(pipelineStruct),
366 mPipelineMainLocalVar(pipelineMainLocalVar)
367 {}
368
isOriginalPipelineFunction(const TFunction & func) const369 bool isOriginalPipelineFunction(const TFunction &func) const
370 {
371 return mPipelineFunctions.find(&func) != mPipelineFunctions.end();
372 }
373
isUpdatedPipelineFunction(const TFunction & func) const374 bool isUpdatedPipelineFunction(const TFunction &func) const
375 {
376 auto it = mFuncMap.find(&func);
377 if (it == mFuncMap.end())
378 {
379 return false;
380 }
381 return &func == it->second;
382 }
383
getUpdatedFunction(const TFunction & func)384 const TFunction &getUpdatedFunction(const TFunction &func)
385 {
386 ASSERT(isOriginalPipelineFunction(func) || isUpdatedPipelineFunction(func));
387
388 const TFunction *newFunc;
389
390 auto it = mFuncMap.find(&func);
391 if (it == mFuncMap.end())
392 {
393 const bool isMain = func.isMain();
394
395 if (isMain && mPipeline.isPipelineOut())
396 {
397 ASSERT(func.getReturnType().getBasicType() == TBasicType::EbtVoid);
398 newFunc = &CloneFunctionAndChangeReturnType(mSymbolTable, nullptr, func,
399 *mPipelineStruct.external);
400 }
401 else if (isMain && (mPipeline.type == Pipeline::Type::InvocationVertexGlobals ||
402 mPipeline.type == Pipeline::Type::InvocationFragmentGlobals))
403 {
404 std::vector<const TVariable *> variables;
405 for (const TField *field : mPipelineStruct.external->fields())
406 {
407 variables.push_back(new TVariable(&mSymbolTable, field->name(), field->type(),
408 field->symbolType()));
409 }
410 newFunc = &CloneFunctionAndAppendParams(mSymbolTable, nullptr, func, variables);
411 }
412 else if (isMain && mPipeline.type == Pipeline::Type::Texture)
413 {
414 std::vector<const TVariable *> variables;
415 TranslatorMetalReflection *reflection =
416 mtl::getTranslatorMetalReflection(&mCompiler);
417 for (const TField *field : mPipelineStruct.external->fields())
418 {
419 const TStructure *textureEnv = field->type()->getStruct();
420 ASSERT(textureEnv && textureEnv->fields().size() == 2);
421 for (const TField *subfield : textureEnv->fields())
422 {
423 const Name name = mIdGen.createNewName({field->name(), subfield->name()});
424 TType &type = *new TType(*subfield->type());
425 ASSERT(!type.isArray());
426 type.makeArrays(field->type()->getArraySizes());
427 auto *var =
428 new TVariable(&mSymbolTable, name.rawName(), &type, name.symbolType());
429 variables.push_back(var);
430 reflection->addOriginalName(var->uniqueId().get(), field->name().data());
431 }
432 }
433 newFunc = &CloneFunctionAndAppendParams(mSymbolTable, nullptr, func, variables);
434 }
435 else if (isMain && mPipeline.type == Pipeline::Type::InstanceId)
436 {
437 Name instanceIdName = mPipeline.getStructInstanceName(Pipeline::Variant::Modified);
438 auto *instanceIdVar =
439 new TVariable(&mSymbolTable, instanceIdName.rawName(),
440 new TType(TBasicType::EbtUInt), instanceIdName.symbolType());
441
442 auto *baseInstanceVar =
443 new TVariable(&mSymbolTable, kBaseInstanceName.rawName(),
444 new TType(TBasicType::EbtUInt), kBaseInstanceName.symbolType());
445
446 newFunc = &CloneFunctionAndPrependTwoParams(mSymbolTable, nullptr, func,
447 *instanceIdVar, *baseInstanceVar);
448 mPipelineMainLocalVar.external = instanceIdVar;
449 mPipelineMainLocalVar.externalExtra = baseInstanceVar;
450 }
451 else if (isMain && mPipeline.alwaysRequiresLocalVariableDeclarationInMain())
452 {
453 ASSERT(mPipelineMainLocalVar.isTotallyFull());
454 newFunc = &func;
455 }
456 else
457 {
458 const TVariable *var;
459 AddressSpace addressSpace;
460
461 if (isMain && !mPipelineMainLocalVar.isUniform())
462 {
463 var = &CreateInstanceVariable(
464 mSymbolTable, *mPipelineStruct.external,
465 mPipeline.getStructInstanceName(Pipeline::Variant::Modified));
466 addressSpace = mPipeline.externalAddressSpace();
467 }
468 else
469 {
470 if (mPipeline.type == Pipeline::Type::UniformBuffer)
471 {
472 TranslatorMetalReflection *reflection =
473 ((sh::TranslatorMetalDirect *)&mCompiler)
474 ->getTranslatorMetalReflection();
475 // TODO: need more checks to make sure they line up? Could be reordered?
476 ASSERT(mPipelineStruct.external->fields().size() ==
477 mPipelineStruct.internal->fields().size());
478 for (size_t i = 0; i < mPipelineStruct.external->fields().size(); i++)
479 {
480 const TField *externalField = mPipelineStruct.external->fields()[i];
481 const TField *internalField = mPipelineStruct.internal->fields()[i];
482 const TType &externalType = *externalField->type();
483 const TType &internalType = *internalField->type();
484 ASSERT(externalType.getBasicType() == internalType.getBasicType());
485 if (externalType.getBasicType() == TBasicType::EbtStruct)
486 {
487 const TStructure *externalEnv = externalType.getStruct();
488 const TStructure *internalEnv = internalType.getStruct();
489 const std::string internalName =
490 reflection->getOriginalName(internalEnv->uniqueId().get());
491 reflection->addOriginalName(externalEnv->uniqueId().get(),
492 internalName);
493 }
494 }
495 }
496 var = &CreateInstanceVariable(
497 mSymbolTable, *mPipelineStruct.internal,
498 mPipeline.getStructInstanceName(Pipeline::Variant::Original));
499 addressSpace = mPipelineMainLocalVar.isUniform()
500 ? mPipeline.externalAddressSpace()
501 : AddressSpace::Thread;
502 }
503
504 bool markAsReference = true;
505 if (isMain)
506 {
507 switch (mPipeline.type)
508 {
509 case Pipeline::Type::VertexIn:
510 case Pipeline::Type::FragmentIn:
511 markAsReference = false;
512 break;
513
514 default:
515 break;
516 }
517 }
518
519 if (markAsReference)
520 {
521 mSymbolEnv.markAsReference(*var, addressSpace);
522 }
523
524 newFunc = &CloneFunctionAndPrependParam(mSymbolTable, nullptr, func, *var);
525 }
526
527 mFuncMap[&func] = newFunc;
528 mFuncMap[newFunc] = newFunc;
529 }
530 else
531 {
532 newFunc = it->second;
533 }
534
535 return *newFunc;
536 }
537
createUpdatedFunctionPrototype(TIntermFunctionPrototype & funcProtoNode)538 TIntermFunctionPrototype *createUpdatedFunctionPrototype(
539 TIntermFunctionPrototype &funcProtoNode)
540 {
541 const TFunction &func = *funcProtoNode.getFunction();
542 if (!isOriginalPipelineFunction(func) && !isUpdatedPipelineFunction(func))
543 {
544 return nullptr;
545 }
546 const TFunction &newFunc = getUpdatedFunction(func);
547 return new TIntermFunctionPrototype(&newFunc);
548 }
549 };
550
551 class UpdatePipelineFunctions : private TIntermRebuild
552 {
553 private:
554 const Pipeline &mPipeline;
555 const PipelineScoped<TStructure> mPipelineStruct;
556 PipelineScoped<TVariable> &mPipelineMainLocalVar;
557 SymbolEnv &mSymbolEnv;
558 PipelineFunctionEnv mEnv;
559 const TFunction *mFuncOriginalToModified;
560 const TFunction *mFuncModifiedToOriginal;
561
562 public:
ThreadPipeline(TCompiler & compiler,TIntermBlock & root,const Pipeline & pipeline,const std::unordered_set<const TFunction * > & pipelineFunctions,PipelineScoped<TStructure> pipelineStruct,PipelineScoped<TVariable> & pipelineMainLocalVar,IdGen & idGen,SymbolEnv & symbolEnv,const TFunction * funcOriginalToModified,const TFunction * funcModifiedToOriginal)563 static bool ThreadPipeline(TCompiler &compiler,
564 TIntermBlock &root,
565 const Pipeline &pipeline,
566 const std::unordered_set<const TFunction *> &pipelineFunctions,
567 PipelineScoped<TStructure> pipelineStruct,
568 PipelineScoped<TVariable> &pipelineMainLocalVar,
569 IdGen &idGen,
570 SymbolEnv &symbolEnv,
571 const TFunction *funcOriginalToModified,
572 const TFunction *funcModifiedToOriginal)
573 {
574 UpdatePipelineFunctions self(compiler, pipeline, pipelineFunctions, pipelineStruct,
575 pipelineMainLocalVar, idGen, symbolEnv, funcOriginalToModified,
576 funcModifiedToOriginal);
577 if (!self.rebuildRoot(root))
578 {
579 return false;
580 }
581 return true;
582 }
583
584 private:
UpdatePipelineFunctions(TCompiler & compiler,const Pipeline & pipeline,const std::unordered_set<const TFunction * > & pipelineFunctions,PipelineScoped<TStructure> pipelineStruct,PipelineScoped<TVariable> & pipelineMainLocalVar,IdGen & idGen,SymbolEnv & symbolEnv,const TFunction * funcOriginalToModified,const TFunction * funcModifiedToOriginal)585 UpdatePipelineFunctions(TCompiler &compiler,
586 const Pipeline &pipeline,
587 const std::unordered_set<const TFunction *> &pipelineFunctions,
588 PipelineScoped<TStructure> pipelineStruct,
589 PipelineScoped<TVariable> &pipelineMainLocalVar,
590 IdGen &idGen,
591 SymbolEnv &symbolEnv,
592 const TFunction *funcOriginalToModified,
593 const TFunction *funcModifiedToOriginal)
594 : TIntermRebuild(compiler, false, true),
595 mPipeline(pipeline),
596 mPipelineStruct(pipelineStruct),
597 mPipelineMainLocalVar(pipelineMainLocalVar),
598 mSymbolEnv(symbolEnv),
599 mEnv(compiler,
600 symbolEnv,
601 idGen,
602 pipeline,
603 pipelineFunctions,
604 pipelineStruct,
605 mPipelineMainLocalVar),
606 mFuncOriginalToModified(funcOriginalToModified),
607 mFuncModifiedToOriginal(funcModifiedToOriginal)
608 {
609 ASSERT(mPipelineStruct.isTotallyFull());
610 }
611
getInternalPipelineVariable(const TFunction & pipelineFunc)612 const TVariable &getInternalPipelineVariable(const TFunction &pipelineFunc)
613 {
614 if (pipelineFunc.isMain() && (mPipeline.alwaysRequiresLocalVariableDeclarationInMain() ||
615 !mPipelineMainLocalVar.isUniform()))
616 {
617 ASSERT(mPipelineMainLocalVar.internal);
618 return *mPipelineMainLocalVar.internal;
619 }
620 else
621 {
622 ASSERT(pipelineFunc.getParamCount() > 0);
623 return *pipelineFunc.getParam(0);
624 }
625 }
626
getExternalPipelineVariable(const TFunction & mainFunc)627 const TVariable &getExternalPipelineVariable(const TFunction &mainFunc)
628 {
629 ASSERT(mainFunc.isMain());
630 if (mPipelineMainLocalVar.external)
631 {
632 return *mPipelineMainLocalVar.external;
633 }
634 else
635 {
636 ASSERT(mainFunc.getParamCount() > 0);
637 return *mainFunc.getParam(0);
638 }
639 }
640
getExternalExtraPipelineVariable(const TFunction & mainFunc)641 const TVariable &getExternalExtraPipelineVariable(const TFunction &mainFunc)
642 {
643 ASSERT(mainFunc.isMain());
644 if (mPipelineMainLocalVar.externalExtra)
645 {
646 return *mPipelineMainLocalVar.externalExtra;
647 }
648 else
649 {
650 ASSERT(mainFunc.getParamCount() > 1);
651 return *mainFunc.getParam(1);
652 }
653 }
654
visitAggregatePost(TIntermAggregate & callNode)655 PostResult visitAggregatePost(TIntermAggregate &callNode) override
656 {
657 if (callNode.isConstructor())
658 {
659 return callNode;
660 }
661 else
662 {
663 const TFunction &oldCalledFunc = *callNode.getFunction();
664 if (!mEnv.isOriginalPipelineFunction(oldCalledFunc))
665 {
666 return callNode;
667 }
668 const TFunction &newCalledFunc = mEnv.getUpdatedFunction(oldCalledFunc);
669
670 const TFunction *oldOwnerFunc = getParentFunction();
671 ASSERT(oldOwnerFunc);
672 const TFunction &newOwnerFunc = mEnv.getUpdatedFunction(*oldOwnerFunc);
673
674 return *TIntermAggregate::CreateFunctionCall(
675 newCalledFunc, &CloneSequenceAndPrepend(
676 *callNode.getSequence(),
677 *new TIntermSymbol(&getInternalPipelineVariable(newOwnerFunc))));
678 }
679 }
680
visitFunctionPrototypePost(TIntermFunctionPrototype & funcProtoNode)681 PostResult visitFunctionPrototypePost(TIntermFunctionPrototype &funcProtoNode) override
682 {
683 TIntermFunctionPrototype *newFuncProtoNode =
684 mEnv.createUpdatedFunctionPrototype(funcProtoNode);
685 if (newFuncProtoNode == nullptr)
686 {
687 return funcProtoNode;
688 }
689 return *newFuncProtoNode;
690 }
691
visitFunctionDefinitionPost(TIntermFunctionDefinition & funcDefNode)692 PostResult visitFunctionDefinitionPost(TIntermFunctionDefinition &funcDefNode) override
693 {
694 if (funcDefNode.getFunction()->isMain())
695 {
696 return visitMain(funcDefNode);
697 }
698 else
699 {
700 return visitNonMain(funcDefNode);
701 }
702 }
703
visitNonMain(TIntermFunctionDefinition & funcDefNode)704 TIntermNode &visitNonMain(TIntermFunctionDefinition &funcDefNode)
705 {
706 TIntermFunctionPrototype &funcProtoNode = *funcDefNode.getFunctionPrototype();
707 ASSERT(!funcProtoNode.getFunction()->isMain());
708
709 TIntermFunctionPrototype *newFuncProtoNode =
710 mEnv.createUpdatedFunctionPrototype(funcProtoNode);
711 if (newFuncProtoNode == nullptr)
712 {
713 return funcDefNode;
714 }
715
716 const TFunction &func = *newFuncProtoNode->getFunction();
717 ASSERT(!func.isMain());
718
719 TIntermBlock *body = funcDefNode.getBody();
720
721 return *new TIntermFunctionDefinition(newFuncProtoNode, body);
722 }
723
visitMain(TIntermFunctionDefinition & funcDefNode)724 TIntermNode &visitMain(TIntermFunctionDefinition &funcDefNode)
725 {
726 TIntermFunctionPrototype &funcProtoNode = *funcDefNode.getFunctionPrototype();
727 ASSERT(funcProtoNode.getFunction()->isMain());
728
729 TIntermFunctionPrototype *newFuncProtoNode =
730 mEnv.createUpdatedFunctionPrototype(funcProtoNode);
731 if (newFuncProtoNode == nullptr)
732 {
733 return funcDefNode;
734 }
735
736 const TFunction &func = *newFuncProtoNode->getFunction();
737 ASSERT(func.isMain());
738
739 auto callModifiedToOriginal = [&](TIntermBlock &body) {
740 ASSERT(mPipelineMainLocalVar.internal);
741 if (!mPipeline.isPipelineOut())
742 {
743 ASSERT(mFuncModifiedToOriginal);
744 auto *m = new TIntermSymbol(&getExternalPipelineVariable(func));
745 auto *o = new TIntermSymbol(mPipelineMainLocalVar.internal);
746 body.appendStatement(TIntermAggregate::CreateFunctionCall(
747 *mFuncModifiedToOriginal, new TIntermSequence{m, o}));
748 }
749 };
750
751 auto callOriginalToModified = [&](TIntermBlock &body) {
752 ASSERT(mPipelineMainLocalVar.internal);
753 if (mPipeline.isPipelineOut())
754 {
755 ASSERT(mFuncOriginalToModified);
756 auto *o = new TIntermSymbol(mPipelineMainLocalVar.internal);
757 auto *m = new TIntermSymbol(&getExternalPipelineVariable(func));
758 body.appendStatement(TIntermAggregate::CreateFunctionCall(
759 *mFuncOriginalToModified, new TIntermSequence{o, m}));
760 }
761 };
762
763 TIntermBlock *body = funcDefNode.getBody();
764
765 if (mPipeline.alwaysRequiresLocalVariableDeclarationInMain())
766 {
767 ASSERT(mPipelineMainLocalVar.isTotallyFull());
768
769 auto *newBody = new TIntermBlock();
770 newBody->appendStatement(new TIntermDeclaration{mPipelineMainLocalVar.internal});
771
772 if (mPipeline.type == Pipeline::Type::InvocationVertexGlobals ||
773 mPipeline.type == Pipeline::Type::InvocationFragmentGlobals)
774 {
775 // Populate struct instance with references to global pipeline variables.
776 for (const TField *field : mPipelineStruct.external->fields())
777 {
778 auto *var = new TVariable(&mSymbolTable, field->name(), field->type(),
779 field->symbolType());
780 auto *symbol = new TIntermSymbol(var);
781 auto &accessNode = AccessField(*mPipelineMainLocalVar.internal, var->name());
782 auto *assignNode = new TIntermBinary(TOperator::EOpAssign, &accessNode, symbol);
783 newBody->appendStatement(assignNode);
784 }
785 }
786 else if (mPipeline.type == Pipeline::Type::Texture)
787 {
788 const TFieldList &fields = mPipelineStruct.external->fields();
789
790 ASSERT(func.getParamCount() >= 2 * fields.size());
791 size_t paramIndex = func.getParamCount() - 2 * fields.size();
792
793 for (const TField *field : fields)
794 {
795 const TVariable &textureParam = *func.getParam(paramIndex++);
796 const TVariable &samplerParam = *func.getParam(paramIndex++);
797
798 auto go = [&](TIntermTyped &env, const int *index) {
799 TIntermTyped &textureField = AccessField(
800 AccessIndex(*env.deepCopy(), index), ImmutableString("texture"));
801 TIntermTyped &samplerField = AccessField(
802 AccessIndex(*env.deepCopy(), index), ImmutableString("sampler"));
803
804 auto mkAssign = [&](TIntermTyped &field, const TVariable ¶m) {
805 return new TIntermBinary(TOperator::EOpAssign, &field,
806 &mSymbolEnv.callFunctionOverload(
807 Name("addressof"), field.getType(),
808 *new TIntermSequence{&AccessIndex(
809 *new TIntermSymbol(¶m), index)}));
810 };
811
812 newBody->appendStatement(mkAssign(textureField, textureParam));
813 newBody->appendStatement(mkAssign(samplerField, samplerParam));
814 };
815
816 TIntermTyped &env = AccessField(*mPipelineMainLocalVar.internal, field->name());
817 const TType &envType = env.getType();
818
819 if (envType.isArray())
820 {
821 ASSERT(!envType.isArrayOfArrays());
822 const auto n = static_cast<int>(envType.getArraySizeProduct());
823 for (int i = 0; i < n; ++i)
824 {
825 go(env, &i);
826 }
827 }
828 else
829 {
830 go(env, nullptr);
831 }
832 }
833 }
834 else if (mPipeline.type == Pipeline::Type::InstanceId)
835 {
836 auto varInstanceId = new TIntermSymbol(&getExternalPipelineVariable(func));
837 auto varBaseInstance = new TIntermSymbol(&getExternalExtraPipelineVariable(func));
838
839 newBody->appendStatement(new TIntermBinary(
840 TOperator::EOpAssign,
841 &AccessFieldByIndex(*new TIntermSymbol(&getInternalPipelineVariable(func)), 0),
842 &AsType(
843 mSymbolEnv, *new TType(TBasicType::EbtInt),
844 *new TIntermBinary(TOperator::EOpSub, varInstanceId, varBaseInstance))));
845 }
846 else if (!mPipelineMainLocalVar.isUniform())
847 {
848 newBody->appendStatement(new TIntermDeclaration{mPipelineMainLocalVar.external});
849 callModifiedToOriginal(*newBody);
850 }
851
852 newBody->appendStatement(body);
853
854 if (!mPipelineMainLocalVar.isUniform())
855 {
856 callOriginalToModified(*newBody);
857 }
858
859 if (mPipeline.isPipelineOut())
860 {
861 newBody->appendStatement(new TIntermBranch(
862 TOperator::EOpReturn, new TIntermSymbol(mPipelineMainLocalVar.external)));
863 }
864
865 body = newBody;
866 }
867 else if (!mPipelineMainLocalVar.isUniform())
868 {
869 ASSERT(!mPipelineMainLocalVar.external);
870 ASSERT(mPipelineMainLocalVar.internal);
871
872 auto *newBody = new TIntermBlock();
873 newBody->appendStatement(new TIntermDeclaration{mPipelineMainLocalVar.internal});
874 callModifiedToOriginal(*newBody);
875 newBody->appendStatement(body);
876 callOriginalToModified(*newBody);
877 body = newBody;
878 }
879
880 return *new TIntermFunctionDefinition(newFuncProtoNode, body);
881 }
882 };
883
884 ////////////////////////////////////////////////////////////////////////////////
885
UpdatePipelineSymbols(Pipeline::Type pipelineType,TCompiler & compiler,TIntermBlock & root,SymbolEnv & symbolEnv,const VariableSet & pipelineVariables,PipelineScoped<TVariable> pipelineMainLocalVar)886 bool UpdatePipelineSymbols(Pipeline::Type pipelineType,
887 TCompiler &compiler,
888 TIntermBlock &root,
889 SymbolEnv &symbolEnv,
890 const VariableSet &pipelineVariables,
891 PipelineScoped<TVariable> pipelineMainLocalVar)
892 {
893 auto map = [&](const TFunction *owner, TIntermSymbol &symbol) -> TIntermNode & {
894 if (!owner)
895 return symbol;
896 const TVariable &var = symbol.variable();
897 if (pipelineVariables.find(&var) == pipelineVariables.end())
898 {
899 return symbol;
900 }
901 const TVariable *structInstanceVar;
902 if (owner->isMain())
903 {
904 ASSERT(pipelineMainLocalVar.internal);
905 structInstanceVar = pipelineMainLocalVar.internal;
906 }
907 else
908 {
909 ASSERT(owner->getParamCount() > 0);
910 structInstanceVar = owner->getParam(0);
911 }
912 ASSERT(structInstanceVar);
913 return AccessField(*structInstanceVar, var.name());
914 };
915 return MapSymbols(compiler, root, map);
916 }
917
918 ////////////////////////////////////////////////////////////////////////////////
919
RewritePipeline(TCompiler & compiler,TIntermBlock & root,IdGen & idGen,const Pipeline & pipeline,SymbolEnv & symbolEnv,const std::vector<sh::ShaderVariable> * variableInfo,PipelineScoped<TStructure> & outStruct)920 bool RewritePipeline(TCompiler &compiler,
921 TIntermBlock &root,
922 IdGen &idGen,
923 const Pipeline &pipeline,
924 SymbolEnv &symbolEnv,
925 const std::vector<sh::ShaderVariable> *variableInfo,
926 PipelineScoped<TStructure> &outStruct)
927 {
928 ASSERT(outStruct.isTotallyEmpty());
929
930 TSymbolTable &symbolTable = compiler.getSymbolTable();
931
932 PipelineStructInfo psi;
933 if (!GeneratePipelineStruct::Exec(psi, compiler, root, idGen, pipeline, symbolEnv,
934 variableInfo))
935 {
936 return false;
937 }
938
939 if (psi.isEmpty())
940 {
941 return true;
942 }
943
944 const auto pipelineFunctions = DiscoverDependentFunctions(root, [&](const TVariable &var) {
945 return psi.pipelineVariables.find(&var) != psi.pipelineVariables.end();
946 });
947
948 auto pipelineMainLocalVar =
949 CreatePipelineMainLocalVar(symbolTable, pipeline, psi.pipelineStruct);
950
951 if (!UpdatePipelineFunctions::ThreadPipeline(
952 compiler, root, pipeline, pipelineFunctions, psi.pipelineStruct, pipelineMainLocalVar,
953 idGen, symbolEnv, psi.funcOriginalToModified, psi.funcModifiedToOriginal))
954 {
955 return false;
956 }
957
958 if (!pipeline.globalInstanceVar)
959 {
960 if (!UpdatePipelineSymbols(pipeline.type, compiler, root, symbolEnv, psi.pipelineVariables,
961 pipelineMainLocalVar))
962 {
963 return false;
964 }
965 }
966
967 if (!PruneNoOps(&compiler, &root, &compiler.getSymbolTable()))
968 {
969 return false;
970 }
971
972 outStruct = psi.pipelineStruct;
973 return true;
974 }
975
976 } // anonymous namespace
977
RewritePipelines(TCompiler & compiler,TIntermBlock & root,const std::vector<sh::ShaderVariable> & inputVaryings,const std::vector<sh::ShaderVariable> & outputVaryings,IdGen & idGen,DriverUniform & angleUniformsGlobalInstanceVar,SymbolEnv & symbolEnv,PipelineStructs & outStructs)978 bool sh::RewritePipelines(TCompiler &compiler,
979 TIntermBlock &root,
980 const std::vector<sh::ShaderVariable> &inputVaryings,
981 const std::vector<sh::ShaderVariable> &outputVaryings,
982 IdGen &idGen,
983 DriverUniform &angleUniformsGlobalInstanceVar,
984 SymbolEnv &symbolEnv,
985 PipelineStructs &outStructs)
986 {
987 struct Info
988 {
989 Pipeline::Type pipelineType;
990 PipelineScoped<TStructure> &outStruct;
991 const TVariable *globalInstanceVar;
992 const std::vector<sh::ShaderVariable> *variableInfo;
993 };
994
995 Info infos[] = {
996 {Pipeline::Type::InstanceId, outStructs.instanceId, nullptr, nullptr},
997 {Pipeline::Type::Texture, outStructs.texture, nullptr, nullptr},
998 {Pipeline::Type::NonConstantGlobals, outStructs.nonConstantGlobals, nullptr, nullptr},
999 {Pipeline::Type::AngleUniforms, outStructs.angleUniforms,
1000 angleUniformsGlobalInstanceVar.getDriverUniformsVariable(), nullptr},
1001 {Pipeline::Type::UserUniforms, outStructs.userUniforms, nullptr, nullptr},
1002 {Pipeline::Type::VertexIn, outStructs.vertexIn, nullptr, &inputVaryings},
1003 {Pipeline::Type::VertexOut, outStructs.vertexOut, nullptr, &outputVaryings},
1004 {Pipeline::Type::FragmentIn, outStructs.fragmentIn, nullptr, &inputVaryings},
1005 {Pipeline::Type::FragmentOut, outStructs.fragmentOut, nullptr, &outputVaryings},
1006 {Pipeline::Type::InvocationVertexGlobals, outStructs.invocationVertexGlobals, nullptr,
1007 nullptr},
1008 {Pipeline::Type::InvocationFragmentGlobals, outStructs.invocationFragmentGlobals, nullptr,
1009 &inputVaryings},
1010 {Pipeline::Type::UniformBuffer, outStructs.uniformBuffers, nullptr, nullptr},
1011 };
1012
1013 for (Info &info : infos)
1014 {
1015 Pipeline pipeline{info.pipelineType, info.globalInstanceVar};
1016 if (!RewritePipeline(compiler, root, idGen, pipeline, symbolEnv, info.variableInfo,
1017 info.outStruct))
1018 {
1019 return false;
1020 }
1021 }
1022
1023 return true;
1024 }
1025