1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include <cstring>
8 #include <unordered_map>
9 #include <unordered_set>
10
11 #include "compiler/translator/TranslatorMetalDirect.h"
12 #include "compiler/translator/TranslatorMetalDirect/AstHelpers.h"
13 #include "compiler/translator/TranslatorMetalDirect/DiscoverDependentFunctions.h"
14 #include "compiler/translator/TranslatorMetalDirect/IdGen.h"
15 #include "compiler/translator/TranslatorMetalDirect/IntermRebuild.h"
16 #include "compiler/translator/TranslatorMetalDirect/MapSymbols.h"
17 #include "compiler/translator/TranslatorMetalDirect/Pipeline.h"
18 #include "compiler/translator/TranslatorMetalDirect/RewritePipelines.h"
19 #include "compiler/translator/TranslatorMetalDirect/SymbolEnv.h"
20 #include "compiler/translator/tree_ops/PruneNoOps.h"
21 #include "compiler/translator/tree_util/DriverUniform.h"
22 #include "compiler/translator/tree_util/FindMain.h"
23 #include "compiler/translator/tree_util/IntermTraverse.h"
24 using namespace sh;
25
26 ////////////////////////////////////////////////////////////////////////////////
27
28 namespace
29 {
30
31 using VariableSet = std::unordered_set<const TVariable *>;
32 using VariableList = std::vector<const TVariable *>;
33
34 ////////////////////////////////////////////////////////////////////////////////
35
36 struct PipelineStructInfo
37 {
38 VariableSet pipelineVariables;
39 PipelineScoped<TStructure> pipelineStruct;
40 const TFunction *funcOriginalToModified = nullptr;
41 const TFunction *funcModifiedToOriginal = nullptr;
42
isEmpty__anon4e3055eb0111::PipelineStructInfo43 bool isEmpty() const
44 {
45 if (pipelineStruct.isTotallyEmpty())
46 {
47 ASSERT(pipelineVariables.empty());
48 return true;
49 }
50 else
51 {
52 ASSERT(pipelineStruct.isTotallyFull());
53 ASSERT(!pipelineVariables.empty());
54 return false;
55 }
56 }
57 };
58
59 class GeneratePipelineStruct : private TIntermRebuild
60 {
61 private:
62 const Pipeline &mPipeline;
63 SymbolEnv &mSymbolEnv;
64 Invariants &mInvariants;
65 VariableList mPipelineVariableList;
66 IdGen &mIdGen;
67 PipelineStructInfo mInfo;
68
69 public:
Exec(PipelineStructInfo & out,TCompiler & compiler,TIntermBlock & root,IdGen & idGen,const Pipeline & pipeline,SymbolEnv & symbolEnv,Invariants & invariants)70 static bool Exec(PipelineStructInfo &out,
71 TCompiler &compiler,
72 TIntermBlock &root,
73 IdGen &idGen,
74 const Pipeline &pipeline,
75 SymbolEnv &symbolEnv,
76 Invariants &invariants)
77 {
78 GeneratePipelineStruct self(compiler, idGen, pipeline, symbolEnv, invariants);
79 if (!self.exec(root))
80 {
81 return false;
82 }
83 out = self.mInfo;
84 return true;
85 }
86
87 private:
GeneratePipelineStruct(TCompiler & compiler,IdGen & idGen,const Pipeline & pipeline,SymbolEnv & symbolEnv,Invariants & invariants)88 GeneratePipelineStruct(TCompiler &compiler,
89 IdGen &idGen,
90 const Pipeline &pipeline,
91 SymbolEnv &symbolEnv,
92 Invariants &invariants)
93 : TIntermRebuild(compiler, true, true),
94 mPipeline(pipeline),
95 mSymbolEnv(symbolEnv),
96 mInvariants(invariants),
97 mIdGen(idGen)
98 {}
99
exec(TIntermBlock & root)100 bool exec(TIntermBlock &root)
101 {
102 if (!rebuildRoot(root))
103 {
104 return false;
105 }
106
107 if (mInfo.pipelineVariables.empty())
108 {
109 return true;
110 }
111
112 TIntermSequence seq;
113
114 const TStructure &pipelineStruct = [&]() -> const TStructure & {
115 if (mPipeline.globalInstanceVar)
116 {
117 return *mPipeline.globalInstanceVar->getType().getStruct();
118 }
119 else
120 {
121 return createInternalPipelineStruct(root, seq);
122 }
123 }();
124
125 ModifiedStructMachineries modifiedMachineries;
126 const bool isUBO = mPipeline.type == Pipeline::Type::UniformBuffer;
127 const bool modified = TryCreateModifiedStruct(
128 mCompiler, mSymbolEnv, mIdGen, mPipeline.externalStructModifyConfig(), pipelineStruct,
129 mPipeline.getStructTypeName(Pipeline::Variant::Modified), modifiedMachineries, isUBO,
130 !isUBO);
131
132 if (modified)
133 {
134 ASSERT(mPipeline.type != Pipeline::Type::Texture);
135 ASSERT(mPipeline.type == Pipeline::Type::AngleUniforms ||
136 !mPipeline.globalInstanceVar); // This shouldn't happen by construction.
137
138 auto getFunction = [](sh::TIntermFunctionDefinition *funcDecl) {
139 return funcDecl ? funcDecl->getFunction() : nullptr;
140 };
141
142 const size_t size = modifiedMachineries.size();
143 ASSERT(size > 0);
144 for (size_t i = 0; i < size; ++i)
145 {
146 const ModifiedStructMachinery &machinery = modifiedMachineries.at(i);
147 ASSERT(machinery.modifiedStruct);
148
149 seq.push_back(new TIntermDeclaration{
150 &CreateStructTypeVariable(mSymbolTable, *machinery.modifiedStruct)});
151
152 if (mPipeline.isPipelineOut())
153 {
154 ASSERT(machinery.funcOriginalToModified);
155 ASSERT(!machinery.funcModifiedToOriginal);
156 seq.push_back(machinery.funcOriginalToModified);
157 }
158 else
159 {
160 ASSERT(machinery.funcModifiedToOriginal);
161 ASSERT(!machinery.funcOriginalToModified);
162 seq.push_back(machinery.funcModifiedToOriginal);
163 }
164
165 if (i == size - 1)
166 {
167 mInfo.funcOriginalToModified = getFunction(machinery.funcOriginalToModified);
168 mInfo.funcModifiedToOriginal = getFunction(machinery.funcModifiedToOriginal);
169
170 mInfo.pipelineStruct.internal = &pipelineStruct;
171 mInfo.pipelineStruct.external =
172 modified ? machinery.modifiedStruct : &pipelineStruct;
173 }
174 }
175 }
176 else
177 {
178 mInfo.pipelineStruct.internal = &pipelineStruct;
179 mInfo.pipelineStruct.external = &pipelineStruct;
180 }
181
182 root.insertChildNodes(FindMainIndex(&root), seq);
183
184 return true;
185 }
186
187 private:
visitFunctionDefinitionPre(TIntermFunctionDefinition & node)188 PreResult visitFunctionDefinitionPre(TIntermFunctionDefinition &node) override
189 {
190 return {node, VisitBits::Neither};
191 }
visitDeclarationPost(TIntermDeclaration & declNode)192 PostResult visitDeclarationPost(TIntermDeclaration &declNode) override
193 {
194 Declaration decl = ViewDeclaration(declNode);
195 const TVariable &var = decl.symbol.variable();
196 if (mPipeline.uses(var))
197 {
198 ASSERT(mInfo.pipelineVariables.find(&var) == mInfo.pipelineVariables.end());
199 mInfo.pipelineVariables.insert(&var);
200 mPipelineVariableList.push_back(&var);
201 return nullptr;
202 }
203
204 return declNode;
205 }
206
createInternalPipelineStruct(TIntermBlock & root,TIntermSequence & outDeclSeq)207 const TStructure &createInternalPipelineStruct(TIntermBlock &root, TIntermSequence &outDeclSeq)
208 {
209 auto &fields = *new TFieldList();
210
211 switch (mPipeline.type)
212 {
213 case Pipeline::Type::Texture:
214 {
215 for (const TVariable *var : mPipelineVariableList)
216 {
217 ASSERT(!mInvariants.contains(*var));
218 const TType &varType = var->getType();
219 const TBasicType samplerType = varType.getBasicType();
220
221 const TStructure &textureEnv = mSymbolEnv.getTextureEnv(samplerType);
222 auto *textureEnvType = new TType(&textureEnv, false);
223 if (varType.isArray())
224 {
225 textureEnvType->makeArrays(varType.getArraySizes());
226 }
227
228 fields.push_back(
229 new TField(textureEnvType, var->name(), kNoSourceLoc, var->symbolType()));
230 }
231 }
232 break;
233
234 case Pipeline::Type::UniformBuffer:
235 {
236 for (const TVariable *var : mPipelineVariableList)
237 {
238 auto &type = CloneType(var->getType());
239 auto *field = new TField(&type, var->name(), kNoSourceLoc, var->symbolType());
240 mSymbolEnv.markAsPointer(*field, AddressSpace::Constant);
241 mSymbolEnv.markAsUBO(*field);
242 mSymbolEnv.markAsPointer(*var, AddressSpace::Constant);
243 fields.push_back(field);
244 }
245 }
246 break;
247 default:
248 {
249 for (const TVariable *var : mPipelineVariableList)
250 {
251 auto &type = CloneType(var->getType());
252 auto *field = new TField(&type, var->name(), kNoSourceLoc, var->symbolType());
253 fields.push_back(field);
254
255 if (mInvariants.contains(*var))
256 {
257 mInvariants.insert(*field);
258 }
259 }
260 }
261 break;
262 }
263
264 Name pipelineStructName = mPipeline.getStructTypeName(Pipeline::Variant::Original);
265 auto &s = *new TStructure(&mSymbolTable, pipelineStructName.rawName(), &fields,
266 pipelineStructName.symbolType());
267
268 outDeclSeq.push_back(new TIntermDeclaration{&CreateStructTypeVariable(mSymbolTable, s)});
269
270 return s;
271 }
272 };
273
274 ////////////////////////////////////////////////////////////////////////////////
275
CreatePipelineMainLocalVar(TSymbolTable & symbolTable,const Pipeline & pipeline,PipelineScoped<TStructure> pipelineStruct)276 PipelineScoped<TVariable> CreatePipelineMainLocalVar(TSymbolTable &symbolTable,
277 const Pipeline &pipeline,
278 PipelineScoped<TStructure> pipelineStruct)
279 {
280 ASSERT(pipelineStruct.isTotallyFull());
281
282 PipelineScoped<TVariable> pipelineMainLocalVar;
283
284 auto populateExternalMainLocalVar = [&]() {
285 ASSERT(!pipelineMainLocalVar.external);
286 pipelineMainLocalVar.external = &CreateInstanceVariable(
287 symbolTable, *pipelineStruct.external,
288 pipeline.getStructInstanceName(pipelineStruct.isUniform()
289 ? Pipeline::Variant::Original
290 : Pipeline::Variant::Modified));
291 };
292
293 auto populateDistinctInternalMainLocalVar = [&]() {
294 ASSERT(!pipelineMainLocalVar.internal);
295 pipelineMainLocalVar.internal =
296 &CreateInstanceVariable(symbolTable, *pipelineStruct.internal,
297 pipeline.getStructInstanceName(Pipeline::Variant::Original));
298 };
299
300 if (pipeline.type == Pipeline::Type::InstanceId)
301 {
302 populateDistinctInternalMainLocalVar();
303 }
304 else if (pipeline.alwaysRequiresLocalVariableDeclarationInMain())
305 {
306 populateExternalMainLocalVar();
307
308 if (pipelineStruct.isUniform())
309 {
310 pipelineMainLocalVar.internal = pipelineMainLocalVar.external;
311 }
312 else
313 {
314 populateDistinctInternalMainLocalVar();
315 }
316 }
317 else if (!pipelineStruct.isUniform())
318 {
319 populateDistinctInternalMainLocalVar();
320 }
321
322 return pipelineMainLocalVar;
323 }
324
325 class PipelineFunctionEnv
326 {
327 private:
328 TCompiler &mCompiler;
329 SymbolEnv &mSymbolEnv;
330 TSymbolTable &mSymbolTable;
331 IdGen &mIdGen;
332 const Pipeline &mPipeline;
333 const std::unordered_set<const TFunction *> &mPipelineFunctions;
334 const PipelineScoped<TStructure> mPipelineStruct;
335 PipelineScoped<TVariable> &mPipelineMainLocalVar;
336
337 std::unordered_map<const TFunction *, const TFunction *> mFuncMap;
338
339 public:
PipelineFunctionEnv(TCompiler & compiler,SymbolEnv & symbolEnv,IdGen & idGen,const Pipeline & pipeline,const std::unordered_set<const TFunction * > & pipelineFunctions,PipelineScoped<TStructure> pipelineStruct,PipelineScoped<TVariable> & pipelineMainLocalVar)340 PipelineFunctionEnv(TCompiler &compiler,
341 SymbolEnv &symbolEnv,
342 IdGen &idGen,
343 const Pipeline &pipeline,
344 const std::unordered_set<const TFunction *> &pipelineFunctions,
345 PipelineScoped<TStructure> pipelineStruct,
346 PipelineScoped<TVariable> &pipelineMainLocalVar)
347 : mCompiler(compiler),
348 mSymbolEnv(symbolEnv),
349 mSymbolTable(symbolEnv.symbolTable()),
350 mIdGen(idGen),
351 mPipeline(pipeline),
352 mPipelineFunctions(pipelineFunctions),
353 mPipelineStruct(pipelineStruct),
354 mPipelineMainLocalVar(pipelineMainLocalVar)
355 {}
356
isOriginalPipelineFunction(const TFunction & func) const357 bool isOriginalPipelineFunction(const TFunction &func) const
358 {
359 return mPipelineFunctions.find(&func) != mPipelineFunctions.end();
360 }
361
isUpdatedPipelineFunction(const TFunction & func) const362 bool isUpdatedPipelineFunction(const TFunction &func) const
363 {
364 auto it = mFuncMap.find(&func);
365 if (it == mFuncMap.end())
366 {
367 return false;
368 }
369 return &func == it->second;
370 }
371
getUpdatedFunction(const TFunction & func)372 const TFunction &getUpdatedFunction(const TFunction &func)
373 {
374 ASSERT(isOriginalPipelineFunction(func) || isUpdatedPipelineFunction(func));
375
376 const TFunction *newFunc;
377
378 auto it = mFuncMap.find(&func);
379 if (it == mFuncMap.end())
380 {
381 const bool isMain = func.isMain();
382
383 if (isMain && mPipeline.isPipelineOut())
384 {
385 ASSERT(func.getReturnType().getBasicType() == TBasicType::EbtVoid);
386 newFunc = &CloneFunctionAndChangeReturnType(mSymbolTable, nullptr, func,
387 *mPipelineStruct.external);
388 }
389 else if (isMain && (mPipeline.type == Pipeline::Type::InvocationVertexGlobals ||
390 mPipeline.type == Pipeline::Type::InvocationFragmentGlobals))
391 {
392 std::vector<const TVariable *> variables;
393 for (const TField *field : mPipelineStruct.external->fields())
394 {
395 variables.push_back(new TVariable(&mSymbolTable, field->name(), field->type(),
396 field->symbolType()));
397 }
398 newFunc = &CloneFunctionAndAppendParams(mSymbolTable, nullptr, func, variables);
399 }
400 else if (isMain && mPipeline.type == Pipeline::Type::Texture)
401 {
402 std::vector<const TVariable *> variables;
403 TranslatorMetalReflection *reflection =
404 ((sh::TranslatorMetalDirect *)&mCompiler)->getTranslatorMetalReflection();
405 for (const TField *field : mPipelineStruct.external->fields())
406 {
407 const TStructure *textureEnv = field->type()->getStruct();
408 ASSERT(textureEnv && textureEnv->fields().size() == 2);
409 for (const TField *subfield : textureEnv->fields())
410 {
411 const Name name = mIdGen.createNewName({field->name(), subfield->name()});
412 TType &type = *new TType(*subfield->type());
413 ASSERT(!type.isArray());
414 type.makeArrays(field->type()->getArraySizes());
415 auto *var =
416 new TVariable(&mSymbolTable, name.rawName(), &type, name.symbolType());
417 variables.push_back(var);
418 reflection->addOriginalName(var->uniqueId().get(), field->name().data());
419 }
420 }
421 newFunc = &CloneFunctionAndAppendParams(mSymbolTable, nullptr, func, variables);
422 }
423 else if (isMain && mPipeline.type == Pipeline::Type::InstanceId)
424 {
425 Name name = mPipeline.getStructInstanceName(Pipeline::Variant::Modified);
426 auto *var = new TVariable(&mSymbolTable, name.rawName(),
427 new TType(TBasicType::EbtUInt), name.symbolType());
428 newFunc = &CloneFunctionAndPrependParam(mSymbolTable, nullptr, func, *var);
429 mPipelineMainLocalVar.external = var;
430 }
431 else if (isMain && mPipeline.alwaysRequiresLocalVariableDeclarationInMain())
432 {
433 ASSERT(mPipelineMainLocalVar.isTotallyFull());
434 newFunc = &func;
435 }
436 else
437 {
438 const TVariable *var;
439 AddressSpace addressSpace;
440
441 if (isMain && !mPipelineMainLocalVar.isUniform())
442 {
443 var = &CreateInstanceVariable(
444 mSymbolTable, *mPipelineStruct.external,
445 mPipeline.getStructInstanceName(Pipeline::Variant::Modified));
446 addressSpace = mPipeline.externalAddressSpace();
447 }
448 else
449 {
450 if (mPipeline.type == Pipeline::Type::UniformBuffer)
451 {
452 TranslatorMetalReflection *reflection =
453 ((sh::TranslatorMetalDirect *)&mCompiler)
454 ->getTranslatorMetalReflection();
455 // TODO: need more checks to make sure they line up? Could be reordered?
456 ASSERT(mPipelineStruct.external->fields().size() ==
457 mPipelineStruct.internal->fields().size());
458 for (size_t i = 0; i < mPipelineStruct.external->fields().size(); i++)
459 {
460 const TField *externalField = mPipelineStruct.external->fields()[i];
461 const TField *internalField = mPipelineStruct.internal->fields()[i];
462 const TType &externalType = *externalField->type();
463 const TType &internalType = *internalField->type();
464 ASSERT(externalType.getBasicType() == internalType.getBasicType());
465 if (externalType.getBasicType() == TBasicType::EbtStruct)
466 {
467 const TStructure *externalEnv = externalType.getStruct();
468 const TStructure *internalEnv = internalType.getStruct();
469 const std::string internalName =
470 reflection->getOriginalName(internalEnv->uniqueId().get());
471 reflection->addOriginalName(externalEnv->uniqueId().get(),
472 internalName);
473 }
474 }
475 }
476 var = &CreateInstanceVariable(
477 mSymbolTable, *mPipelineStruct.internal,
478 mPipeline.getStructInstanceName(Pipeline::Variant::Original));
479 addressSpace = mPipelineMainLocalVar.isUniform()
480 ? mPipeline.externalAddressSpace()
481 : AddressSpace::Thread;
482 }
483
484 bool markAsReference = true;
485 if (isMain)
486 {
487 switch (mPipeline.type)
488 {
489 case Pipeline::Type::VertexIn:
490 case Pipeline::Type::FragmentIn:
491 markAsReference = false;
492 break;
493
494 default:
495 break;
496 }
497 }
498
499 if (markAsReference)
500 {
501 mSymbolEnv.markAsReference(*var, addressSpace);
502 }
503
504 newFunc = &CloneFunctionAndPrependParam(mSymbolTable, nullptr, func, *var);
505 }
506
507 mFuncMap[&func] = newFunc;
508 mFuncMap[newFunc] = newFunc;
509 }
510 else
511 {
512 newFunc = it->second;
513 }
514
515 return *newFunc;
516 }
517
createUpdatedFunctionPrototype(TIntermFunctionPrototype & funcProtoNode)518 TIntermFunctionPrototype *createUpdatedFunctionPrototype(
519 TIntermFunctionPrototype &funcProtoNode)
520 {
521 const TFunction &func = *funcProtoNode.getFunction();
522 if (!isOriginalPipelineFunction(func) && !isUpdatedPipelineFunction(func))
523 {
524 return nullptr;
525 }
526 const TFunction &newFunc = getUpdatedFunction(func);
527 return new TIntermFunctionPrototype(&newFunc);
528 }
529 };
530
531 class UpdatePipelineFunctions : private TIntermRebuild
532 {
533 private:
534 const Pipeline &mPipeline;
535 const PipelineScoped<TStructure> mPipelineStruct;
536 PipelineScoped<TVariable> &mPipelineMainLocalVar;
537 SymbolEnv &mSymbolEnv;
538 PipelineFunctionEnv mEnv;
539 const TFunction *mFuncOriginalToModified;
540 const TFunction *mFuncModifiedToOriginal;
541
542 public:
ThreadPipeline(TCompiler & compiler,TIntermBlock & root,const Pipeline & pipeline,const std::unordered_set<const TFunction * > & pipelineFunctions,PipelineScoped<TStructure> pipelineStruct,PipelineScoped<TVariable> & pipelineMainLocalVar,IdGen & idGen,SymbolEnv & symbolEnv,const TFunction * funcOriginalToModified,const TFunction * funcModifiedToOriginal)543 static bool ThreadPipeline(TCompiler &compiler,
544 TIntermBlock &root,
545 const Pipeline &pipeline,
546 const std::unordered_set<const TFunction *> &pipelineFunctions,
547 PipelineScoped<TStructure> pipelineStruct,
548 PipelineScoped<TVariable> &pipelineMainLocalVar,
549 IdGen &idGen,
550 SymbolEnv &symbolEnv,
551 const TFunction *funcOriginalToModified,
552 const TFunction *funcModifiedToOriginal)
553 {
554 UpdatePipelineFunctions self(compiler, pipeline, pipelineFunctions, pipelineStruct,
555 pipelineMainLocalVar, idGen, symbolEnv, funcOriginalToModified,
556 funcModifiedToOriginal);
557 if (!self.rebuildRoot(root))
558 {
559 return false;
560 }
561 return true;
562 }
563
564 private:
UpdatePipelineFunctions(TCompiler & compiler,const Pipeline & pipeline,const std::unordered_set<const TFunction * > & pipelineFunctions,PipelineScoped<TStructure> pipelineStruct,PipelineScoped<TVariable> & pipelineMainLocalVar,IdGen & idGen,SymbolEnv & symbolEnv,const TFunction * funcOriginalToModified,const TFunction * funcModifiedToOriginal)565 UpdatePipelineFunctions(TCompiler &compiler,
566 const Pipeline &pipeline,
567 const std::unordered_set<const TFunction *> &pipelineFunctions,
568 PipelineScoped<TStructure> pipelineStruct,
569 PipelineScoped<TVariable> &pipelineMainLocalVar,
570 IdGen &idGen,
571 SymbolEnv &symbolEnv,
572 const TFunction *funcOriginalToModified,
573 const TFunction *funcModifiedToOriginal)
574 : TIntermRebuild(compiler, false, true),
575 mPipeline(pipeline),
576 mPipelineStruct(pipelineStruct),
577 mPipelineMainLocalVar(pipelineMainLocalVar),
578 mSymbolEnv(symbolEnv),
579 mEnv(compiler,
580 symbolEnv,
581 idGen,
582 pipeline,
583 pipelineFunctions,
584 pipelineStruct,
585 mPipelineMainLocalVar),
586 mFuncOriginalToModified(funcOriginalToModified),
587 mFuncModifiedToOriginal(funcModifiedToOriginal)
588 {
589 ASSERT(mPipelineStruct.isTotallyFull());
590 }
591
getInternalPipelineVariable(const TFunction & pipelineFunc)592 const TVariable &getInternalPipelineVariable(const TFunction &pipelineFunc)
593 {
594 if (pipelineFunc.isMain() && (mPipeline.alwaysRequiresLocalVariableDeclarationInMain() ||
595 !mPipelineMainLocalVar.isUniform()))
596 {
597 ASSERT(mPipelineMainLocalVar.internal);
598 return *mPipelineMainLocalVar.internal;
599 }
600 else
601 {
602 ASSERT(pipelineFunc.getParamCount() > 0);
603 return *pipelineFunc.getParam(0);
604 }
605 }
606
getExternalPipelineVariable(const TFunction & mainFunc)607 const TVariable &getExternalPipelineVariable(const TFunction &mainFunc)
608 {
609 ASSERT(mainFunc.isMain());
610 if (mPipelineMainLocalVar.external)
611 {
612 return *mPipelineMainLocalVar.external;
613 }
614 else
615 {
616 ASSERT(mainFunc.getParamCount() > 0);
617 return *mainFunc.getParam(0);
618 }
619 }
620
visitAggregatePost(TIntermAggregate & callNode)621 PostResult visitAggregatePost(TIntermAggregate &callNode) override
622 {
623 if (callNode.isConstructor())
624 {
625 return callNode;
626 }
627 else
628 {
629 const TFunction &oldCalledFunc = *callNode.getFunction();
630 if (!mEnv.isOriginalPipelineFunction(oldCalledFunc))
631 {
632 return callNode;
633 }
634 const TFunction &newCalledFunc = mEnv.getUpdatedFunction(oldCalledFunc);
635
636 const TFunction *oldOwnerFunc = getParentFunction();
637 ASSERT(oldOwnerFunc);
638 const TFunction &newOwnerFunc = mEnv.getUpdatedFunction(*oldOwnerFunc);
639
640 return *TIntermAggregate::CreateFunctionCall(
641 newCalledFunc, &CloneSequenceAndPrepend(
642 *callNode.getSequence(),
643 *new TIntermSymbol(&getInternalPipelineVariable(newOwnerFunc))));
644 }
645 }
646
visitFunctionPrototypePost(TIntermFunctionPrototype & funcProtoNode)647 PostResult visitFunctionPrototypePost(TIntermFunctionPrototype &funcProtoNode) override
648 {
649 TIntermFunctionPrototype *newFuncProtoNode =
650 mEnv.createUpdatedFunctionPrototype(funcProtoNode);
651 if (newFuncProtoNode == nullptr)
652 {
653 return funcProtoNode;
654 }
655 return *newFuncProtoNode;
656 }
657
visitFunctionDefinitionPost(TIntermFunctionDefinition & funcDefNode)658 PostResult visitFunctionDefinitionPost(TIntermFunctionDefinition &funcDefNode) override
659 {
660 if (funcDefNode.getFunction()->isMain())
661 {
662 return visitMain(funcDefNode);
663 }
664 else
665 {
666 return visitNonMain(funcDefNode);
667 }
668 }
669
visitNonMain(TIntermFunctionDefinition & funcDefNode)670 TIntermNode &visitNonMain(TIntermFunctionDefinition &funcDefNode)
671 {
672 TIntermFunctionPrototype &funcProtoNode = *funcDefNode.getFunctionPrototype();
673 ASSERT(!funcProtoNode.getFunction()->isMain());
674
675 TIntermFunctionPrototype *newFuncProtoNode =
676 mEnv.createUpdatedFunctionPrototype(funcProtoNode);
677 if (newFuncProtoNode == nullptr)
678 {
679 return funcDefNode;
680 }
681
682 const TFunction &func = *newFuncProtoNode->getFunction();
683 ASSERT(!func.isMain());
684
685 TIntermBlock *body = funcDefNode.getBody();
686
687 return *new TIntermFunctionDefinition(newFuncProtoNode, body);
688 }
689
visitMain(TIntermFunctionDefinition & funcDefNode)690 TIntermNode &visitMain(TIntermFunctionDefinition &funcDefNode)
691 {
692 TIntermFunctionPrototype &funcProtoNode = *funcDefNode.getFunctionPrototype();
693 ASSERT(funcProtoNode.getFunction()->isMain());
694
695 TIntermFunctionPrototype *newFuncProtoNode =
696 mEnv.createUpdatedFunctionPrototype(funcProtoNode);
697 if (newFuncProtoNode == nullptr)
698 {
699 return funcDefNode;
700 }
701
702 const TFunction &func = *newFuncProtoNode->getFunction();
703 ASSERT(func.isMain());
704
705 auto callModifiedToOriginal = [&](TIntermBlock &body) {
706 ASSERT(mPipelineMainLocalVar.internal);
707 if (!mPipeline.isPipelineOut())
708 {
709 ASSERT(mFuncModifiedToOriginal);
710 auto *m = new TIntermSymbol(&getExternalPipelineVariable(func));
711 auto *o = new TIntermSymbol(mPipelineMainLocalVar.internal);
712 body.appendStatement(TIntermAggregate::CreateFunctionCall(
713 *mFuncModifiedToOriginal, new TIntermSequence{m, o}));
714 }
715 };
716
717 auto callOriginalToModified = [&](TIntermBlock &body) {
718 ASSERT(mPipelineMainLocalVar.internal);
719 if (mPipeline.isPipelineOut())
720 {
721 ASSERT(mFuncOriginalToModified);
722 auto *o = new TIntermSymbol(mPipelineMainLocalVar.internal);
723 auto *m = new TIntermSymbol(&getExternalPipelineVariable(func));
724 body.appendStatement(TIntermAggregate::CreateFunctionCall(
725 *mFuncOriginalToModified, new TIntermSequence{o, m}));
726 }
727 };
728
729 TIntermBlock *body = funcDefNode.getBody();
730
731 if (mPipeline.alwaysRequiresLocalVariableDeclarationInMain())
732 {
733 ASSERT(mPipelineMainLocalVar.isTotallyFull());
734
735 auto *newBody = new TIntermBlock();
736 newBody->appendStatement(new TIntermDeclaration{mPipelineMainLocalVar.internal});
737
738 if (mPipeline.type == Pipeline::Type::InvocationVertexGlobals ||
739 mPipeline.type == Pipeline::Type::InvocationFragmentGlobals)
740 {
741 // Populate struct instance with references to global pipeline variables.
742 for (const TField *field : mPipelineStruct.external->fields())
743 {
744 auto *var = new TVariable(&mSymbolTable, field->name(), field->type(),
745 field->symbolType());
746 auto *symbol = new TIntermSymbol(var);
747 auto &accessNode = AccessField(*mPipelineMainLocalVar.internal, var->name());
748 auto *assignNode = new TIntermBinary(TOperator::EOpAssign, &accessNode, symbol);
749 newBody->appendStatement(assignNode);
750 }
751 }
752 else if (mPipeline.type == Pipeline::Type::Texture)
753 {
754 const TFieldList &fields = mPipelineStruct.external->fields();
755
756 ASSERT(func.getParamCount() >= 2 * fields.size());
757 size_t paramIndex = func.getParamCount() - 2 * fields.size();
758
759 for (const TField *field : fields)
760 {
761 const TVariable &textureParam = *func.getParam(paramIndex++);
762 const TVariable &samplerParam = *func.getParam(paramIndex++);
763
764 auto go = [&](TIntermTyped &env, const int *index) {
765 TIntermTyped &textureField = AccessField(
766 AccessIndex(*env.deepCopy(), index), ImmutableString("texture"));
767 TIntermTyped &samplerField = AccessField(
768 AccessIndex(*env.deepCopy(), index), ImmutableString("sampler"));
769
770 auto mkAssign = [&](TIntermTyped &field, const TVariable ¶m) {
771 return new TIntermBinary(TOperator::EOpAssign, &field,
772 &mSymbolEnv.callFunctionOverload(
773 Name("addressof"), field.getType(),
774 *new TIntermSequence{&AccessIndex(
775 *new TIntermSymbol(¶m), index)}));
776 };
777
778 newBody->appendStatement(mkAssign(textureField, textureParam));
779 newBody->appendStatement(mkAssign(samplerField, samplerParam));
780 };
781
782 TIntermTyped &env = AccessField(*mPipelineMainLocalVar.internal, field->name());
783 const TType &envType = env.getType();
784
785 if (envType.isArray())
786 {
787 ASSERT(!envType.isArrayOfArrays());
788 const auto n = static_cast<int>(envType.getArraySizeProduct());
789 for (int i = 0; i < n; ++i)
790 {
791 go(env, &i);
792 }
793 }
794 else
795 {
796 go(env, nullptr);
797 }
798 }
799 }
800 else if (mPipeline.type == Pipeline::Type::InstanceId)
801 {
802 newBody->appendStatement(new TIntermBinary(
803 TOperator::EOpAssign,
804 &AccessFieldByIndex(*new TIntermSymbol(&getInternalPipelineVariable(func)), 0),
805 &AsType(mSymbolEnv, *new TType(TBasicType::EbtInt),
806 *new TIntermSymbol(&getExternalPipelineVariable(func)))));
807 }
808 else if (!mPipelineMainLocalVar.isUniform())
809 {
810 newBody->appendStatement(new TIntermDeclaration{mPipelineMainLocalVar.external});
811 callModifiedToOriginal(*newBody);
812 }
813
814 newBody->appendStatement(body);
815
816 if (!mPipelineMainLocalVar.isUniform())
817 {
818 callOriginalToModified(*newBody);
819 }
820
821 if (mPipeline.isPipelineOut())
822 {
823 newBody->appendStatement(new TIntermBranch(
824 TOperator::EOpReturn, new TIntermSymbol(mPipelineMainLocalVar.external)));
825 }
826
827 body = newBody;
828 }
829 else if (!mPipelineMainLocalVar.isUniform())
830 {
831 ASSERT(!mPipelineMainLocalVar.external);
832 ASSERT(mPipelineMainLocalVar.internal);
833
834 auto *newBody = new TIntermBlock();
835 newBody->appendStatement(new TIntermDeclaration{mPipelineMainLocalVar.internal});
836 callModifiedToOriginal(*newBody);
837 newBody->appendStatement(body);
838 callOriginalToModified(*newBody);
839 body = newBody;
840 }
841
842 return *new TIntermFunctionDefinition(newFuncProtoNode, body);
843 }
844 };
845
846 ////////////////////////////////////////////////////////////////////////////////
847
UpdatePipelineSymbols(Pipeline::Type pipelineType,TCompiler & compiler,TIntermBlock & root,SymbolEnv & symbolEnv,const VariableSet & pipelineVariables,PipelineScoped<TVariable> pipelineMainLocalVar)848 bool UpdatePipelineSymbols(Pipeline::Type pipelineType,
849 TCompiler &compiler,
850 TIntermBlock &root,
851 SymbolEnv &symbolEnv,
852 const VariableSet &pipelineVariables,
853 PipelineScoped<TVariable> pipelineMainLocalVar)
854 {
855 auto map = [&](const TFunction *owner, TIntermSymbol &symbol) -> TIntermNode & {
856 const TVariable &var = symbol.variable();
857 if (pipelineVariables.find(&var) == pipelineVariables.end())
858 {
859 return symbol;
860 }
861 ASSERT(owner);
862 const TVariable *structInstanceVar;
863 if (owner->isMain())
864 {
865 ASSERT(pipelineMainLocalVar.internal);
866 structInstanceVar = pipelineMainLocalVar.internal;
867 }
868 else
869 {
870 ASSERT(owner->getParamCount() > 0);
871 structInstanceVar = owner->getParam(0);
872 }
873 ASSERT(structInstanceVar);
874 return AccessField(*structInstanceVar, var.name());
875 };
876 return MapSymbols(compiler, root, map);
877 }
878
879 ////////////////////////////////////////////////////////////////////////////////
880
RewritePipeline(TCompiler & compiler,TIntermBlock & root,IdGen & idGen,const Pipeline & pipeline,SymbolEnv & symbolEnv,Invariants & invariants,PipelineScoped<TStructure> & outStruct)881 bool RewritePipeline(TCompiler &compiler,
882 TIntermBlock &root,
883 IdGen &idGen,
884 const Pipeline &pipeline,
885 SymbolEnv &symbolEnv,
886 Invariants &invariants,
887 PipelineScoped<TStructure> &outStruct)
888 {
889 ASSERT(outStruct.isTotallyEmpty());
890
891 TSymbolTable &symbolTable = compiler.getSymbolTable();
892
893 PipelineStructInfo psi;
894 if (!GeneratePipelineStruct::Exec(psi, compiler, root, idGen, pipeline, symbolEnv, invariants))
895 {
896 return false;
897 }
898
899 if (psi.isEmpty())
900 {
901 return true;
902 }
903
904 const auto pipelineFunctions = DiscoverDependentFunctions(root, [&](const TVariable &var) {
905 return psi.pipelineVariables.find(&var) != psi.pipelineVariables.end();
906 });
907
908 auto pipelineMainLocalVar =
909 CreatePipelineMainLocalVar(symbolTable, pipeline, psi.pipelineStruct);
910
911 if (!UpdatePipelineFunctions::ThreadPipeline(
912 compiler, root, pipeline, pipelineFunctions, psi.pipelineStruct, pipelineMainLocalVar,
913 idGen, symbolEnv, psi.funcOriginalToModified, psi.funcModifiedToOriginal))
914 {
915 return false;
916 }
917
918 if (!pipeline.globalInstanceVar)
919 {
920 if (!UpdatePipelineSymbols(pipeline.type, compiler, root, symbolEnv, psi.pipelineVariables,
921 pipelineMainLocalVar))
922 {
923 return false;
924 }
925 }
926
927 if (!PruneNoOps(&compiler, &root, &compiler.getSymbolTable()))
928 {
929 return false;
930 }
931
932 outStruct = psi.pipelineStruct;
933 return true;
934 }
935
936 } // anonymous namespace
937
RewritePipelines(TCompiler & compiler,TIntermBlock & root,IdGen & idGen,DriverUniform & angleUniformsGlobalInstanceVar,SymbolEnv & symbolEnv,Invariants & invariants,PipelineStructs & outStructs)938 bool sh::RewritePipelines(TCompiler &compiler,
939 TIntermBlock &root,
940 IdGen &idGen,
941 DriverUniform &angleUniformsGlobalInstanceVar,
942 SymbolEnv &symbolEnv,
943 Invariants &invariants,
944 PipelineStructs &outStructs)
945 {
946 struct Info
947 {
948 Pipeline::Type pipelineType;
949 PipelineScoped<TStructure> &outStruct;
950 const TVariable *globalInstanceVar;
951 };
952
953 Info infos[] = {
954 {Pipeline::Type::InstanceId, outStructs.instanceId, nullptr},
955 {Pipeline::Type::Texture, outStructs.texture, nullptr},
956 {Pipeline::Type::NonConstantGlobals, outStructs.nonConstantGlobals, nullptr},
957 {Pipeline::Type::AngleUniforms, outStructs.angleUniforms,
958 angleUniformsGlobalInstanceVar.getDriverUniformsVariable()},
959 {Pipeline::Type::UserUniforms, outStructs.userUniforms, nullptr},
960 {Pipeline::Type::VertexIn, outStructs.vertexIn, nullptr},
961 {Pipeline::Type::VertexOut, outStructs.vertexOut, nullptr},
962 {Pipeline::Type::FragmentIn, outStructs.fragmentIn, nullptr},
963 {Pipeline::Type::FragmentOut, outStructs.fragmentOut, nullptr},
964 {Pipeline::Type::InvocationVertexGlobals, outStructs.invocationVertexGlobals, nullptr},
965 {Pipeline::Type::InvocationFragmentGlobals, outStructs.invocationFragmentGlobals, nullptr},
966 {Pipeline::Type::UniformBuffer, outStructs.uniformBuffers, nullptr},
967 };
968
969 for (Info &info : infos)
970 {
971 Pipeline pipeline{info.pipelineType, info.globalInstanceVar};
972 if (!RewritePipeline(compiler, root, idGen, pipeline, symbolEnv, invariants,
973 info.outStruct))
974 {
975 return false;
976 }
977 }
978
979 return true;
980 }
981