• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2018 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RewriteStructSamplers: Extract structs from samplers.
7 //
8 
9 #include "compiler/translator/tree_ops/RewriteStructSamplers.h"
10 
11 #include "compiler/translator/ImmutableStringBuilder.h"
12 #include "compiler/translator/SymbolTable.h"
13 #include "compiler/translator/tree_util/IntermTraverse.h"
14 
15 namespace sh
16 {
17 namespace
18 {
19 // Helper method to get the sampler extracted struct type of a parameter.
GetStructSamplerParameterType(TSymbolTable * symbolTable,const TVariable & param)20 TType *GetStructSamplerParameterType(TSymbolTable *symbolTable, const TVariable &param)
21 {
22     const TStructure *structure = param.getType().getStruct();
23     const TSymbol *structSymbol = symbolTable->findUserDefined(structure->name());
24     ASSERT(structSymbol && structSymbol->isStruct());
25     const TStructure *structVar = static_cast<const TStructure *>(structSymbol);
26     TType *structType           = new TType(structVar, false);
27 
28     if (param.getType().isArray())
29     {
30         structType->makeArrays(param.getType().getArraySizes());
31     }
32 
33     ASSERT(!structType->isStructureContainingSamplers());
34 
35     return structType;
36 }
37 
ReplaceTypeOfSymbolNode(TIntermSymbol * symbolNode,TSymbolTable * symbolTable)38 TIntermSymbol *ReplaceTypeOfSymbolNode(TIntermSymbol *symbolNode, TSymbolTable *symbolTable)
39 {
40     const TVariable &oldVariable = symbolNode->variable();
41 
42     TType *newType = GetStructSamplerParameterType(symbolTable, oldVariable);
43 
44     TVariable *newVariable =
45         new TVariable(oldVariable.uniqueId(), oldVariable.name(), oldVariable.symbolType(),
46                       oldVariable.extension(), newType);
47     return new TIntermSymbol(newVariable);
48 }
49 
ReplaceTypeOfTypedStructNode(TIntermTyped * argument,TSymbolTable * symbolTable)50 TIntermTyped *ReplaceTypeOfTypedStructNode(TIntermTyped *argument, TSymbolTable *symbolTable)
51 {
52     TIntermSymbol *asSymbol = argument->getAsSymbolNode();
53     if (asSymbol)
54     {
55         ASSERT(asSymbol->getType().getStruct());
56         return ReplaceTypeOfSymbolNode(asSymbol, symbolTable);
57     }
58 
59     TIntermTyped *replacement = argument->deepCopy();
60     TIntermBinary *binary     = replacement->getAsBinaryNode();
61     ASSERT(binary);
62 
63     while (binary)
64     {
65         ASSERT(binary->getOp() == EOpIndexDirectStruct || binary->getOp() == EOpIndexDirect);
66 
67         asSymbol = binary->getLeft()->getAsSymbolNode();
68 
69         if (asSymbol)
70         {
71             ASSERT(asSymbol->getType().getStruct());
72             TIntermSymbol *newSymbol = ReplaceTypeOfSymbolNode(asSymbol, symbolTable);
73             binary->replaceChildNode(binary->getLeft(), newSymbol);
74             return replacement;
75         }
76 
77         binary = binary->getLeft()->getAsBinaryNode();
78     }
79 
80     UNREACHABLE();
81     return nullptr;
82 }
83 
84 // Maximum string size of a hex unsigned int.
85 constexpr size_t kHexSize = ImmutableStringBuilder::GetHexCharCount<unsigned int>();
86 
87 class Traverser final : public TIntermTraverser
88 {
89   public:
Traverser(TSymbolTable * symbolTable)90     explicit Traverser(TSymbolTable *symbolTable)
91         : TIntermTraverser(true, false, true, symbolTable), mRemovedUniformsCount(0)
92     {
93         mSymbolTable->push();
94     }
95 
~Traverser()96     ~Traverser() override { mSymbolTable->pop(); }
97 
removedUniformsCount() const98     int removedUniformsCount() const { return mRemovedUniformsCount; }
99 
100     // Each struct sampler declaration is stripped of its samplers. New uniforms are added for each
101     // stripped struct sampler.
visitDeclaration(Visit visit,TIntermDeclaration * decl)102     bool visitDeclaration(Visit visit, TIntermDeclaration *decl) override
103     {
104         if (visit != PreVisit)
105             return true;
106 
107         if (!mInGlobalScope)
108         {
109             return true;
110         }
111 
112         const TIntermSequence &sequence = *(decl->getSequence());
113         TIntermTyped *declarator        = sequence.front()->getAsTyped();
114         const TType &type               = declarator->getType();
115 
116         if (type.isStructureContainingSamplers())
117         {
118             TIntermSequence *newSequence = new TIntermSequence;
119 
120             if (type.isStructSpecifier())
121             {
122                 stripStructSpecifierSamplers(type.getStruct(), newSequence);
123             }
124             else
125             {
126                 TIntermSymbol *asSymbol = declarator->getAsSymbolNode();
127                 ASSERT(asSymbol);
128                 const TVariable &variable = asSymbol->variable();
129                 ASSERT(variable.symbolType() != SymbolType::Empty);
130                 extractStructSamplerUniforms(decl, variable, type.getStruct(), newSequence);
131             }
132 
133             mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), decl, *newSequence);
134         }
135 
136         return true;
137     }
138 
139     // Each struct sampler reference is replaced with a reference to the new extracted sampler.
visitBinary(Visit visit,TIntermBinary * node)140     bool visitBinary(Visit visit, TIntermBinary *node) override
141     {
142         if (visit != PreVisit)
143             return true;
144 
145         if (node->getOp() == EOpIndexDirectStruct && node->getType().isSampler())
146         {
147             ImmutableString newName = GetStructSamplerNameFromTypedNode(node);
148             const TVariable *samplerReplacement =
149                 static_cast<const TVariable *>(mSymbolTable->findUserDefined(newName));
150             ASSERT(samplerReplacement);
151 
152             TIntermSymbol *replacement = new TIntermSymbol(samplerReplacement);
153 
154             queueReplacement(replacement, OriginalNode::IS_DROPPED);
155             return true;
156         }
157 
158         return true;
159     }
160 
161     // In we are passing references to structs containing samplers we must new additional
162     // arguments. For each extracted struct sampler a new argument is added. This chains to nested
163     // structs.
visitFunctionPrototype(TIntermFunctionPrototype * node)164     void visitFunctionPrototype(TIntermFunctionPrototype *node) override
165     {
166         const TFunction *function = node->getFunction();
167 
168         if (!function->hasSamplerInStructOrArrayOfArrayParams())
169         {
170             return;
171         }
172 
173         const TSymbol *foundFunction = mSymbolTable->findUserDefined(function->name());
174         if (foundFunction)
175         {
176             ASSERT(foundFunction->isFunction());
177             function = static_cast<const TFunction *>(foundFunction);
178         }
179         else
180         {
181             TFunction *newFunction = createStructSamplerFunction(function);
182             mSymbolTable->declareUserDefinedFunction(newFunction, true);
183             function = newFunction;
184         }
185 
186         ASSERT(!function->hasSamplerInStructOrArrayOfArrayParams());
187         TIntermFunctionPrototype *newProto = new TIntermFunctionPrototype(function);
188         queueReplacement(newProto, OriginalNode::IS_DROPPED);
189     }
190 
191     // We insert a new scope for each function definition so we can track the new parameters.
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)192     bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override
193     {
194         if (visit == PreVisit)
195         {
196             mSymbolTable->push();
197         }
198         else
199         {
200             ASSERT(visit == PostVisit);
201             mSymbolTable->pop();
202         }
203         return true;
204     }
205 
206     // For function call nodes we pass references to the extracted struct samplers in that scope.
visitAggregate(Visit visit,TIntermAggregate * node)207     bool visitAggregate(Visit visit, TIntermAggregate *node) override
208     {
209         if (visit != PreVisit)
210             return true;
211 
212         if (!node->isFunctionCall())
213             return true;
214 
215         const TFunction *function = node->getFunction();
216         if (!function->hasSamplerInStructOrArrayOfArrayParams())
217             return true;
218 
219         ASSERT(node->getOp() == EOpCallFunctionInAST);
220         TFunction *newFunction        = mSymbolTable->findUserDefinedFunction(function->name());
221         TIntermSequence *newArguments = getStructSamplerArguments(function, node->getSequence());
222 
223         TIntermAggregate *newCall =
224             TIntermAggregate::CreateFunctionCall(*newFunction, newArguments);
225         queueReplacement(newCall, OriginalNode::IS_DROPPED);
226         return true;
227     }
228 
229   private:
230     // This returns the name of a struct sampler reference. References are always TIntermBinary.
GetStructSamplerNameFromTypedNode(TIntermTyped * node)231     static ImmutableString GetStructSamplerNameFromTypedNode(TIntermTyped *node)
232     {
233         std::string stringBuilder;
234 
235         TIntermTyped *currentNode = node;
236         while (currentNode->getAsBinaryNode())
237         {
238             TIntermBinary *asBinary = currentNode->getAsBinaryNode();
239 
240             switch (asBinary->getOp())
241             {
242                 case EOpIndexDirect:
243                 {
244                     const int index = asBinary->getRight()->getAsConstantUnion()->getIConst(0);
245                     const std::string strInt = Str(index);
246                     stringBuilder.insert(0, strInt);
247                     stringBuilder.insert(0, "_");
248                     break;
249                 }
250                 case EOpIndexDirectStruct:
251                 {
252                     stringBuilder.insert(0, asBinary->getIndexStructFieldName().data());
253                     stringBuilder.insert(0, "_");
254                     break;
255                 }
256 
257                 default:
258                     UNREACHABLE();
259                     break;
260             }
261 
262             currentNode = asBinary->getLeft();
263         }
264 
265         const ImmutableString &variableName = currentNode->getAsSymbolNode()->variable().name();
266         stringBuilder.insert(0, variableName.data());
267 
268         return stringBuilder;
269     }
270 
271     // Removes all the struct samplers from a struct specifier.
stripStructSpecifierSamplers(const TStructure * structure,TIntermSequence * newSequence)272     void stripStructSpecifierSamplers(const TStructure *structure, TIntermSequence *newSequence)
273     {
274         TFieldList *newFieldList = new TFieldList;
275         ASSERT(structure->containsSamplers());
276 
277         for (const TField *field : structure->fields())
278         {
279             const TType &fieldType = *field->type();
280             if (!fieldType.isSampler() && !isRemovedStructType(fieldType))
281             {
282                 TType *newType = nullptr;
283 
284                 if (fieldType.isStructureContainingSamplers())
285                 {
286                     const TSymbol *structSymbol =
287                         mSymbolTable->findUserDefined(fieldType.getStruct()->name());
288                     ASSERT(structSymbol && structSymbol->isStruct());
289                     const TStructure *fieldStruct = static_cast<const TStructure *>(structSymbol);
290                     newType                       = new TType(fieldStruct, true);
291                     if (fieldType.isArray())
292                     {
293                         newType->makeArrays(fieldType.getArraySizes());
294                     }
295                 }
296                 else
297                 {
298                     newType = new TType(fieldType);
299                 }
300 
301                 TField *newField =
302                     new TField(newType, field->name(), field->line(), field->symbolType());
303                 newFieldList->push_back(newField);
304             }
305         }
306 
307         // Prune empty structs.
308         if (newFieldList->empty())
309         {
310             mRemovedStructs.insert(structure->name());
311             return;
312         }
313 
314         TStructure *newStruct =
315             new TStructure(mSymbolTable, structure->name(), newFieldList, structure->symbolType());
316         TType *newStructType = new TType(newStruct, true);
317         TVariable *newStructVar =
318             new TVariable(mSymbolTable, kEmptyImmutableString, newStructType, SymbolType::Empty);
319         TIntermSymbol *newStructRef = new TIntermSymbol(newStructVar);
320 
321         TIntermDeclaration *structDecl = new TIntermDeclaration;
322         structDecl->appendDeclarator(newStructRef);
323 
324         newSequence->push_back(structDecl);
325 
326         mSymbolTable->declare(newStruct);
327     }
328 
329     // Returns true if the type is a struct that was removed because we extracted all the members.
isRemovedStructType(const TType & type) const330     bool isRemovedStructType(const TType &type) const
331     {
332         const TStructure *structure = type.getStruct();
333         return (structure && (mRemovedStructs.count(structure->name()) > 0));
334     }
335 
336     // Removes samplers from struct uniforms. For each sampler removed also adds a new globally
337     // defined sampler uniform.
extractStructSamplerUniforms(TIntermDeclaration * oldDeclaration,const TVariable & variable,const TStructure * structure,TIntermSequence * newSequence)338     void extractStructSamplerUniforms(TIntermDeclaration *oldDeclaration,
339                                       const TVariable &variable,
340                                       const TStructure *structure,
341                                       TIntermSequence *newSequence)
342     {
343         ASSERT(structure->containsSamplers());
344 
345         size_t nonSamplerCount = 0;
346 
347         for (const TField *field : structure->fields())
348         {
349             nonSamplerCount +=
350                 extractFieldSamplers(variable.name(), field, variable.getType(), newSequence);
351         }
352 
353         if (nonSamplerCount > 0)
354         {
355             // Keep the old declaration around if it has other members.
356             newSequence->push_back(oldDeclaration);
357         }
358         else
359         {
360             mRemovedUniformsCount++;
361         }
362     }
363 
364     // Extracts samplers from a field of a struct. Works with nested structs and arrays.
extractFieldSamplers(const ImmutableString & prefix,const TField * field,const TType & containingType,TIntermSequence * newSequence)365     size_t extractFieldSamplers(const ImmutableString &prefix,
366                                 const TField *field,
367                                 const TType &containingType,
368                                 TIntermSequence *newSequence)
369     {
370         if (containingType.isArray())
371         {
372             size_t nonSamplerCount = 0;
373 
374             // Name the samplers internally as varName_<index>_fieldName
375             const TSpan<const unsigned int> &arraySizes = containingType.getArraySizes();
376             for (unsigned int arrayElement = 0; arrayElement < arraySizes[0]; ++arrayElement)
377             {
378                 ImmutableStringBuilder stringBuilder(prefix.length() + kHexSize + 1);
379                 stringBuilder << prefix << "_";
380                 stringBuilder.appendHex(arrayElement);
381                 nonSamplerCount = extractFieldSamplersImpl(stringBuilder, field, newSequence);
382             }
383 
384             return nonSamplerCount;
385         }
386 
387         return extractFieldSamplersImpl(prefix, field, newSequence);
388     }
389 
390     // Extracts samplers from a field of a struct. Works with nested structs and arrays.
extractFieldSamplersImpl(const ImmutableString & prefix,const TField * field,TIntermSequence * newSequence)391     size_t extractFieldSamplersImpl(const ImmutableString &prefix,
392                                     const TField *field,
393                                     TIntermSequence *newSequence)
394     {
395         size_t nonSamplerCount = 0;
396 
397         const TType &fieldType = *field->type();
398         if (fieldType.isSampler() || fieldType.isStructureContainingSamplers())
399         {
400             ImmutableStringBuilder stringBuilder(prefix.length() + field->name().length() + 1);
401             stringBuilder << prefix << "_" << field->name();
402             ImmutableString newPrefix(stringBuilder);
403 
404             if (fieldType.isSampler())
405             {
406                 extractSampler(newPrefix, fieldType, newSequence);
407             }
408             else
409             {
410                 const TStructure *structure = fieldType.getStruct();
411                 for (const TField *nestedField : structure->fields())
412                 {
413                     nonSamplerCount +=
414                         extractFieldSamplers(newPrefix, nestedField, fieldType, newSequence);
415                 }
416             }
417         }
418         else
419         {
420             nonSamplerCount++;
421         }
422 
423         return nonSamplerCount;
424     }
425 
426     // Extracts a sampler from a struct. Declares the new extracted sampler.
extractSampler(const ImmutableString & newName,const TType & fieldType,TIntermSequence * newSequence) const427     void extractSampler(const ImmutableString &newName,
428                         const TType &fieldType,
429                         TIntermSequence *newSequence) const
430     {
431         TType *newType = new TType(fieldType);
432         newType->setQualifier(EvqUniform);
433         TVariable *newVariable =
434             new TVariable(mSymbolTable, newName, newType, SymbolType::AngleInternal);
435         TIntermSymbol *newRef = new TIntermSymbol(newVariable);
436 
437         TIntermDeclaration *samplerDecl = new TIntermDeclaration;
438         samplerDecl->appendDeclarator(newRef);
439 
440         newSequence->push_back(samplerDecl);
441 
442         mSymbolTable->declareInternal(newVariable);
443     }
444 
445     // Returns the chained name of a sampler uniform field.
GetFieldName(const ImmutableString & paramName,const TField * field,unsigned arrayIndex)446     static ImmutableString GetFieldName(const ImmutableString &paramName,
447                                         const TField *field,
448                                         unsigned arrayIndex)
449     {
450         ImmutableStringBuilder nameBuilder(paramName.length() + kHexSize + 2 +
451                                            field->name().length());
452         nameBuilder << paramName << "_";
453 
454         if (arrayIndex < std::numeric_limits<unsigned>::max())
455         {
456             nameBuilder.appendHex(arrayIndex);
457             nameBuilder << "_";
458         }
459         nameBuilder << field->name();
460 
461         return nameBuilder;
462     }
463 
464     // A pattern that visits every parameter of a function call. Uses different handlers for struct
465     // parameters, struct sampler parameters, and non-struct parameters.
466     class StructSamplerFunctionVisitor : angle::NonCopyable
467     {
468       public:
469         StructSamplerFunctionVisitor()          = default;
470         virtual ~StructSamplerFunctionVisitor() = default;
471 
traverse(const TFunction * function)472         virtual void traverse(const TFunction *function)
473         {
474             size_t paramCount = function->getParamCount();
475 
476             for (size_t paramIndex = 0; paramIndex < paramCount; ++paramIndex)
477             {
478                 const TVariable *param = function->getParam(paramIndex);
479                 const TType &paramType = param->getType();
480 
481                 if (paramType.isStructureContainingSamplers())
482                 {
483                     const ImmutableString &baseName = getNameFromIndex(function, paramIndex);
484                     if (traverseStructContainingSamplers(baseName, paramType))
485                     {
486                         visitStructParam(function, paramIndex);
487                     }
488                 }
489                 else
490                 {
491                     visitNonStructParam(function, paramIndex);
492                 }
493             }
494         }
495 
496         virtual ImmutableString getNameFromIndex(const TFunction *function, size_t paramIndex) = 0;
497         virtual void visitSamplerInStructParam(const ImmutableString &name,
498                                                const TField *field)                            = 0;
499         virtual void visitStructParam(const TFunction *function, size_t paramIndex)            = 0;
500         virtual void visitNonStructParam(const TFunction *function, size_t paramIndex)         = 0;
501 
502       private:
traverseStructContainingSamplers(const ImmutableString & baseName,const TType & structType)503         bool traverseStructContainingSamplers(const ImmutableString &baseName,
504                                               const TType &structType)
505         {
506             bool hasNonSamplerFields    = false;
507             const TStructure *structure = structType.getStruct();
508             for (const TField *field : structure->fields())
509             {
510                 if (field->type()->isStructureContainingSamplers() || field->type()->isSampler())
511                 {
512                     if (traverseSamplerInStruct(baseName, structType, field))
513                     {
514                         hasNonSamplerFields = true;
515                     }
516                 }
517                 else
518                 {
519                     hasNonSamplerFields = true;
520                 }
521             }
522             return hasNonSamplerFields;
523         }
524 
traverseSamplerInStruct(const ImmutableString & baseName,const TType & baseType,const TField * field)525         bool traverseSamplerInStruct(const ImmutableString &baseName,
526                                      const TType &baseType,
527                                      const TField *field)
528         {
529             bool hasNonSamplerParams = false;
530 
531             if (baseType.isArray())
532             {
533                 const TSpan<const unsigned int> &arraySizes = baseType.getArraySizes();
534                 ASSERT(arraySizes.size() == 1);
535 
536                 for (unsigned int arrayIndex = 0; arrayIndex < arraySizes[0]; ++arrayIndex)
537                 {
538                     ImmutableString name = GetFieldName(baseName, field, arrayIndex);
539 
540                     if (field->type()->isStructureContainingSamplers())
541                     {
542                         if (traverseStructContainingSamplers(name, *field->type()))
543                         {
544                             hasNonSamplerParams = true;
545                         }
546                     }
547                     else
548                     {
549                         ASSERT(field->type()->isSampler());
550                         visitSamplerInStructParam(name, field);
551                     }
552                 }
553             }
554             else if (field->type()->isStructureContainingSamplers())
555             {
556                 ImmutableString name =
557                     GetFieldName(baseName, field, std::numeric_limits<unsigned>::max());
558                 hasNonSamplerParams = traverseStructContainingSamplers(name, *field->type());
559             }
560             else
561             {
562                 ASSERT(field->type()->isSampler());
563                 ImmutableString name =
564                     GetFieldName(baseName, field, std::numeric_limits<unsigned>::max());
565                 visitSamplerInStructParam(name, field);
566             }
567 
568             return hasNonSamplerParams;
569         }
570     };
571 
572     // A visitor that replaces functions with struct sampler references. The struct sampler
573     // references are expanded to include new fields for the structs.
574     class CreateStructSamplerFunctionVisitor final : public StructSamplerFunctionVisitor
575     {
576       public:
CreateStructSamplerFunctionVisitor(TSymbolTable * symbolTable)577         CreateStructSamplerFunctionVisitor(TSymbolTable *symbolTable)
578             : mSymbolTable(symbolTable), mNewFunction(nullptr)
579         {}
580 
getNameFromIndex(const TFunction * function,size_t paramIndex)581         ImmutableString getNameFromIndex(const TFunction *function, size_t paramIndex) override
582         {
583             const TVariable *param = function->getParam(paramIndex);
584             return param->name();
585         }
586 
traverse(const TFunction * function)587         void traverse(const TFunction *function) override
588         {
589             mNewFunction =
590                 new TFunction(mSymbolTable, function->name(), function->symbolType(),
591                               &function->getReturnType(), function->isKnownToNotHaveSideEffects());
592 
593             StructSamplerFunctionVisitor::traverse(function);
594         }
595 
visitSamplerInStructParam(const ImmutableString & name,const TField * field)596         void visitSamplerInStructParam(const ImmutableString &name, const TField *field) override
597         {
598             TVariable *fieldSampler =
599                 new TVariable(mSymbolTable, name, field->type(), SymbolType::AngleInternal);
600             mNewFunction->addParameter(fieldSampler);
601             mSymbolTable->declareInternal(fieldSampler);
602         }
603 
visitStructParam(const TFunction * function,size_t paramIndex)604         void visitStructParam(const TFunction *function, size_t paramIndex) override
605         {
606             const TVariable *param = function->getParam(paramIndex);
607             TType *structType      = GetStructSamplerParameterType(mSymbolTable, *param);
608             TVariable *newParam =
609                 new TVariable(mSymbolTable, param->name(), structType, param->symbolType());
610             mNewFunction->addParameter(newParam);
611         }
612 
visitNonStructParam(const TFunction * function,size_t paramIndex)613         void visitNonStructParam(const TFunction *function, size_t paramIndex) override
614         {
615             const TVariable *param = function->getParam(paramIndex);
616             mNewFunction->addParameter(param);
617         }
618 
getNewFunction() const619         TFunction *getNewFunction() const { return mNewFunction; }
620 
621       private:
622         TSymbolTable *mSymbolTable;
623         TFunction *mNewFunction;
624     };
625 
createStructSamplerFunction(const TFunction * function) const626     TFunction *createStructSamplerFunction(const TFunction *function) const
627     {
628         CreateStructSamplerFunctionVisitor visitor(mSymbolTable);
629         visitor.traverse(function);
630         return visitor.getNewFunction();
631     }
632 
633     // A visitor that replaces function calls with expanded struct sampler parameters.
634     class GetSamplerArgumentsVisitor final : public StructSamplerFunctionVisitor
635     {
636       public:
GetSamplerArgumentsVisitor(TSymbolTable * symbolTable,const TIntermSequence * arguments)637         GetSamplerArgumentsVisitor(TSymbolTable *symbolTable, const TIntermSequence *arguments)
638             : mSymbolTable(symbolTable), mArguments(arguments), mNewArguments(new TIntermSequence)
639         {}
640 
getNameFromIndex(const TFunction * function,size_t paramIndex)641         ImmutableString getNameFromIndex(const TFunction *function, size_t paramIndex) override
642         {
643             TIntermTyped *argument = (*mArguments)[paramIndex]->getAsTyped();
644             return GetStructSamplerNameFromTypedNode(argument);
645         }
646 
visitSamplerInStructParam(const ImmutableString & name,const TField * field)647         void visitSamplerInStructParam(const ImmutableString &name, const TField *field) override
648         {
649             TVariable *argSampler =
650                 new TVariable(mSymbolTable, name, field->type(), SymbolType::AngleInternal);
651             TIntermSymbol *argSymbol = new TIntermSymbol(argSampler);
652             mNewArguments->push_back(argSymbol);
653         }
654 
visitStructParam(const TFunction * function,size_t paramIndex)655         void visitStructParam(const TFunction *function, size_t paramIndex) override
656         {
657             // The tree structure of the parameter is modified to point to the new type. This leaves
658             // the tree in a consistent state.
659             TIntermTyped *argument    = (*mArguments)[paramIndex]->getAsTyped();
660             TIntermTyped *replacement = ReplaceTypeOfTypedStructNode(argument, mSymbolTable);
661             mNewArguments->push_back(replacement);
662         }
663 
visitNonStructParam(const TFunction * function,size_t paramIndex)664         void visitNonStructParam(const TFunction *function, size_t paramIndex) override
665         {
666             TIntermTyped *argument = (*mArguments)[paramIndex]->getAsTyped();
667             mNewArguments->push_back(argument);
668         }
669 
getNewArguments() const670         TIntermSequence *getNewArguments() const { return mNewArguments; }
671 
672       private:
673         TSymbolTable *mSymbolTable;
674         const TIntermSequence *mArguments;
675         TIntermSequence *mNewArguments;
676     };
677 
getStructSamplerArguments(const TFunction * function,const TIntermSequence * arguments) const678     TIntermSequence *getStructSamplerArguments(const TFunction *function,
679                                                const TIntermSequence *arguments) const
680     {
681         GetSamplerArgumentsVisitor visitor(mSymbolTable, arguments);
682         visitor.traverse(function);
683         return visitor.getNewArguments();
684     }
685 
686     int mRemovedUniformsCount;
687     std::set<ImmutableString> mRemovedStructs;
688 };
689 }  // anonymous namespace
690 
RewriteStructSamplersOld(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,int * removedUniformsCountOut)691 bool RewriteStructSamplersOld(TCompiler *compiler,
692                               TIntermBlock *root,
693                               TSymbolTable *symbolTable,
694                               int *removedUniformsCountOut)
695 {
696     Traverser rewriteStructSamplers(symbolTable);
697     root->traverse(&rewriteStructSamplers);
698     if (!rewriteStructSamplers.updateTree(compiler, root))
699     {
700         return false;
701     }
702     *removedUniformsCountOut = rewriteStructSamplers.removedUniformsCount();
703     return true;
704 }
705 }  // namespace sh
706