1 //
2 // Copyright 2018 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RewriteStructSamplers: Extract structs from samplers.
7 //
8
9 #include "compiler/translator/tree_ops/RewriteStructSamplers.h"
10
11 #include "compiler/translator/ImmutableStringBuilder.h"
12 #include "compiler/translator/StaticType.h"
13 #include "compiler/translator/SymbolTable.h"
14 #include "compiler/translator/tree_util/IntermNode_util.h"
15 #include "compiler/translator/tree_util/IntermTraverse.h"
16
17 namespace sh
18 {
19 namespace
20 {
21 // Helper method to get the sampler extracted struct type of a parameter.
GetStructSamplerParameterType(TSymbolTable * symbolTable,const TVariable & param)22 TType *GetStructSamplerParameterType(TSymbolTable *symbolTable, const TVariable ¶m)
23 {
24 const TStructure *structure = param.getType().getStruct();
25 const TSymbol *structSymbol = symbolTable->findUserDefined(structure->name());
26 ASSERT(structSymbol && structSymbol->isStruct());
27 const TStructure *structVar = static_cast<const TStructure *>(structSymbol);
28 TType *structType = new TType(structVar, false);
29
30 if (param.getType().isArray())
31 {
32 structType->makeArrays(param.getType().getArraySizes());
33 }
34
35 ASSERT(!structType->isStructureContainingSamplers());
36
37 return structType;
38 }
39
ReplaceTypeOfSymbolNode(TIntermSymbol * symbolNode,TSymbolTable * symbolTable)40 TIntermSymbol *ReplaceTypeOfSymbolNode(TIntermSymbol *symbolNode, TSymbolTable *symbolTable)
41 {
42 const TVariable &oldVariable = symbolNode->variable();
43
44 TType *newType = GetStructSamplerParameterType(symbolTable, oldVariable);
45
46 TVariable *newVariable =
47 new TVariable(oldVariable.uniqueId(), oldVariable.name(), oldVariable.symbolType(),
48 oldVariable.extension(), newType);
49 return new TIntermSymbol(newVariable);
50 }
51
ReplaceTypeOfTypedStructNode(TIntermTyped * argument,TSymbolTable * symbolTable)52 TIntermTyped *ReplaceTypeOfTypedStructNode(TIntermTyped *argument, TSymbolTable *symbolTable)
53 {
54 TIntermSymbol *asSymbol = argument->getAsSymbolNode();
55 if (asSymbol)
56 {
57 ASSERT(asSymbol->getType().getStruct());
58 return ReplaceTypeOfSymbolNode(asSymbol, symbolTable);
59 }
60
61 TIntermTyped *replacement = argument->deepCopy();
62 TIntermBinary *binary = replacement->getAsBinaryNode();
63 ASSERT(binary);
64
65 while (binary)
66 {
67 ASSERT(binary->getOp() == EOpIndexDirectStruct || binary->getOp() == EOpIndexDirect);
68
69 asSymbol = binary->getLeft()->getAsSymbolNode();
70
71 if (asSymbol)
72 {
73 ASSERT(asSymbol->getType().getStruct());
74 TIntermSymbol *newSymbol = ReplaceTypeOfSymbolNode(asSymbol, symbolTable);
75 binary->replaceChildNode(binary->getLeft(), newSymbol);
76 return replacement;
77 }
78
79 binary = binary->getLeft()->getAsBinaryNode();
80 }
81
82 UNREACHABLE();
83 return nullptr;
84 }
85
GenerateArrayStrides(const std::vector<size_t> & arraySizes,std::vector<size_t> * arrayStridesOut)86 void GenerateArrayStrides(const std::vector<size_t> &arraySizes,
87 std::vector<size_t> *arrayStridesOut)
88 {
89 auto &strides = *arrayStridesOut;
90
91 ASSERT(strides.empty());
92 strides.reserve(arraySizes.size() + 1);
93
94 size_t currentStride = 1;
95 strides.push_back(1);
96 for (auto it = arraySizes.rbegin(); it != arraySizes.rend(); ++it)
97 {
98 currentStride *= *it;
99 strides.push_back(currentStride);
100 }
101 }
102
103 // This returns an expression representing the correct index using the array
104 // index operations in node.
GetIndexExpressionFromTypedNode(TIntermTyped * node,const std::vector<size_t> & strides,TIntermTyped * offset)105 static TIntermTyped *GetIndexExpressionFromTypedNode(TIntermTyped *node,
106 const std::vector<size_t> &strides,
107 TIntermTyped *offset)
108 {
109 TIntermTyped *result = offset;
110 TIntermTyped *currentNode = node;
111
112 auto it = strides.end();
113 --it;
114 // If this is being used as an argument, not all indices may be present;
115 // count how many indices are there.
116 while (currentNode->getAsBinaryNode())
117 {
118 TIntermBinary *asBinary = currentNode->getAsBinaryNode();
119
120 switch (asBinary->getOp())
121 {
122 case EOpIndexDirectStruct:
123 break;
124
125 case EOpIndexDirect:
126 case EOpIndexIndirect:
127 --it;
128 break;
129
130 default:
131 UNREACHABLE();
132 break;
133 }
134
135 currentNode = asBinary->getLeft();
136 }
137
138 currentNode = node;
139
140 while (currentNode->getAsBinaryNode())
141 {
142 TIntermBinary *asBinary = currentNode->getAsBinaryNode();
143
144 switch (asBinary->getOp())
145 {
146 case EOpIndexDirectStruct:
147 break;
148
149 case EOpIndexDirect:
150 case EOpIndexIndirect:
151 {
152 TIntermBinary *multiply =
153 new TIntermBinary(EOpMul, CreateIndexNode(static_cast<int>(*it++)),
154 asBinary->getRight()->deepCopy());
155 result = new TIntermBinary(EOpAdd, result, multiply);
156 break;
157 }
158
159 default:
160 UNREACHABLE();
161 break;
162 }
163
164 currentNode = asBinary->getLeft();
165 }
166
167 return result;
168 }
169
170 // Structures for keeping track of function instantiations.
171
172 // An instantiation is keyed by the flattened sizes of the sampler arrays.
173 typedef std::vector<size_t> Instantiation;
174
175 struct InstantiationHash
176 {
operator ()sh::__anon71dfe6900111::InstantiationHash177 size_t operator()(const Instantiation &v) const noexcept
178 {
179 std::hash<size_t> hasher;
180 size_t seed = 0;
181 for (size_t x : v)
182 {
183 seed ^= hasher(x) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
184 }
185 return seed;
186 }
187 };
188
189 // Map from each function to a "set" of instantiations.
190 // We store a TFunction for each instantiation as its value.
191 typedef std::map<ImmutableString, std::unordered_map<Instantiation, TFunction *, InstantiationHash>>
192 FunctionInstantiations;
193
194 typedef std::unordered_map<const TFunction *, const TFunction *> FunctionMap;
195
196 // Generates a new function from the given function using the given
197 // instantiation; generatedInstantiations can be null.
GenerateFunctionFromArguments(const TFunction * function,const TIntermSequence * arguments,TSymbolTable * symbolTable,FunctionInstantiations * functionInstantiations,FunctionMap * functionMap,const FunctionInstantiations * generatedInstantiations)198 TFunction *GenerateFunctionFromArguments(const TFunction *function,
199 const TIntermSequence *arguments,
200 TSymbolTable *symbolTable,
201 FunctionInstantiations *functionInstantiations,
202 FunctionMap *functionMap,
203 const FunctionInstantiations *generatedInstantiations)
204 {
205 // Collect sizes of array arguments.
206 Instantiation instantiation;
207 for (TIntermNode *node : *arguments)
208 {
209 const TType &type = node->getAsTyped()->getType();
210 if (type.isArray() && type.isSampler())
211 {
212 ASSERT(type.getNumArraySizes() == 1);
213 instantiation.push_back(type.getArraySizes()[0]);
214 }
215 }
216
217 if (generatedInstantiations)
218 {
219 auto it1 = generatedInstantiations->find(function->name());
220 if (it1 != generatedInstantiations->end())
221 {
222 const auto &map = it1->second;
223 auto it2 = map.find(instantiation);
224 if (it2 != map.end())
225 {
226 return it2->second;
227 }
228 }
229 }
230
231 TFunction **newFunction = &(*functionInstantiations)[function->name()][instantiation];
232
233 if (!*newFunction)
234 {
235 *newFunction =
236 new TFunction(symbolTable, kEmptyImmutableString, SymbolType::AngleInternal,
237 &function->getReturnType(), function->isKnownToNotHaveSideEffects());
238 (*functionMap)[*newFunction] = function;
239 // Insert parameters from updated function.
240 TFunction *updatedFunction = symbolTable->findUserDefinedFunction(function->name());
241 size_t paramCount = updatedFunction->getParamCount();
242 auto it = instantiation.begin();
243 for (size_t paramIndex = 0; paramIndex < paramCount; ++paramIndex)
244 {
245 const TVariable *param = updatedFunction->getParam(paramIndex);
246 const TType ¶mType = param->getType();
247 if (paramType.isArray() && paramType.isSampler())
248 {
249 TType *replacementType = new TType(paramType);
250 size_t arraySize = *it++;
251 replacementType->setArraySize(0, static_cast<unsigned int>(arraySize));
252 param =
253 new TVariable(symbolTable, param->name(), replacementType, param->symbolType());
254 }
255 (*newFunction)->addParameter(param);
256 }
257 }
258 return *newFunction;
259 }
260
261 class ArrayTraverser
262 {
263 public:
ArrayTraverser()264 ArrayTraverser() { mCumulativeArraySizeStack.push_back(1); }
265
enterArray(const TType & arrayType)266 void enterArray(const TType &arrayType)
267 {
268 if (!arrayType.isArray())
269 return;
270 size_t currentArraySize = mCumulativeArraySizeStack.back();
271 const TSpan<const unsigned int> &arraySizes = arrayType.getArraySizes();
272 for (auto it = arraySizes.rbegin(); it != arraySizes.rend(); ++it)
273 {
274 unsigned int arraySize = *it;
275 currentArraySize *= arraySize;
276 mArraySizeStack.push_back(arraySize);
277 mCumulativeArraySizeStack.push_back(currentArraySize);
278 }
279 }
280
exitArray(const TType & arrayType)281 void exitArray(const TType &arrayType)
282 {
283 if (!arrayType.isArray())
284 return;
285 mArraySizeStack.resize(mArraySizeStack.size() - arrayType.getNumArraySizes());
286 mCumulativeArraySizeStack.resize(mCumulativeArraySizeStack.size() -
287 arrayType.getNumArraySizes());
288 }
289
290 protected:
291 std::vector<size_t> mArraySizeStack;
292 // The first element is 1; each successive element is the previous
293 // multiplied by the size of the next nested array in the current sampler.
294 // For example, with sampler2D foo[3][6], we would have {1, 3, 18}.
295 std::vector<size_t> mCumulativeArraySizeStack;
296 };
297
298 struct VariableExtraData
299 {
300 // The value consists of strides, starting from the outermost array.
301 // For example, with sampler2D foo[3][6], we would have {1, 6, 18}.
302 std::unordered_map<const TVariable *, std::vector<size_t>> arrayStrideMap;
303 // For each generated array parameter, holds the offset parameter.
304 std::unordered_map<const TVariable *, const TVariable *> paramOffsetMap;
305 };
306
307 class Traverser final : public TIntermTraverser, public ArrayTraverser
308 {
309 public:
Traverser(TSymbolTable * symbolTable)310 explicit Traverser(TSymbolTable *symbolTable)
311 : TIntermTraverser(true, false, true, symbolTable), mRemovedUniformsCount(0)
312 {
313 mSymbolTable->push();
314 }
315
~Traverser()316 ~Traverser() override { mSymbolTable->pop(); }
317
removedUniformsCount() const318 int removedUniformsCount() const { return mRemovedUniformsCount; }
319
320 // Each struct sampler declaration is stripped of its samplers. New uniforms are added for each
321 // stripped struct sampler. Flattens all arrays, including default uniforms.
visitDeclaration(Visit visit,TIntermDeclaration * decl)322 bool visitDeclaration(Visit visit, TIntermDeclaration *decl) override
323 {
324 if (visit != PreVisit)
325 return true;
326
327 if (!mInGlobalScope)
328 {
329 return true;
330 }
331
332 const TIntermSequence &sequence = *(decl->getSequence());
333 TIntermTyped *declarator = sequence.front()->getAsTyped();
334 const TType &type = declarator->getType();
335
336 if (type.isStructureContainingSamplers())
337 {
338 TIntermSequence *newSequence = new TIntermSequence;
339
340 if (type.isStructSpecifier())
341 {
342 stripStructSpecifierSamplers(type.getStruct(), newSequence);
343 }
344 else
345 {
346 TIntermSymbol *asSymbol = declarator->getAsSymbolNode();
347 ASSERT(asSymbol);
348 const TVariable &variable = asSymbol->variable();
349 ASSERT(variable.symbolType() != SymbolType::Empty);
350 extractStructSamplerUniforms(decl, variable, type.getStruct(), newSequence);
351 }
352
353 mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), decl, *newSequence);
354 }
355
356 if (type.isSampler() && type.isArray())
357 {
358 TIntermSequence *newSequence = new TIntermSequence;
359 TIntermSymbol *asSymbol = declarator->getAsSymbolNode();
360 ASSERT(asSymbol);
361 const TVariable &variable = asSymbol->variable();
362 ASSERT(variable.symbolType() != SymbolType::Empty);
363 extractSampler(variable.name(), variable.symbolType(), variable.getType(), newSequence,
364 0);
365 mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), decl, *newSequence);
366 }
367
368 return true;
369 }
370
371 // Each struct sampler reference is replaced with a reference to the new extracted sampler.
visitBinary(Visit visit,TIntermBinary * node)372 bool visitBinary(Visit visit, TIntermBinary *node) override
373 {
374 if (visit != PreVisit)
375 return true;
376 // If the node isn't a sampler or if this isn't the outermost access,
377 // continue.
378 if (!node->getType().isSampler() || node->getType().isArray())
379 {
380 return true;
381 }
382
383 if (node->getOp() == EOpIndexDirect || node->getOp() == EOpIndexIndirect ||
384 node->getOp() == EOpIndexDirectStruct)
385 {
386 ImmutableString newName = GetStructSamplerNameFromTypedNode(node);
387 const TVariable *samplerReplacement =
388 static_cast<const TVariable *>(mSymbolTable->findUserDefined(newName));
389 ASSERT(samplerReplacement);
390
391 TIntermTyped *replacement = new TIntermSymbol(samplerReplacement);
392
393 if (replacement->isArray())
394 {
395 // Add in an indirect index if contained in an array
396 const auto &strides = mVariableExtraData.arrayStrideMap[samplerReplacement];
397 ASSERT(!strides.empty());
398 if (strides.size() > 1)
399 {
400 auto it = mVariableExtraData.paramOffsetMap.find(samplerReplacement);
401
402 TIntermTyped *offset =
403 it == mVariableExtraData.paramOffsetMap.end()
404 ? static_cast<TIntermTyped *>(CreateIndexNode(0))
405 : static_cast<TIntermTyped *>(new TIntermSymbol(it->second));
406
407 TIntermTyped *index = GetIndexExpressionFromTypedNode(node, strides, offset);
408 replacement = new TIntermBinary(EOpIndexIndirect, replacement, index);
409 }
410 }
411
412 queueReplacement(replacement, OriginalNode::IS_DROPPED);
413 return true;
414 }
415
416 return true;
417 }
418
419 // In we are passing references to structs containing samplers we must new additional
420 // arguments. For each extracted struct sampler a new argument is added. This chains to nested
421 // structs.
visitFunctionPrototype(TIntermFunctionPrototype * node)422 void visitFunctionPrototype(TIntermFunctionPrototype *node) override
423 {
424 const TFunction *function = node->getFunction();
425
426 if (!function->hasSamplerInStructOrArrayParams())
427 {
428 return;
429 }
430
431 const TSymbol *foundFunction = mSymbolTable->findUserDefined(function->name());
432 if (foundFunction)
433 {
434 ASSERT(foundFunction->isFunction());
435 function = static_cast<const TFunction *>(foundFunction);
436 }
437 else
438 {
439 TFunction *newFunction = createStructSamplerFunction(function);
440 mSymbolTable->declareUserDefinedFunction(newFunction, true);
441 function = newFunction;
442 }
443
444 ASSERT(!function->hasSamplerInStructOrArrayOfArrayParams());
445 TIntermFunctionPrototype *newProto = new TIntermFunctionPrototype(function);
446 queueReplacement(newProto, OriginalNode::IS_DROPPED);
447 }
448
449 // We insert a new scope for each function definition so we can track the new parameters.
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)450 bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override
451 {
452 if (visit == PreVisit)
453 {
454 mSymbolTable->push();
455 }
456 else
457 {
458 ASSERT(visit == PostVisit);
459 mSymbolTable->pop();
460 }
461 return true;
462 }
463
464 // For function call nodes we pass references to the extracted struct samplers in that scope.
visitAggregate(Visit visit,TIntermAggregate * node)465 bool visitAggregate(Visit visit, TIntermAggregate *node) override
466 {
467 if (visit != PreVisit)
468 return true;
469
470 if (!node->isFunctionCall())
471 return true;
472
473 const TFunction *function = node->getFunction();
474 if (!function->hasSamplerInStructOrArrayParams())
475 return true;
476
477 ASSERT(node->getOp() == EOpCallFunctionInAST);
478 TIntermSequence *newArguments = getStructSamplerArguments(function, node->getSequence());
479
480 TFunction *newFunction = GenerateFunctionFromArguments(
481 function, newArguments, mSymbolTable, &mFunctionInstantiations, &mFunctionMap, nullptr);
482
483 TIntermAggregate *newCall =
484 TIntermAggregate::CreateFunctionCall(*newFunction, newArguments);
485 queueReplacement(newCall, OriginalNode::IS_DROPPED);
486 return true;
487 }
488
getFunctionInstantiations()489 FunctionInstantiations *getFunctionInstantiations() { return &mFunctionInstantiations; }
490
getFunctionMap()491 std::unordered_map<const TFunction *, const TFunction *> *getFunctionMap()
492 {
493 return &mFunctionMap;
494 }
495
496 private:
497 // This returns the name of a struct sampler reference. References are always TIntermBinary.
GetStructSamplerNameFromTypedNode(TIntermTyped * node)498 static ImmutableString GetStructSamplerNameFromTypedNode(TIntermTyped *node)
499 {
500 std::string stringBuilder;
501
502 TIntermTyped *currentNode = node;
503 while (currentNode->getAsBinaryNode())
504 {
505 TIntermBinary *asBinary = currentNode->getAsBinaryNode();
506
507 switch (asBinary->getOp())
508 {
509 case EOpIndexDirectStruct:
510 {
511 stringBuilder.insert(0, asBinary->getIndexStructFieldName().data());
512 stringBuilder.insert(0, "_");
513 break;
514 }
515
516 case EOpIndexDirect:
517 case EOpIndexIndirect:
518 break;
519
520 default:
521 UNREACHABLE();
522 break;
523 }
524
525 currentNode = asBinary->getLeft();
526 }
527
528 const ImmutableString &variableName = currentNode->getAsSymbolNode()->variable().name();
529 stringBuilder.insert(0, variableName.data());
530
531 return stringBuilder;
532 }
533
534 // Removes all the struct samplers from a struct specifier.
stripStructSpecifierSamplers(const TStructure * structure,TIntermSequence * newSequence)535 void stripStructSpecifierSamplers(const TStructure *structure, TIntermSequence *newSequence)
536 {
537 TFieldList *newFieldList = new TFieldList;
538 ASSERT(structure->containsSamplers());
539
540 for (const TField *field : structure->fields())
541 {
542 const TType &fieldType = *field->type();
543 if (!fieldType.isSampler() && !isRemovedStructType(fieldType))
544 {
545 TType *newType = nullptr;
546
547 if (fieldType.isStructureContainingSamplers())
548 {
549 const TSymbol *structSymbol =
550 mSymbolTable->findUserDefined(fieldType.getStruct()->name());
551 ASSERT(structSymbol && structSymbol->isStruct());
552 const TStructure *fieldStruct = static_cast<const TStructure *>(structSymbol);
553 newType = new TType(fieldStruct, true);
554 if (fieldType.isArray())
555 {
556 newType->makeArrays(fieldType.getArraySizes());
557 }
558 }
559 else
560 {
561 newType = new TType(fieldType);
562 }
563
564 TField *newField =
565 new TField(newType, field->name(), field->line(), field->symbolType());
566 newFieldList->push_back(newField);
567 }
568 }
569
570 // Prune empty structs.
571 if (newFieldList->empty())
572 {
573 mRemovedStructs.insert(structure->name());
574 return;
575 }
576
577 TStructure *newStruct =
578 new TStructure(mSymbolTable, structure->name(), newFieldList, structure->symbolType());
579 TType *newStructType = new TType(newStruct, true);
580 TVariable *newStructVar =
581 new TVariable(mSymbolTable, kEmptyImmutableString, newStructType, SymbolType::Empty);
582 TIntermSymbol *newStructRef = new TIntermSymbol(newStructVar);
583
584 TIntermDeclaration *structDecl = new TIntermDeclaration;
585 structDecl->appendDeclarator(newStructRef);
586
587 newSequence->push_back(structDecl);
588
589 mSymbolTable->declare(newStruct);
590 }
591
592 // Returns true if the type is a struct that was removed because we extracted all the members.
isRemovedStructType(const TType & type) const593 bool isRemovedStructType(const TType &type) const
594 {
595 const TStructure *structure = type.getStruct();
596 return (structure && (mRemovedStructs.count(structure->name()) > 0));
597 }
598
599 // Removes samplers from struct uniforms. For each sampler removed also adds a new globally
600 // defined sampler uniform.
extractStructSamplerUniforms(TIntermDeclaration * oldDeclaration,const TVariable & variable,const TStructure * structure,TIntermSequence * newSequence)601 void extractStructSamplerUniforms(TIntermDeclaration *oldDeclaration,
602 const TVariable &variable,
603 const TStructure *structure,
604 TIntermSequence *newSequence)
605 {
606 ASSERT(structure->containsSamplers());
607
608 size_t nonSamplerCount = 0;
609
610 enterArray(variable.getType());
611
612 for (const TField *field : structure->fields())
613 {
614 nonSamplerCount +=
615 extractFieldSamplers(variable.name(), field, variable.getType(), newSequence);
616 }
617
618 if (nonSamplerCount > 0)
619 {
620 // Keep the old declaration around if it has other members.
621 newSequence->push_back(oldDeclaration);
622 }
623 else
624 {
625 mRemovedUniformsCount++;
626 }
627
628 exitArray(variable.getType());
629 }
630
631 // Extracts samplers from a field of a struct. Works with nested structs and arrays.
extractFieldSamplers(const ImmutableString & prefix,const TField * field,const TType & containingType,TIntermSequence * newSequence)632 size_t extractFieldSamplers(const ImmutableString &prefix,
633 const TField *field,
634 const TType &containingType,
635 TIntermSequence *newSequence)
636 {
637 return extractFieldSamplersImpl(prefix, field, newSequence);
638 }
639
640 // Extracts samplers from a field of a struct. Works with nested structs and arrays.
extractFieldSamplersImpl(const ImmutableString & prefix,const TField * field,TIntermSequence * newSequence)641 size_t extractFieldSamplersImpl(const ImmutableString &prefix,
642 const TField *field,
643 TIntermSequence *newSequence)
644 {
645 size_t nonSamplerCount = 0;
646
647 const TType &fieldType = *field->type();
648 if (fieldType.isSampler() || fieldType.isStructureContainingSamplers())
649 {
650 ImmutableStringBuilder stringBuilder(prefix.length() + field->name().length() + 1);
651 stringBuilder << prefix << "_" << field->name();
652 ImmutableString newPrefix(stringBuilder);
653
654 if (fieldType.isSampler())
655 {
656 extractSampler(newPrefix, SymbolType::AngleInternal, fieldType, newSequence, 0);
657 }
658 else
659 {
660 enterArray(fieldType);
661 const TStructure *structure = fieldType.getStruct();
662 for (const TField *nestedField : structure->fields())
663 {
664 nonSamplerCount +=
665 extractFieldSamplers(newPrefix, nestedField, fieldType, newSequence);
666 }
667 exitArray(fieldType);
668 }
669 }
670 else
671 {
672 nonSamplerCount++;
673 }
674
675 return nonSamplerCount;
676 }
677
678 // Extracts a sampler from a struct. Declares the new extracted sampler.
extractSampler(const ImmutableString & newName,SymbolType symbolType,const TType & fieldType,TIntermSequence * newSequence,size_t arrayLevel)679 void extractSampler(const ImmutableString &newName,
680 SymbolType symbolType,
681 const TType &fieldType,
682 TIntermSequence *newSequence,
683 size_t arrayLevel)
684 {
685 enterArray(fieldType);
686
687 TType *newType = new TType(fieldType);
688 while (newType->isArray())
689 {
690 newType->toArrayElementType();
691 }
692 if (!mArraySizeStack.empty())
693 {
694 newType->makeArray(static_cast<unsigned int>(mCumulativeArraySizeStack.back()));
695 }
696 newType->setQualifier(EvqUniform);
697 TVariable *newVariable = new TVariable(mSymbolTable, newName, newType, symbolType);
698 TIntermSymbol *newRef = new TIntermSymbol(newVariable);
699
700 TIntermDeclaration *samplerDecl = new TIntermDeclaration;
701 samplerDecl->appendDeclarator(newRef);
702
703 newSequence->push_back(samplerDecl);
704
705 // TODO(syoussefi): Use a SymbolType::Empty name instead of generating a name as currently
706 // done. There is no guarantee that these generated names cannot clash. Create a mapping
707 // from the previous name to the name assigned to the SymbolType::Empty variable so
708 // ShaderVariable::mappedName can be updated post-transformation.
709 // http://anglebug.com/4301
710 if (symbolType == SymbolType::AngleInternal)
711 {
712 mSymbolTable->declareInternal(newVariable);
713 }
714 else
715 {
716 mSymbolTable->declare(newVariable);
717 }
718
719 GenerateArrayStrides(mArraySizeStack, &mVariableExtraData.arrayStrideMap[newVariable]);
720
721 exitArray(fieldType);
722 }
723
724 // Returns the chained name of a sampler uniform field.
GetFieldName(const ImmutableString & paramName,const TField * field)725 static ImmutableString GetFieldName(const ImmutableString ¶mName, const TField *field)
726 {
727 ImmutableStringBuilder nameBuilder(paramName.length() + 1 + field->name().length());
728 nameBuilder << paramName << "_";
729 nameBuilder << field->name();
730
731 return nameBuilder;
732 }
733
734 // A pattern that visits every parameter of a function call. Uses different handlers for struct
735 // parameters, struct sampler parameters, and non-struct parameters.
736 class StructSamplerFunctionVisitor : angle::NonCopyable, public ArrayTraverser
737 {
738 public:
739 StructSamplerFunctionVisitor() = default;
740 virtual ~StructSamplerFunctionVisitor() = default;
741
traverse(const TFunction * function)742 virtual void traverse(const TFunction *function)
743 {
744 size_t paramCount = function->getParamCount();
745
746 for (size_t paramIndex = 0; paramIndex < paramCount; ++paramIndex)
747 {
748 const TVariable *param = function->getParam(paramIndex);
749 const TType ¶mType = param->getType();
750
751 if (paramType.isStructureContainingSamplers())
752 {
753 const ImmutableString &baseName = getNameFromIndex(function, paramIndex);
754 if (traverseStructContainingSamplers(baseName, paramType, paramIndex))
755 {
756 visitStructParam(function, paramIndex);
757 }
758 }
759 else if (paramType.isArray() && paramType.isSampler())
760 {
761 const ImmutableString ¶mName = getNameFromIndex(function, paramIndex);
762 traverseLeafSampler(paramName, paramType, paramIndex);
763 }
764 else
765 {
766 visitNonStructParam(function, paramIndex);
767 }
768 }
769 }
770
771 virtual ImmutableString getNameFromIndex(const TFunction *function, size_t paramIndex) = 0;
772 // Also includes samplers in arrays of arrays.
773 virtual void visitSamplerInStructParam(const ImmutableString &name,
774 const TType *type,
775 size_t paramIndex) = 0;
776 virtual void visitStructParam(const TFunction *function, size_t paramIndex) = 0;
777 virtual void visitNonStructParam(const TFunction *function, size_t paramIndex) = 0;
778
779 private:
traverseStructContainingSamplers(const ImmutableString & baseName,const TType & structType,size_t paramIndex)780 bool traverseStructContainingSamplers(const ImmutableString &baseName,
781 const TType &structType,
782 size_t paramIndex)
783 {
784 bool hasNonSamplerFields = false;
785 const TStructure *structure = structType.getStruct();
786 enterArray(structType);
787 for (const TField *field : structure->fields())
788 {
789 if (field->type()->isStructureContainingSamplers() || field->type()->isSampler())
790 {
791 if (traverseSamplerInStruct(baseName, structType, field, paramIndex))
792 {
793 hasNonSamplerFields = true;
794 }
795 }
796 else
797 {
798 hasNonSamplerFields = true;
799 }
800 }
801 exitArray(structType);
802 return hasNonSamplerFields;
803 }
804
traverseSamplerInStruct(const ImmutableString & baseName,const TType & baseType,const TField * field,size_t paramIndex)805 bool traverseSamplerInStruct(const ImmutableString &baseName,
806 const TType &baseType,
807 const TField *field,
808 size_t paramIndex)
809 {
810 bool hasNonSamplerParams = false;
811
812 if (field->type()->isStructureContainingSamplers())
813 {
814 ImmutableString name = GetFieldName(baseName, field);
815 hasNonSamplerParams =
816 traverseStructContainingSamplers(name, *field->type(), paramIndex);
817 }
818 else
819 {
820 ASSERT(field->type()->isSampler());
821 ImmutableString name = GetFieldName(baseName, field);
822 traverseLeafSampler(name, *field->type(), paramIndex);
823 }
824
825 return hasNonSamplerParams;
826 }
827
traverseLeafSampler(const ImmutableString & samplerName,const TType & samplerType,size_t paramIndex)828 void traverseLeafSampler(const ImmutableString &samplerName,
829 const TType &samplerType,
830 size_t paramIndex)
831 {
832 enterArray(samplerType);
833 visitSamplerInStructParam(samplerName, &samplerType, paramIndex);
834 exitArray(samplerType);
835 return;
836 }
837 };
838
839 // A visitor that replaces functions with struct sampler references. The struct sampler
840 // references are expanded to include new fields for the structs.
841 class CreateStructSamplerFunctionVisitor final : public StructSamplerFunctionVisitor
842 {
843 public:
CreateStructSamplerFunctionVisitor(TSymbolTable * symbolTable,VariableExtraData * extraData)844 CreateStructSamplerFunctionVisitor(TSymbolTable *symbolTable, VariableExtraData *extraData)
845 : mSymbolTable(symbolTable), mNewFunction(nullptr), mExtraData(extraData)
846 {}
847
getNameFromIndex(const TFunction * function,size_t paramIndex)848 ImmutableString getNameFromIndex(const TFunction *function, size_t paramIndex) override
849 {
850 const TVariable *param = function->getParam(paramIndex);
851 return param->name();
852 }
853
traverse(const TFunction * function)854 void traverse(const TFunction *function) override
855 {
856 mNewFunction =
857 new TFunction(mSymbolTable, function->name(), function->symbolType(),
858 &function->getReturnType(), function->isKnownToNotHaveSideEffects());
859
860 StructSamplerFunctionVisitor::traverse(function);
861 }
862
visitSamplerInStructParam(const ImmutableString & name,const TType * type,size_t paramIndex)863 void visitSamplerInStructParam(const ImmutableString &name,
864 const TType *type,
865 size_t paramIndex) override
866 {
867 if (mArraySizeStack.size() > 0)
868 {
869 TType *newType = new TType(*type);
870 newType->toArrayBaseType();
871 newType->makeArray(static_cast<unsigned int>(mCumulativeArraySizeStack.back()));
872 type = newType;
873 }
874 TVariable *fieldSampler =
875 new TVariable(mSymbolTable, name, type, SymbolType::AngleInternal);
876 mNewFunction->addParameter(fieldSampler);
877 mSymbolTable->declareInternal(fieldSampler);
878 if (mArraySizeStack.size() > 0)
879 {
880 // Also declare an offset parameter.
881 const TType *intType = StaticType::GetBasic<EbtInt>();
882 TVariable *samplerOffset = new TVariable(mSymbolTable, kEmptyImmutableString,
883 intType, SymbolType::AngleInternal);
884 mNewFunction->addParameter(samplerOffset);
885 GenerateArrayStrides(mArraySizeStack, &mExtraData->arrayStrideMap[fieldSampler]);
886 mExtraData->paramOffsetMap[fieldSampler] = samplerOffset;
887 }
888 }
889
visitStructParam(const TFunction * function,size_t paramIndex)890 void visitStructParam(const TFunction *function, size_t paramIndex) override
891 {
892 const TVariable *param = function->getParam(paramIndex);
893 TType *structType = GetStructSamplerParameterType(mSymbolTable, *param);
894 TVariable *newParam =
895 new TVariable(mSymbolTable, param->name(), structType, param->symbolType());
896 mNewFunction->addParameter(newParam);
897 }
898
visitNonStructParam(const TFunction * function,size_t paramIndex)899 void visitNonStructParam(const TFunction *function, size_t paramIndex) override
900 {
901 const TVariable *param = function->getParam(paramIndex);
902 mNewFunction->addParameter(param);
903 }
904
getNewFunction() const905 TFunction *getNewFunction() const { return mNewFunction; }
906
907 private:
908 TSymbolTable *mSymbolTable;
909 TFunction *mNewFunction;
910 VariableExtraData *mExtraData;
911 };
912
createStructSamplerFunction(const TFunction * function)913 TFunction *createStructSamplerFunction(const TFunction *function)
914 {
915 CreateStructSamplerFunctionVisitor visitor(mSymbolTable, &mVariableExtraData);
916 visitor.traverse(function);
917 return visitor.getNewFunction();
918 }
919
920 // A visitor that replaces function calls with expanded struct sampler parameters.
921 class GetSamplerArgumentsVisitor final : public StructSamplerFunctionVisitor
922 {
923 public:
GetSamplerArgumentsVisitor(TSymbolTable * symbolTable,const TIntermSequence * arguments,VariableExtraData * extraData)924 GetSamplerArgumentsVisitor(TSymbolTable *symbolTable,
925 const TIntermSequence *arguments,
926 VariableExtraData *extraData)
927 : mSymbolTable(symbolTable),
928 mArguments(arguments),
929 mNewArguments(new TIntermSequence),
930 mExtraData(extraData)
931 {}
932
getNameFromIndex(const TFunction * function,size_t paramIndex)933 ImmutableString getNameFromIndex(const TFunction *function, size_t paramIndex) override
934 {
935 TIntermTyped *argument = (*mArguments)[paramIndex]->getAsTyped();
936 return GetStructSamplerNameFromTypedNode(argument);
937 }
938
visitSamplerInStructParam(const ImmutableString & name,const TType * type,size_t paramIndex)939 void visitSamplerInStructParam(const ImmutableString &name,
940 const TType *type,
941 size_t paramIndex) override
942 {
943 const TVariable *argSampler =
944 static_cast<const TVariable *>(mSymbolTable->findUserDefined(name));
945 ASSERT(argSampler);
946
947 TIntermTyped *argument = (*mArguments)[paramIndex]->getAsTyped();
948
949 auto it = mExtraData->paramOffsetMap.find(argSampler);
950 TIntermTyped *argOffset =
951 it == mExtraData->paramOffsetMap.end()
952 ? static_cast<TIntermTyped *>(CreateIndexNode(0))
953 : static_cast<TIntermTyped *>(new TIntermSymbol(it->second));
954
955 TIntermTyped *finalOffset = GetIndexExpressionFromTypedNode(
956 argument, mExtraData->arrayStrideMap[argSampler], argOffset);
957
958 TIntermSymbol *argSymbol = new TIntermSymbol(argSampler);
959
960 // If we have a regular sampler inside a struct (possibly an array
961 // of structs), handle this case separately.
962 if (!type->isArray() && mArraySizeStack.size() == 0)
963 {
964 if (argSampler->getType().isArray())
965 {
966 TIntermTyped *argIndex =
967 new TIntermBinary(EOpIndexIndirect, argSymbol, finalOffset);
968 mNewArguments->push_back(argIndex);
969 }
970 else
971 {
972 mNewArguments->push_back(argSymbol);
973 }
974 return;
975 }
976
977 mNewArguments->push_back(argSymbol);
978
979 mNewArguments->push_back(finalOffset);
980 // If array, we need to calculate the offset based on what indices
981 // are present in the argument.
982 }
983
visitStructParam(const TFunction * function,size_t paramIndex)984 void visitStructParam(const TFunction *function, size_t paramIndex) override
985 {
986 // The tree structure of the parameter is modified to point to the new type. This leaves
987 // the tree in a consistent state.
988 TIntermTyped *argument = (*mArguments)[paramIndex]->getAsTyped();
989 TIntermTyped *replacement = ReplaceTypeOfTypedStructNode(argument, mSymbolTable);
990 mNewArguments->push_back(replacement);
991 }
992
visitNonStructParam(const TFunction * function,size_t paramIndex)993 void visitNonStructParam(const TFunction *function, size_t paramIndex) override
994 {
995 TIntermTyped *argument = (*mArguments)[paramIndex]->getAsTyped();
996 mNewArguments->push_back(argument);
997 }
998
getNewArguments() const999 TIntermSequence *getNewArguments() const { return mNewArguments; }
1000
1001 private:
1002 TSymbolTable *mSymbolTable;
1003 const TIntermSequence *mArguments;
1004 TIntermSequence *mNewArguments;
1005 VariableExtraData *mExtraData;
1006 };
1007
getStructSamplerArguments(const TFunction * function,const TIntermSequence * arguments)1008 TIntermSequence *getStructSamplerArguments(const TFunction *function,
1009 const TIntermSequence *arguments)
1010 {
1011 GetSamplerArgumentsVisitor visitor(mSymbolTable, arguments, &mVariableExtraData);
1012 visitor.traverse(function);
1013 return visitor.getNewArguments();
1014 }
1015
1016 int mRemovedUniformsCount;
1017 std::set<ImmutableString> mRemovedStructs;
1018 FunctionInstantiations mFunctionInstantiations;
1019 FunctionMap mFunctionMap;
1020 VariableExtraData mVariableExtraData;
1021 };
1022
1023 class MonomorphizeTraverser final : public TIntermTraverser
1024 {
1025 public:
1026 typedef std::unordered_map<const TVariable *, const TVariable *> VariableReplacementMap;
1027
MonomorphizeTraverser(TCompiler * compiler,TSymbolTable * symbolTable,FunctionInstantiations * functionInstantiations,std::unordered_map<const TFunction *,const TFunction * > * functionMap)1028 explicit MonomorphizeTraverser(
1029 TCompiler *compiler,
1030 TSymbolTable *symbolTable,
1031 FunctionInstantiations *functionInstantiations,
1032 std::unordered_map<const TFunction *, const TFunction *> *functionMap)
1033 : TIntermTraverser(true, false, true, symbolTable),
1034 mFunctionInstantiations(*functionInstantiations),
1035 mFunctionMap(functionMap),
1036 mCompiler(compiler),
1037 mSubpassesSucceeded(true)
1038 {}
1039
switchToPending()1040 void switchToPending()
1041 {
1042 mFunctionInstantiations.clear();
1043 mFunctionInstantiations.swap(mPendingInstantiations);
1044 }
1045
hasPending()1046 bool hasPending()
1047 {
1048 if (mPendingInstantiations.empty())
1049 return false;
1050 for (auto &entry : mPendingInstantiations)
1051 {
1052 if (!entry.second.empty())
1053 {
1054 return true;
1055 }
1056 }
1057 return false;
1058 }
1059
subpassesSucceeded()1060 bool subpassesSucceeded() { return mSubpassesSucceeded; }
1061
visitFunctionPrototype(TIntermFunctionPrototype * node)1062 void visitFunctionPrototype(TIntermFunctionPrototype *node) override
1063 {
1064 mReplacementPrototypes.clear();
1065 const TFunction *function = node->getFunction();
1066
1067 auto &generatedMap = mGeneratedInstantiations[function->name()];
1068
1069 auto it = mFunctionInstantiations.find(function->name());
1070 if (it == mFunctionInstantiations.end())
1071 return;
1072 for (const auto &instantiation : it->second)
1073 {
1074 TFunction *replacementFunction = instantiation.second;
1075 mReplacementPrototypes.push_back(new TIntermFunctionPrototype(replacementFunction));
1076 generatedMap[instantiation.first] = replacementFunction;
1077 }
1078 if (!mInFunctionDefinition)
1079 {
1080 insertStatementsInParentBlock(mReplacementPrototypes);
1081 }
1082 }
1083
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)1084 bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override
1085 {
1086 mInFunctionDefinition = visit == PreVisit;
1087 if (visit != PostVisit)
1088 return true;
1089 TIntermSequence replacements;
1090 const TFunction *function = node->getFunction();
1091 size_t numParameters = function->getParamCount();
1092
1093 for (TIntermNode *replacementNode : mReplacementPrototypes)
1094 {
1095 TIntermFunctionPrototype *replacementPrototype =
1096 replacementNode->getAsFunctionPrototypeNode();
1097 const TFunction *replacementFunction = replacementPrototype->getFunction();
1098
1099 // Replace function parameters with correct array sizes.
1100 VariableReplacementMap variableReplacementMap;
1101 ASSERT(replacementPrototype->getFunction()->getParamCount() == numParameters);
1102 for (size_t i = 0; i < numParameters; i++)
1103 {
1104 const TVariable *origParam = function->getParam(i);
1105 const TVariable *newParam = replacementFunction->getParam(i);
1106 if (origParam != newParam)
1107 {
1108 variableReplacementMap[origParam] = newParam;
1109 }
1110 }
1111
1112 TIntermBlock *body = node->getBody()->deepCopy();
1113 ReplaceVariablesTraverser replaceVariables(mSymbolTable, &variableReplacementMap);
1114 body->traverse(&replaceVariables);
1115 mSubpassesSucceeded &= replaceVariables.updateTree(mCompiler, body);
1116 CollectNewInstantiationsTraverser collectNewInstantiations(
1117 mSymbolTable, &mPendingInstantiations, &mGeneratedInstantiations, mFunctionMap);
1118 body->traverse(&collectNewInstantiations);
1119 mSubpassesSucceeded &= collectNewInstantiations.updateTree(mCompiler, body);
1120 replacements.push_back(new TIntermFunctionDefinition(replacementPrototype, body));
1121 }
1122 insertStatementsInParentBlock(replacements);
1123 return true;
1124 }
1125
1126 private:
1127 bool mInFunctionDefinition;
1128 FunctionInstantiations mFunctionInstantiations;
1129 // Set of already-generated instantiations.
1130 FunctionInstantiations mGeneratedInstantiations;
1131 // New instantiations caused by other instantiations.
1132 FunctionInstantiations mPendingInstantiations;
1133 std::unordered_map<const TFunction *, const TFunction *> *mFunctionMap;
1134 TIntermSequence mReplacementPrototypes;
1135 TCompiler *mCompiler;
1136 bool mSubpassesSucceeded;
1137
1138 class ReplaceVariablesTraverser : public TIntermTraverser
1139 {
1140 public:
ReplaceVariablesTraverser(TSymbolTable * symbolTable,VariableReplacementMap * variableReplacementMap)1141 explicit ReplaceVariablesTraverser(TSymbolTable *symbolTable,
1142 VariableReplacementMap *variableReplacementMap)
1143 : TIntermTraverser(true, false, false, symbolTable),
1144 mVariableReplacementMap(variableReplacementMap)
1145 {}
1146
visitSymbol(TIntermSymbol * node)1147 void visitSymbol(TIntermSymbol *node) override
1148 {
1149 const TVariable *variable = &node->variable();
1150 auto it = mVariableReplacementMap->find(variable);
1151 if (it != mVariableReplacementMap->end())
1152 {
1153 queueReplacement(new TIntermSymbol(it->second), OriginalNode::IS_DROPPED);
1154 }
1155 }
1156
1157 private:
1158 VariableReplacementMap *mVariableReplacementMap;
1159 };
1160
1161 class CollectNewInstantiationsTraverser : public TIntermTraverser
1162 {
1163 public:
CollectNewInstantiationsTraverser(TSymbolTable * symbolTable,FunctionInstantiations * pendingInstantiations,FunctionInstantiations * generatedInstantiations,std::unordered_map<const TFunction *,const TFunction * > * functionMap)1164 explicit CollectNewInstantiationsTraverser(
1165 TSymbolTable *symbolTable,
1166 FunctionInstantiations *pendingInstantiations,
1167 FunctionInstantiations *generatedInstantiations,
1168 std::unordered_map<const TFunction *, const TFunction *> *functionMap)
1169 : TIntermTraverser(true, false, false, symbolTable),
1170 mPendingInstantiations(pendingInstantiations),
1171 mGeneratedInstantiations(generatedInstantiations),
1172 mFunctionMap(functionMap)
1173 {}
1174
visitAggregate(Visit visit,TIntermAggregate * node)1175 bool visitAggregate(Visit visit, TIntermAggregate *node) override
1176 {
1177 if (!node->isFunctionCall())
1178 return true;
1179 const TFunction *function = node->getFunction();
1180 const TFunction *oldFunction;
1181 {
1182 auto it = mFunctionMap->find(function);
1183 if (it == mFunctionMap->end())
1184 return true;
1185 oldFunction = it->second;
1186 }
1187 ASSERT(node->getOp() == EOpCallFunctionInAST);
1188 TIntermSequence *arguments = node->getSequence();
1189 TFunction *newFunction = GenerateFunctionFromArguments(
1190 oldFunction, arguments, mSymbolTable, mPendingInstantiations, mFunctionMap,
1191 mGeneratedInstantiations);
1192 queueReplacement(TIntermAggregate::CreateFunctionCall(*newFunction, arguments),
1193 OriginalNode::IS_DROPPED);
1194 return true;
1195 }
1196
1197 private:
1198 FunctionInstantiations *mPendingInstantiations;
1199 FunctionInstantiations *mGeneratedInstantiations;
1200 std::unordered_map<const TFunction *, const TFunction *> *mFunctionMap;
1201 };
1202 };
1203 } // anonymous namespace
1204
RewriteStructSamplers(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,int * removedUniformsCountOut)1205 bool RewriteStructSamplers(TCompiler *compiler,
1206 TIntermBlock *root,
1207 TSymbolTable *symbolTable,
1208 int *removedUniformsCountOut)
1209 {
1210 Traverser rewriteStructSamplers(symbolTable);
1211 root->traverse(&rewriteStructSamplers);
1212 if (!rewriteStructSamplers.updateTree(compiler, root))
1213 {
1214 return false;
1215 }
1216 *removedUniformsCountOut = rewriteStructSamplers.removedUniformsCount();
1217
1218 if (rewriteStructSamplers.getFunctionInstantiations()->empty())
1219 {
1220 return true;
1221 }
1222
1223 MonomorphizeTraverser monomorphizeFunctions(compiler, symbolTable,
1224 rewriteStructSamplers.getFunctionInstantiations(),
1225 rewriteStructSamplers.getFunctionMap());
1226 root->traverse(&monomorphizeFunctions);
1227 if (!monomorphizeFunctions.subpassesSucceeded())
1228 {
1229 return false;
1230 }
1231 if (!monomorphizeFunctions.updateTree(compiler, root))
1232 {
1233 return false;
1234 }
1235
1236 // Generate instantiations caused by other instantiations.
1237 while (monomorphizeFunctions.hasPending())
1238 {
1239 monomorphizeFunctions.switchToPending();
1240 root->traverse(&monomorphizeFunctions);
1241 if (!monomorphizeFunctions.subpassesSucceeded())
1242 {
1243 return false;
1244 }
1245 if (!monomorphizeFunctions.updateTree(compiler, root))
1246 {
1247 return false;
1248 }
1249 }
1250
1251 return true;
1252 }
1253 } // namespace sh
1254