1 //
2 // Copyright 2018 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RewriteStructSamplers: Extract structs from samplers.
7 //
8
9 #include "compiler/translator/tree_ops/RewriteStructSamplers.h"
10
11 #include "compiler/translator/ImmutableStringBuilder.h"
12 #include "compiler/translator/SymbolTable.h"
13 #include "compiler/translator/tree_util/IntermTraverse.h"
14
15 namespace sh
16 {
17 namespace
18 {
19 // Helper method to get the sampler extracted struct type of a parameter.
GetStructSamplerParameterType(TSymbolTable * symbolTable,const TVariable & param)20 TType *GetStructSamplerParameterType(TSymbolTable *symbolTable, const TVariable ¶m)
21 {
22 const TStructure *structure = param.getType().getStruct();
23 const TSymbol *structSymbol = symbolTable->findUserDefined(structure->name());
24 ASSERT(structSymbol && structSymbol->isStruct());
25 const TStructure *structVar = static_cast<const TStructure *>(structSymbol);
26 TType *structType = new TType(structVar, false);
27
28 if (param.getType().isArray())
29 {
30 structType->makeArrays(param.getType().getArraySizes());
31 }
32
33 ASSERT(!structType->isStructureContainingSamplers());
34
35 return structType;
36 }
37
ReplaceTypeOfSymbolNode(TIntermSymbol * symbolNode,TSymbolTable * symbolTable)38 TIntermSymbol *ReplaceTypeOfSymbolNode(TIntermSymbol *symbolNode, TSymbolTable *symbolTable)
39 {
40 const TVariable &oldVariable = symbolNode->variable();
41
42 TType *newType = GetStructSamplerParameterType(symbolTable, oldVariable);
43
44 TVariable *newVariable =
45 new TVariable(oldVariable.uniqueId(), oldVariable.name(), oldVariable.symbolType(),
46 oldVariable.extension(), newType);
47 return new TIntermSymbol(newVariable);
48 }
49
ReplaceTypeOfTypedStructNode(TIntermTyped * argument,TSymbolTable * symbolTable)50 TIntermTyped *ReplaceTypeOfTypedStructNode(TIntermTyped *argument, TSymbolTable *symbolTable)
51 {
52 TIntermSymbol *asSymbol = argument->getAsSymbolNode();
53 if (asSymbol)
54 {
55 ASSERT(asSymbol->getType().getStruct());
56 return ReplaceTypeOfSymbolNode(asSymbol, symbolTable);
57 }
58
59 TIntermTyped *replacement = argument->deepCopy();
60 TIntermBinary *binary = replacement->getAsBinaryNode();
61 ASSERT(binary);
62
63 while (binary)
64 {
65 ASSERT(binary->getOp() == EOpIndexDirectStruct || binary->getOp() == EOpIndexDirect);
66
67 asSymbol = binary->getLeft()->getAsSymbolNode();
68
69 if (asSymbol)
70 {
71 ASSERT(asSymbol->getType().getStruct());
72 TIntermSymbol *newSymbol = ReplaceTypeOfSymbolNode(asSymbol, symbolTable);
73 binary->replaceChildNode(binary->getLeft(), newSymbol);
74 return replacement;
75 }
76
77 binary = binary->getLeft()->getAsBinaryNode();
78 }
79
80 UNREACHABLE();
81 return nullptr;
82 }
83
84 // Maximum string size of a hex unsigned int.
85 constexpr size_t kHexSize = ImmutableStringBuilder::GetHexCharCount<unsigned int>();
86
87 class Traverser final : public TIntermTraverser
88 {
89 public:
Traverser(TSymbolTable * symbolTable)90 explicit Traverser(TSymbolTable *symbolTable)
91 : TIntermTraverser(true, false, true, symbolTable), mRemovedUniformsCount(0)
92 {
93 mSymbolTable->push();
94 }
95
~Traverser()96 ~Traverser() override { mSymbolTable->pop(); }
97
removedUniformsCount() const98 int removedUniformsCount() const { return mRemovedUniformsCount; }
99
100 // Each struct sampler declaration is stripped of its samplers. New uniforms are added for each
101 // stripped struct sampler.
visitDeclaration(Visit visit,TIntermDeclaration * decl)102 bool visitDeclaration(Visit visit, TIntermDeclaration *decl) override
103 {
104 if (visit != PreVisit)
105 return true;
106
107 if (!mInGlobalScope)
108 {
109 return true;
110 }
111
112 const TIntermSequence &sequence = *(decl->getSequence());
113 TIntermTyped *declarator = sequence.front()->getAsTyped();
114 const TType &type = declarator->getType();
115
116 if (type.isStructureContainingSamplers())
117 {
118 TIntermSequence *newSequence = new TIntermSequence;
119
120 if (type.isStructSpecifier())
121 {
122 stripStructSpecifierSamplers(type.getStruct(), newSequence);
123 }
124 else
125 {
126 TIntermSymbol *asSymbol = declarator->getAsSymbolNode();
127 ASSERT(asSymbol);
128 const TVariable &variable = asSymbol->variable();
129 ASSERT(variable.symbolType() != SymbolType::Empty);
130 extractStructSamplerUniforms(decl, variable, type.getStruct(), newSequence);
131 }
132
133 mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), decl, *newSequence);
134 }
135
136 return true;
137 }
138
139 // Each struct sampler reference is replaced with a reference to the new extracted sampler.
visitBinary(Visit visit,TIntermBinary * node)140 bool visitBinary(Visit visit, TIntermBinary *node) override
141 {
142 if (visit != PreVisit)
143 return true;
144
145 if (node->getOp() == EOpIndexDirectStruct && node->getType().isSampler())
146 {
147 ImmutableString newName = GetStructSamplerNameFromTypedNode(node);
148 const TVariable *samplerReplacement =
149 static_cast<const TVariable *>(mSymbolTable->findUserDefined(newName));
150 ASSERT(samplerReplacement);
151
152 TIntermSymbol *replacement = new TIntermSymbol(samplerReplacement);
153
154 queueReplacement(replacement, OriginalNode::IS_DROPPED);
155 return true;
156 }
157
158 return true;
159 }
160
161 // In we are passing references to structs containing samplers we must new additional
162 // arguments. For each extracted struct sampler a new argument is added. This chains to nested
163 // structs.
visitFunctionPrototype(TIntermFunctionPrototype * node)164 void visitFunctionPrototype(TIntermFunctionPrototype *node) override
165 {
166 const TFunction *function = node->getFunction();
167
168 if (!function->hasSamplerInStructOrArrayOfArrayParams())
169 {
170 return;
171 }
172
173 const TSymbol *foundFunction = mSymbolTable->findUserDefined(function->name());
174 if (foundFunction)
175 {
176 ASSERT(foundFunction->isFunction());
177 function = static_cast<const TFunction *>(foundFunction);
178 }
179 else
180 {
181 TFunction *newFunction = createStructSamplerFunction(function);
182 mSymbolTable->declareUserDefinedFunction(newFunction, true);
183 function = newFunction;
184 }
185
186 ASSERT(!function->hasSamplerInStructOrArrayOfArrayParams());
187 TIntermFunctionPrototype *newProto = new TIntermFunctionPrototype(function);
188 queueReplacement(newProto, OriginalNode::IS_DROPPED);
189 }
190
191 // We insert a new scope for each function definition so we can track the new parameters.
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)192 bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override
193 {
194 if (visit == PreVisit)
195 {
196 mSymbolTable->push();
197 }
198 else
199 {
200 ASSERT(visit == PostVisit);
201 mSymbolTable->pop();
202 }
203 return true;
204 }
205
206 // For function call nodes we pass references to the extracted struct samplers in that scope.
visitAggregate(Visit visit,TIntermAggregate * node)207 bool visitAggregate(Visit visit, TIntermAggregate *node) override
208 {
209 if (visit != PreVisit)
210 return true;
211
212 if (!node->isFunctionCall())
213 return true;
214
215 const TFunction *function = node->getFunction();
216 if (!function->hasSamplerInStructOrArrayOfArrayParams())
217 return true;
218
219 ASSERT(node->getOp() == EOpCallFunctionInAST);
220 TFunction *newFunction = mSymbolTable->findUserDefinedFunction(function->name());
221 TIntermSequence *newArguments = getStructSamplerArguments(function, node->getSequence());
222
223 TIntermAggregate *newCall =
224 TIntermAggregate::CreateFunctionCall(*newFunction, newArguments);
225 queueReplacement(newCall, OriginalNode::IS_DROPPED);
226 return true;
227 }
228
229 private:
230 // This returns the name of a struct sampler reference. References are always TIntermBinary.
GetStructSamplerNameFromTypedNode(TIntermTyped * node)231 static ImmutableString GetStructSamplerNameFromTypedNode(TIntermTyped *node)
232 {
233 std::string stringBuilder;
234
235 TIntermTyped *currentNode = node;
236 while (currentNode->getAsBinaryNode())
237 {
238 TIntermBinary *asBinary = currentNode->getAsBinaryNode();
239
240 switch (asBinary->getOp())
241 {
242 case EOpIndexDirect:
243 {
244 const int index = asBinary->getRight()->getAsConstantUnion()->getIConst(0);
245 const std::string strInt = Str(index);
246 stringBuilder.insert(0, strInt);
247 stringBuilder.insert(0, "_");
248 break;
249 }
250 case EOpIndexDirectStruct:
251 {
252 stringBuilder.insert(0, asBinary->getIndexStructFieldName().data());
253 stringBuilder.insert(0, "_");
254 break;
255 }
256
257 default:
258 UNREACHABLE();
259 break;
260 }
261
262 currentNode = asBinary->getLeft();
263 }
264
265 const ImmutableString &variableName = currentNode->getAsSymbolNode()->variable().name();
266 stringBuilder.insert(0, variableName.data());
267
268 return stringBuilder;
269 }
270
271 // Removes all the struct samplers from a struct specifier.
stripStructSpecifierSamplers(const TStructure * structure,TIntermSequence * newSequence)272 void stripStructSpecifierSamplers(const TStructure *structure, TIntermSequence *newSequence)
273 {
274 TFieldList *newFieldList = new TFieldList;
275 ASSERT(structure->containsSamplers());
276
277 for (const TField *field : structure->fields())
278 {
279 const TType &fieldType = *field->type();
280 if (!fieldType.isSampler() && !isRemovedStructType(fieldType))
281 {
282 TType *newType = nullptr;
283
284 if (fieldType.isStructureContainingSamplers())
285 {
286 const TSymbol *structSymbol =
287 mSymbolTable->findUserDefined(fieldType.getStruct()->name());
288 ASSERT(structSymbol && structSymbol->isStruct());
289 const TStructure *fieldStruct = static_cast<const TStructure *>(structSymbol);
290 newType = new TType(fieldStruct, true);
291 if (fieldType.isArray())
292 {
293 newType->makeArrays(fieldType.getArraySizes());
294 }
295 }
296 else
297 {
298 newType = new TType(fieldType);
299 }
300
301 TField *newField =
302 new TField(newType, field->name(), field->line(), field->symbolType());
303 newFieldList->push_back(newField);
304 }
305 }
306
307 // Prune empty structs.
308 if (newFieldList->empty())
309 {
310 mRemovedStructs.insert(structure->name());
311 return;
312 }
313
314 TStructure *newStruct =
315 new TStructure(mSymbolTable, structure->name(), newFieldList, structure->symbolType());
316 TType *newStructType = new TType(newStruct, true);
317 TVariable *newStructVar =
318 new TVariable(mSymbolTable, kEmptyImmutableString, newStructType, SymbolType::Empty);
319 TIntermSymbol *newStructRef = new TIntermSymbol(newStructVar);
320
321 TIntermDeclaration *structDecl = new TIntermDeclaration;
322 structDecl->appendDeclarator(newStructRef);
323
324 newSequence->push_back(structDecl);
325
326 mSymbolTable->declare(newStruct);
327 }
328
329 // Returns true if the type is a struct that was removed because we extracted all the members.
isRemovedStructType(const TType & type) const330 bool isRemovedStructType(const TType &type) const
331 {
332 const TStructure *structure = type.getStruct();
333 return (structure && (mRemovedStructs.count(structure->name()) > 0));
334 }
335
336 // Removes samplers from struct uniforms. For each sampler removed also adds a new globally
337 // defined sampler uniform.
extractStructSamplerUniforms(TIntermDeclaration * oldDeclaration,const TVariable & variable,const TStructure * structure,TIntermSequence * newSequence)338 void extractStructSamplerUniforms(TIntermDeclaration *oldDeclaration,
339 const TVariable &variable,
340 const TStructure *structure,
341 TIntermSequence *newSequence)
342 {
343 ASSERT(structure->containsSamplers());
344
345 size_t nonSamplerCount = 0;
346
347 for (const TField *field : structure->fields())
348 {
349 nonSamplerCount +=
350 extractFieldSamplers(variable.name(), field, variable.getType(), newSequence);
351 }
352
353 if (nonSamplerCount > 0)
354 {
355 // Keep the old declaration around if it has other members.
356 newSequence->push_back(oldDeclaration);
357 }
358 else
359 {
360 mRemovedUniformsCount++;
361 }
362 }
363
364 // Extracts samplers from a field of a struct. Works with nested structs and arrays.
extractFieldSamplers(const ImmutableString & prefix,const TField * field,const TType & containingType,TIntermSequence * newSequence)365 size_t extractFieldSamplers(const ImmutableString &prefix,
366 const TField *field,
367 const TType &containingType,
368 TIntermSequence *newSequence)
369 {
370 if (containingType.isArray())
371 {
372 size_t nonSamplerCount = 0;
373
374 // Name the samplers internally as varName_<index>_fieldName
375 const TSpan<const unsigned int> &arraySizes = containingType.getArraySizes();
376 for (unsigned int arrayElement = 0; arrayElement < arraySizes[0]; ++arrayElement)
377 {
378 ImmutableStringBuilder stringBuilder(prefix.length() + kHexSize + 1);
379 stringBuilder << prefix << "_";
380 stringBuilder.appendHex(arrayElement);
381 nonSamplerCount = extractFieldSamplersImpl(stringBuilder, field, newSequence);
382 }
383
384 return nonSamplerCount;
385 }
386
387 return extractFieldSamplersImpl(prefix, field, newSequence);
388 }
389
390 // Extracts samplers from a field of a struct. Works with nested structs and arrays.
extractFieldSamplersImpl(const ImmutableString & prefix,const TField * field,TIntermSequence * newSequence)391 size_t extractFieldSamplersImpl(const ImmutableString &prefix,
392 const TField *field,
393 TIntermSequence *newSequence)
394 {
395 size_t nonSamplerCount = 0;
396
397 const TType &fieldType = *field->type();
398 if (fieldType.isSampler() || fieldType.isStructureContainingSamplers())
399 {
400 ImmutableStringBuilder stringBuilder(prefix.length() + field->name().length() + 1);
401 stringBuilder << prefix << "_" << field->name();
402 ImmutableString newPrefix(stringBuilder);
403
404 if (fieldType.isSampler())
405 {
406 extractSampler(newPrefix, fieldType, newSequence);
407 }
408 else
409 {
410 const TStructure *structure = fieldType.getStruct();
411 for (const TField *nestedField : structure->fields())
412 {
413 nonSamplerCount +=
414 extractFieldSamplers(newPrefix, nestedField, fieldType, newSequence);
415 }
416 }
417 }
418 else
419 {
420 nonSamplerCount++;
421 }
422
423 return nonSamplerCount;
424 }
425
426 // Extracts a sampler from a struct. Declares the new extracted sampler.
extractSampler(const ImmutableString & newName,const TType & fieldType,TIntermSequence * newSequence) const427 void extractSampler(const ImmutableString &newName,
428 const TType &fieldType,
429 TIntermSequence *newSequence) const
430 {
431 TType *newType = new TType(fieldType);
432 newType->setQualifier(EvqUniform);
433 TVariable *newVariable =
434 new TVariable(mSymbolTable, newName, newType, SymbolType::AngleInternal);
435 TIntermSymbol *newRef = new TIntermSymbol(newVariable);
436
437 TIntermDeclaration *samplerDecl = new TIntermDeclaration;
438 samplerDecl->appendDeclarator(newRef);
439
440 newSequence->push_back(samplerDecl);
441
442 mSymbolTable->declareInternal(newVariable);
443 }
444
445 // Returns the chained name of a sampler uniform field.
GetFieldName(const ImmutableString & paramName,const TField * field,unsigned arrayIndex)446 static ImmutableString GetFieldName(const ImmutableString ¶mName,
447 const TField *field,
448 unsigned arrayIndex)
449 {
450 ImmutableStringBuilder nameBuilder(paramName.length() + kHexSize + 2 +
451 field->name().length());
452 nameBuilder << paramName << "_";
453
454 if (arrayIndex < std::numeric_limits<unsigned>::max())
455 {
456 nameBuilder.appendHex(arrayIndex);
457 nameBuilder << "_";
458 }
459 nameBuilder << field->name();
460
461 return nameBuilder;
462 }
463
464 // A pattern that visits every parameter of a function call. Uses different handlers for struct
465 // parameters, struct sampler parameters, and non-struct parameters.
466 class StructSamplerFunctionVisitor : angle::NonCopyable
467 {
468 public:
469 StructSamplerFunctionVisitor() = default;
470 virtual ~StructSamplerFunctionVisitor() = default;
471
traverse(const TFunction * function)472 virtual void traverse(const TFunction *function)
473 {
474 size_t paramCount = function->getParamCount();
475
476 for (size_t paramIndex = 0; paramIndex < paramCount; ++paramIndex)
477 {
478 const TVariable *param = function->getParam(paramIndex);
479 const TType ¶mType = param->getType();
480
481 if (paramType.isStructureContainingSamplers())
482 {
483 const ImmutableString &baseName = getNameFromIndex(function, paramIndex);
484 if (traverseStructContainingSamplers(baseName, paramType))
485 {
486 visitStructParam(function, paramIndex);
487 }
488 }
489 else
490 {
491 visitNonStructParam(function, paramIndex);
492 }
493 }
494 }
495
496 virtual ImmutableString getNameFromIndex(const TFunction *function, size_t paramIndex) = 0;
497 virtual void visitSamplerInStructParam(const ImmutableString &name,
498 const TField *field) = 0;
499 virtual void visitStructParam(const TFunction *function, size_t paramIndex) = 0;
500 virtual void visitNonStructParam(const TFunction *function, size_t paramIndex) = 0;
501
502 private:
traverseStructContainingSamplers(const ImmutableString & baseName,const TType & structType)503 bool traverseStructContainingSamplers(const ImmutableString &baseName,
504 const TType &structType)
505 {
506 bool hasNonSamplerFields = false;
507 const TStructure *structure = structType.getStruct();
508 for (const TField *field : structure->fields())
509 {
510 if (field->type()->isStructureContainingSamplers() || field->type()->isSampler())
511 {
512 if (traverseSamplerInStruct(baseName, structType, field))
513 {
514 hasNonSamplerFields = true;
515 }
516 }
517 else
518 {
519 hasNonSamplerFields = true;
520 }
521 }
522 return hasNonSamplerFields;
523 }
524
traverseSamplerInStruct(const ImmutableString & baseName,const TType & baseType,const TField * field)525 bool traverseSamplerInStruct(const ImmutableString &baseName,
526 const TType &baseType,
527 const TField *field)
528 {
529 bool hasNonSamplerParams = false;
530
531 if (baseType.isArray())
532 {
533 const TSpan<const unsigned int> &arraySizes = baseType.getArraySizes();
534 ASSERT(arraySizes.size() == 1);
535
536 for (unsigned int arrayIndex = 0; arrayIndex < arraySizes[0]; ++arrayIndex)
537 {
538 ImmutableString name = GetFieldName(baseName, field, arrayIndex);
539
540 if (field->type()->isStructureContainingSamplers())
541 {
542 if (traverseStructContainingSamplers(name, *field->type()))
543 {
544 hasNonSamplerParams = true;
545 }
546 }
547 else
548 {
549 ASSERT(field->type()->isSampler());
550 visitSamplerInStructParam(name, field);
551 }
552 }
553 }
554 else if (field->type()->isStructureContainingSamplers())
555 {
556 ImmutableString name =
557 GetFieldName(baseName, field, std::numeric_limits<unsigned>::max());
558 hasNonSamplerParams = traverseStructContainingSamplers(name, *field->type());
559 }
560 else
561 {
562 ASSERT(field->type()->isSampler());
563 ImmutableString name =
564 GetFieldName(baseName, field, std::numeric_limits<unsigned>::max());
565 visitSamplerInStructParam(name, field);
566 }
567
568 return hasNonSamplerParams;
569 }
570 };
571
572 // A visitor that replaces functions with struct sampler references. The struct sampler
573 // references are expanded to include new fields for the structs.
574 class CreateStructSamplerFunctionVisitor final : public StructSamplerFunctionVisitor
575 {
576 public:
CreateStructSamplerFunctionVisitor(TSymbolTable * symbolTable)577 CreateStructSamplerFunctionVisitor(TSymbolTable *symbolTable)
578 : mSymbolTable(symbolTable), mNewFunction(nullptr)
579 {}
580
getNameFromIndex(const TFunction * function,size_t paramIndex)581 ImmutableString getNameFromIndex(const TFunction *function, size_t paramIndex) override
582 {
583 const TVariable *param = function->getParam(paramIndex);
584 return param->name();
585 }
586
traverse(const TFunction * function)587 void traverse(const TFunction *function) override
588 {
589 mNewFunction =
590 new TFunction(mSymbolTable, function->name(), function->symbolType(),
591 &function->getReturnType(), function->isKnownToNotHaveSideEffects());
592
593 StructSamplerFunctionVisitor::traverse(function);
594 }
595
visitSamplerInStructParam(const ImmutableString & name,const TField * field)596 void visitSamplerInStructParam(const ImmutableString &name, const TField *field) override
597 {
598 TVariable *fieldSampler =
599 new TVariable(mSymbolTable, name, field->type(), SymbolType::AngleInternal);
600 mNewFunction->addParameter(fieldSampler);
601 mSymbolTable->declareInternal(fieldSampler);
602 }
603
visitStructParam(const TFunction * function,size_t paramIndex)604 void visitStructParam(const TFunction *function, size_t paramIndex) override
605 {
606 const TVariable *param = function->getParam(paramIndex);
607 TType *structType = GetStructSamplerParameterType(mSymbolTable, *param);
608 TVariable *newParam =
609 new TVariable(mSymbolTable, param->name(), structType, param->symbolType());
610 mNewFunction->addParameter(newParam);
611 }
612
visitNonStructParam(const TFunction * function,size_t paramIndex)613 void visitNonStructParam(const TFunction *function, size_t paramIndex) override
614 {
615 const TVariable *param = function->getParam(paramIndex);
616 mNewFunction->addParameter(param);
617 }
618
getNewFunction() const619 TFunction *getNewFunction() const { return mNewFunction; }
620
621 private:
622 TSymbolTable *mSymbolTable;
623 TFunction *mNewFunction;
624 };
625
createStructSamplerFunction(const TFunction * function) const626 TFunction *createStructSamplerFunction(const TFunction *function) const
627 {
628 CreateStructSamplerFunctionVisitor visitor(mSymbolTable);
629 visitor.traverse(function);
630 return visitor.getNewFunction();
631 }
632
633 // A visitor that replaces function calls with expanded struct sampler parameters.
634 class GetSamplerArgumentsVisitor final : public StructSamplerFunctionVisitor
635 {
636 public:
GetSamplerArgumentsVisitor(TSymbolTable * symbolTable,const TIntermSequence * arguments)637 GetSamplerArgumentsVisitor(TSymbolTable *symbolTable, const TIntermSequence *arguments)
638 : mSymbolTable(symbolTable), mArguments(arguments), mNewArguments(new TIntermSequence)
639 {}
640
getNameFromIndex(const TFunction * function,size_t paramIndex)641 ImmutableString getNameFromIndex(const TFunction *function, size_t paramIndex) override
642 {
643 TIntermTyped *argument = (*mArguments)[paramIndex]->getAsTyped();
644 return GetStructSamplerNameFromTypedNode(argument);
645 }
646
visitSamplerInStructParam(const ImmutableString & name,const TField * field)647 void visitSamplerInStructParam(const ImmutableString &name, const TField *field) override
648 {
649 TVariable *argSampler =
650 new TVariable(mSymbolTable, name, field->type(), SymbolType::AngleInternal);
651 TIntermSymbol *argSymbol = new TIntermSymbol(argSampler);
652 mNewArguments->push_back(argSymbol);
653 }
654
visitStructParam(const TFunction * function,size_t paramIndex)655 void visitStructParam(const TFunction *function, size_t paramIndex) override
656 {
657 // The tree structure of the parameter is modified to point to the new type. This leaves
658 // the tree in a consistent state.
659 TIntermTyped *argument = (*mArguments)[paramIndex]->getAsTyped();
660 TIntermTyped *replacement = ReplaceTypeOfTypedStructNode(argument, mSymbolTable);
661 mNewArguments->push_back(replacement);
662 }
663
visitNonStructParam(const TFunction * function,size_t paramIndex)664 void visitNonStructParam(const TFunction *function, size_t paramIndex) override
665 {
666 TIntermTyped *argument = (*mArguments)[paramIndex]->getAsTyped();
667 mNewArguments->push_back(argument);
668 }
669
getNewArguments() const670 TIntermSequence *getNewArguments() const { return mNewArguments; }
671
672 private:
673 TSymbolTable *mSymbolTable;
674 const TIntermSequence *mArguments;
675 TIntermSequence *mNewArguments;
676 };
677
getStructSamplerArguments(const TFunction * function,const TIntermSequence * arguments) const678 TIntermSequence *getStructSamplerArguments(const TFunction *function,
679 const TIntermSequence *arguments) const
680 {
681 GetSamplerArgumentsVisitor visitor(mSymbolTable, arguments);
682 visitor.traverse(function);
683 return visitor.getNewArguments();
684 }
685
686 int mRemovedUniformsCount;
687 std::set<ImmutableString> mRemovedStructs;
688 };
689 } // anonymous namespace
690
RewriteStructSamplersOld(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,int * removedUniformsCountOut)691 bool RewriteStructSamplersOld(TCompiler *compiler,
692 TIntermBlock *root,
693 TSymbolTable *symbolTable,
694 int *removedUniformsCountOut)
695 {
696 Traverser rewriteStructSamplers(symbolTable);
697 root->traverse(&rewriteStructSamplers);
698 if (!rewriteStructSamplers.updateTree(compiler, root))
699 {
700 return false;
701 }
702 *removedUniformsCountOut = rewriteStructSamplers.removedUniformsCount();
703 return true;
704 }
705 } // namespace sh
706