• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2019 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RewriteRowMajorMatrices: Rewrite row-major matrices as column-major.
7 //
8 
9 #include "compiler/translator/tree_ops/glsl/apple/RewriteRowMajorMatrices.h"
10 
11 #include "common/span.h"
12 #include "compiler/translator/Compiler.h"
13 #include "compiler/translator/ImmutableStringBuilder.h"
14 #include "compiler/translator/StaticType.h"
15 #include "compiler/translator/SymbolTable.h"
16 #include "compiler/translator/tree_util/IntermNode_util.h"
17 #include "compiler/translator/tree_util/IntermTraverse.h"
18 #include "compiler/translator/tree_util/ReplaceVariable.h"
19 
20 namespace sh
21 {
22 namespace
23 {
24 // Only structs with matrices are tracked.  If layout(row_major) is applied to a struct that doesn't
25 // have matrices, it's silently dropped.  This is also used to avoid creating duplicates for inner
26 // structs that don't have matrices.
27 struct StructConversionData
28 {
29     // The converted struct with every matrix transposed.
30     TStructure *convertedStruct = nullptr;
31 
32     // The copy-from and copy-to functions copying from a struct to its converted version and back.
33     TFunction *copyFromOriginal = nullptr;
34     TFunction *copyToOriginal   = nullptr;
35 };
36 
DoesFieldContainRowMajorMatrix(const TField * field,bool isBlockRowMajor)37 bool DoesFieldContainRowMajorMatrix(const TField *field, bool isBlockRowMajor)
38 {
39     TLayoutMatrixPacking matrixPacking = field->type()->getLayoutQualifier().matrixPacking;
40 
41     // The field is row major if either explicitly specified as such, or if it inherits it from the
42     // block layout qualifier.
43     if (matrixPacking == EmpColumnMajor || (matrixPacking == EmpUnspecified && !isBlockRowMajor))
44     {
45         return false;
46     }
47 
48     // The field is qualified with row_major, but if it's not a matrix or a struct containing
49     // matrices, that's a useless qualifier.
50     const TType *type = field->type();
51     return type->isMatrix() || type->isStructureContainingMatrices();
52 }
53 
DuplicateField(const TField * field)54 TField *DuplicateField(const TField *field)
55 {
56     return new TField(new TType(*field->type()), field->name(), field->line(), field->symbolType());
57 }
58 
SetColumnMajor(TType * type)59 void SetColumnMajor(TType *type)
60 {
61     TLayoutQualifier layoutQualifier = type->getLayoutQualifier();
62     layoutQualifier.matrixPacking    = EmpColumnMajor;
63     type->setLayoutQualifier(layoutQualifier);
64 }
65 
TransposeMatrixType(const TType * type)66 TType *TransposeMatrixType(const TType *type)
67 {
68     TType *newType = new TType(*type);
69 
70     SetColumnMajor(newType);
71 
72     newType->setPrimarySize(type->getRows());
73     newType->setSecondarySize(type->getCols());
74 
75     return newType;
76 }
77 
CopyArraySizes(const TType * from,TType * to)78 void CopyArraySizes(const TType *from, TType *to)
79 {
80     if (from->isArray())
81     {
82         to->makeArrays(from->getArraySizes());
83     }
84 }
85 
86 // Determine if the node is an index node (array index or struct field selection).  For the purposes
87 // of this transformation, swizzle nodes are considered index nodes too.
IsIndexNode(TIntermNode * node,TIntermNode * child)88 bool IsIndexNode(TIntermNode *node, TIntermNode *child)
89 {
90     if (node->getAsSwizzleNode())
91     {
92         return true;
93     }
94 
95     TIntermBinary *binaryNode = node->getAsBinaryNode();
96     if (binaryNode == nullptr || child != binaryNode->getLeft())
97     {
98         return false;
99     }
100 
101     TOperator op = binaryNode->getOp();
102 
103     return op == EOpIndexDirect || op == EOpIndexDirectInterfaceBlock ||
104            op == EOpIndexDirectStruct || op == EOpIndexIndirect;
105 }
106 
CopyToTempVariable(TSymbolTable * symbolTable,TIntermTyped * node,TIntermSequence * prependStatements)107 TIntermSymbol *CopyToTempVariable(TSymbolTable *symbolTable,
108                                   TIntermTyped *node,
109                                   TIntermSequence *prependStatements)
110 {
111     TVariable *temp              = CreateTempVariable(symbolTable, &node->getType(), EvqTemporary);
112     TIntermDeclaration *tempDecl = CreateTempInitDeclarationNode(temp, node);
113     prependStatements->push_back(tempDecl);
114 
115     return new TIntermSymbol(temp);
116 }
117 
CreateStructCopyCall(const TFunction * copyFunc,TIntermTyped * expression)118 TIntermAggregate *CreateStructCopyCall(const TFunction *copyFunc, TIntermTyped *expression)
119 {
120     TIntermSequence args = {expression};
121     return TIntermAggregate::CreateFunctionCall(*copyFunc, &args);
122 }
123 
CreateTransposeCall(TSymbolTable * symbolTable,TIntermTyped * expression)124 TIntermTyped *CreateTransposeCall(TSymbolTable *symbolTable, TIntermTyped *expression)
125 {
126     TIntermSequence args = {expression};
127     return CreateBuiltInFunctionCallNode("transpose", &args, *symbolTable, 300);
128 }
129 
GetIndex(TSymbolTable * symbolTable,TIntermNode * node,TIntermSequence * indices,TIntermSequence * prependStatements)130 TOperator GetIndex(TSymbolTable *symbolTable,
131                    TIntermNode *node,
132                    TIntermSequence *indices,
133                    TIntermSequence *prependStatements)
134 {
135     // Swizzle nodes are converted EOpIndexDirect for simplicity, with one index per swizzle
136     // channel.
137     TIntermSwizzle *asSwizzle = node->getAsSwizzleNode();
138     if (asSwizzle)
139     {
140         for (int channel : asSwizzle->getSwizzleOffsets())
141         {
142             indices->push_back(CreateIndexNode(channel));
143         }
144         return EOpIndexDirect;
145     }
146 
147     TIntermBinary *binaryNode = node->getAsBinaryNode();
148     ASSERT(binaryNode);
149 
150     TOperator op = binaryNode->getOp();
151     ASSERT(op == EOpIndexDirect || op == EOpIndexDirectInterfaceBlock ||
152            op == EOpIndexDirectStruct || op == EOpIndexIndirect);
153 
154     TIntermTyped *rhs = binaryNode->getRight()->deepCopy();
155     if (rhs->getAsConstantUnion() == nullptr)
156     {
157         rhs = CopyToTempVariable(symbolTable, rhs, prependStatements);
158     }
159 
160     indices->push_back(rhs);
161     return op;
162 }
163 
ReplicateIndexNode(TSymbolTable * symbolTable,TIntermNode * node,TIntermTyped * lhs,TIntermSequence * indices)164 TIntermTyped *ReplicateIndexNode(TSymbolTable *symbolTable,
165                                  TIntermNode *node,
166                                  TIntermTyped *lhs,
167                                  TIntermSequence *indices)
168 {
169     TIntermSwizzle *asSwizzle = node->getAsSwizzleNode();
170     if (asSwizzle)
171     {
172         return new TIntermSwizzle(lhs, asSwizzle->getSwizzleOffsets());
173     }
174 
175     TIntermBinary *binaryNode = node->getAsBinaryNode();
176     ASSERT(binaryNode);
177 
178     ASSERT(indices->size() == 1);
179     TIntermTyped *rhs = indices->front()->getAsTyped();
180 
181     return new TIntermBinary(binaryNode->getOp(), lhs, rhs);
182 }
183 
GetIndexOp(TIntermNode * node)184 TOperator GetIndexOp(TIntermNode *node)
185 {
186     return node->getAsConstantUnion() ? EOpIndexDirect : EOpIndexIndirect;
187 }
188 
IsConvertedField(TIntermTyped * indexNode,const angle::HashMap<const TField *,bool> & convertedFields)189 bool IsConvertedField(TIntermTyped *indexNode,
190                       const angle::HashMap<const TField *, bool> &convertedFields)
191 {
192     TIntermBinary *asBinary = indexNode->getAsBinaryNode();
193     if (asBinary == nullptr)
194     {
195         return false;
196     }
197 
198     if (asBinary->getOp() != EOpIndexDirectInterfaceBlock)
199     {
200         return false;
201     }
202 
203     const TInterfaceBlock *interfaceBlock = asBinary->getLeft()->getType().getInterfaceBlock();
204     ASSERT(interfaceBlock);
205 
206     TIntermConstantUnion *fieldIndexNode = asBinary->getRight()->getAsConstantUnion();
207     ASSERT(fieldIndexNode);
208     ASSERT(fieldIndexNode->getConstantValue() != nullptr);
209 
210     int fieldIndex      = fieldIndexNode->getConstantValue()->getIConst();
211     const TField *field = interfaceBlock->fields()[fieldIndex];
212 
213     return convertedFields.count(field) > 0 && convertedFields.at(field);
214 }
215 
216 // A helper class to transform expressions of array type.  Iterates over every element of the
217 // array.
218 class TransformArrayHelper
219 {
220   public:
TransformArrayHelper(TIntermTyped * baseExpression)221     TransformArrayHelper(TIntermTyped *baseExpression)
222         : mBaseExpression(baseExpression),
223           mBaseExpressionType(baseExpression->getType()),
224           mArrayIndices(mBaseExpressionType.getArraySizes().size(), 0)
225     {}
226 
getNextElement(TIntermTyped * valueExpression,TIntermTyped ** valueElementOut)227     TIntermTyped *getNextElement(TIntermTyped *valueExpression, TIntermTyped **valueElementOut)
228     {
229         const angle::Span<const unsigned int> &arraySizes = mBaseExpressionType.getArraySizes();
230 
231         // If the last index overflows, element enumeration is done.
232         if (mArrayIndices.back() >= arraySizes.back())
233         {
234             return nullptr;
235         }
236 
237         TIntermTyped *element = getCurrentElement(mBaseExpression);
238         if (valueExpression)
239         {
240             *valueElementOut = getCurrentElement(valueExpression);
241         }
242 
243         incrementIndices(arraySizes);
244         return element;
245     }
246 
accumulateForRead(TSymbolTable * symbolTable,TIntermTyped * transformedElement,TIntermSequence * prependStatements)247     void accumulateForRead(TSymbolTable *symbolTable,
248                            TIntermTyped *transformedElement,
249                            TIntermSequence *prependStatements)
250     {
251         TIntermTyped *temp = CopyToTempVariable(symbolTable, transformedElement, prependStatements);
252         mReadTransformConstructorArgs.push_back(temp);
253     }
254 
constructReadTransformExpression()255     TIntermTyped *constructReadTransformExpression()
256     {
257         const angle::Span<const unsigned int> &baseTypeArraySizes =
258             mBaseExpressionType.getArraySizes();
259         TVector<unsigned int> arraySizes(baseTypeArraySizes.begin(), baseTypeArraySizes.end());
260         TIntermTyped *firstElement = mReadTransformConstructorArgs.front()->getAsTyped();
261         const TType &baseType      = firstElement->getType();
262 
263         // If N dimensions, acc[0] == size[0] and acc[i] == size[i] * acc[i-1].
264         // The last value is unused, and is not present.
265         TVector<unsigned int> accumulatedArraySizes(arraySizes.size() - 1);
266 
267         if (accumulatedArraySizes.size() > 0)
268         {
269             accumulatedArraySizes[0] = arraySizes[0];
270         }
271         for (size_t index = 1; index + 1 < arraySizes.size(); ++index)
272         {
273             accumulatedArraySizes[index] = accumulatedArraySizes[index - 1] * arraySizes[index];
274         }
275 
276         return constructReadTransformExpressionHelper(arraySizes, accumulatedArraySizes, baseType,
277                                                       0);
278     }
279 
280   private:
getCurrentElement(TIntermTyped * expression)281     TIntermTyped *getCurrentElement(TIntermTyped *expression)
282     {
283         TIntermTyped *element = expression->deepCopy();
284         for (auto it = mArrayIndices.rbegin(); it != mArrayIndices.rend(); ++it)
285         {
286             unsigned int index = *it;
287             element            = new TIntermBinary(EOpIndexDirect, element, CreateIndexNode(index));
288         }
289         return element;
290     }
291 
incrementIndices(const angle::Span<const unsigned int> & arraySizes)292     void incrementIndices(const angle::Span<const unsigned int> &arraySizes)
293     {
294         // Assume mArrayIndices is an N digit number, where digit i is in the range
295         // [0, arraySizes[i]).  This function increments this number.  Last digit is the most
296         // significant digit.
297         for (size_t digitIndex = 0; digitIndex < arraySizes.size(); ++digitIndex)
298         {
299             ++mArrayIndices[digitIndex];
300             if (mArrayIndices[digitIndex] < arraySizes[digitIndex])
301             {
302                 break;
303             }
304             if (digitIndex + 1 != arraySizes.size())
305             {
306                 // This digit has now overflown and is reset to 0, carry will be added to the next
307                 // digit.  The most significant digit will keep the overflow though, to make it
308                 // clear we have exhausted the range.
309                 mArrayIndices[digitIndex] = 0;
310             }
311         }
312     }
313 
constructReadTransformExpressionHelper(const TVector<unsigned int> & arraySizes,const TVector<unsigned int> & accumulatedArraySizes,const TType & baseType,size_t elementsOffset)314     TIntermTyped *constructReadTransformExpressionHelper(
315         const TVector<unsigned int> &arraySizes,
316         const TVector<unsigned int> &accumulatedArraySizes,
317         const TType &baseType,
318         size_t elementsOffset)
319     {
320         ASSERT(!arraySizes.empty());
321 
322         TType *transformedType = new TType(baseType);
323         transformedType->makeArrays(arraySizes);
324 
325         // If one dimensional, create the constructor with the given elements.
326         if (arraySizes.size() == 1)
327         {
328             ASSERT(accumulatedArraySizes.size() == 0);
329 
330             auto sliceStart = mReadTransformConstructorArgs.begin() + elementsOffset;
331             TIntermSequence slice(sliceStart, sliceStart + arraySizes[0]);
332 
333             return TIntermAggregate::CreateConstructor(*transformedType, &slice);
334         }
335 
336         // If not, create constructors for every column recursively.
337         TVector<unsigned int> subArraySizes(arraySizes.begin(), arraySizes.end() - 1);
338         TVector<unsigned int> subArrayAccumulatedSizes(accumulatedArraySizes.begin(),
339                                                        accumulatedArraySizes.end() - 1);
340 
341         TIntermSequence constructorArgs;
342         unsigned int colStride = accumulatedArraySizes.back();
343         for (size_t col = 0; col < arraySizes.back(); ++col)
344         {
345             size_t colElementsOffset = elementsOffset + col * colStride;
346 
347             constructorArgs.push_back(constructReadTransformExpressionHelper(
348                 subArraySizes, subArrayAccumulatedSizes, baseType, colElementsOffset));
349         }
350 
351         return TIntermAggregate::CreateConstructor(*transformedType, &constructorArgs);
352     }
353 
354     TIntermTyped *mBaseExpression;
355     const TType &mBaseExpressionType;
356     TVector<unsigned int> mArrayIndices;
357 
358     TIntermSequence mReadTransformConstructorArgs;
359 };
360 
361 // Traverser that:
362 //
363 // 1. Converts |layout(row_major) matCxR M| to |layout(column_major) matRxC Mt|.
364 // 2. Converts |layout(row_major) S s| to |layout(column_major) St st|, where S is a struct that
365 //    contains matrices, and St is a new struct with the transformation in 1 applied to matrix
366 //    members (recursively).
367 // 3. When read from, the following transformations are applied:
368 //
369 //            M       -> transpose(Mt)
370 //            M[c]    -> gvecN(Mt[0][c], Mt[1][c], ..., Mt[N-1][c])
371 //            M[c][r] -> Mt[r][c]
372 //            M[c].yz -> gvec2(Mt[1][c], Mt[2][c])
373 //            MArr    -> MType[D1]..[DN](transpose(MtArr[0]...[0]), ...)
374 //            s       -> copy_St_to_S(st)
375 //            sArr    -> SType[D1]...[DN](copy_St_to_S(stArr[0]..[0]), ...)
376 //            (matrix reads through struct are transformed similarly to M)
377 //
378 // 4. When written to, the following transformations are applied:
379 //
380 //      M = exp       -> Mt = transpose(exp)
381 //      M[c] = exp    -> temp = exp
382 //                       Mt[0][c] = temp[0]
383 //                       Mt[1][c] = temp[1]
384 //                       ...
385 //                       Mt[N-1][c] = temp[N-1]
386 //      M[c][r] = exp -> Mt[r][c] = exp
387 //      M[c].yz = exp -> temp = exp
388 //                       Mt[1][c] = temp[0]
389 //                       Mt[2][c] = temp[1]
390 //      MArr = exp    -> temp = exp
391 //                       Mt = MtType[D1]..[DN](temp([0]...[0]), ...)
392 //      s = exp       -> st = copy_S_to_St(exp)
393 //      sArr = exp    -> temp = exp
394 //                       St = StType[D1]...[DN](copy_S_to_St(temp[0]..[0]), ...)
395 //      (matrix writes through struct are transformed similarly to M)
396 //
397 // 5. If any of the above is passed to an `inout` parameter, both transformations are applied:
398 //
399 //            f(M[c]) -> temp = gvecN(Mt[0][c], Mt[1][c], ..., Mt[N-1][c])
400 //                       f(temp)
401 //                       Mt[0][c] = temp[0]
402 //                       Mt[1][c] = temp[1]
403 //                       ...
404 //                       Mt[N-1][c] = temp[N-1]
405 //
406 //               f(s) -> temp = copy_St_to_S(st)
407 //                       f(temp)
408 //                       st = copy_S_to_St(temp)
409 //
410 //    If passed to an `out` parameter, the `temp` parameter is simply not initialized.
411 //
412 // 6. If the expression leading to the matrix or struct has array subscripts, temp values are
413 //    created for them to avoid duplicating side effects.
414 //
415 class RewriteRowMajorMatricesTraverser : public TIntermTraverser
416 {
417   public:
RewriteRowMajorMatricesTraverser(TCompiler * compiler,TSymbolTable * symbolTable)418     RewriteRowMajorMatricesTraverser(TCompiler *compiler, TSymbolTable *symbolTable)
419         : TIntermTraverser(true, true, true, symbolTable),
420           mCompiler(compiler),
421           mStructMapOut(&mOuterPass.structMap),
422           mInterfaceBlockMap(&mOuterPass.interfaceBlockMap),
423           mInterfaceBlockFieldConvertedIn(mOuterPass.interfaceBlockFieldConverted),
424           mCopyFunctionDefinitionsOut(&mOuterPass.copyFunctionDefinitions),
425           mOuterTraverser(nullptr),
426           mInnerPassRoot(nullptr),
427           mIsProcessingInnerPassSubtree(false)
428     {}
429 
visitDeclaration(Visit visit,TIntermDeclaration * node)430     bool visitDeclaration(Visit visit, TIntermDeclaration *node) override
431     {
432         // No need to process declarations in inner passes.
433         if (mInnerPassRoot != nullptr)
434         {
435             return true;
436         }
437 
438         if (visit != PreVisit)
439         {
440             return true;
441         }
442 
443         const TIntermSequence &sequence = *(node->getSequence());
444 
445         TIntermTyped *variable = sequence.front()->getAsTyped();
446         const TType &type      = variable->getType();
447 
448         // If it's a struct declaration that has matrices, remember it.  If a row-major instance
449         // of it is created, it will have to be converted.
450         if (type.isStructSpecifier() && type.isStructureContainingMatrices())
451         {
452             const TStructure *structure = type.getStruct();
453             ASSERT(structure);
454 
455             ASSERT(mOuterPass.structMap.count(structure) == 0);
456 
457             StructConversionData structData;
458             mOuterPass.structMap[structure] = structData;
459 
460             return false;
461         }
462 
463         // If it's an interface block, it may have to be converted if it contains any row-major
464         // fields.
465         if (type.isInterfaceBlock() && type.getInterfaceBlock()->containsMatrices())
466         {
467             const TInterfaceBlock *block = type.getInterfaceBlock();
468             ASSERT(block);
469             bool isBlockRowMajor = type.getLayoutQualifier().matrixPacking == EmpRowMajor;
470 
471             const TFieldList &fields = block->fields();
472             bool anyRowMajor         = isBlockRowMajor;
473 
474             for (const TField *field : fields)
475             {
476                 if (DoesFieldContainRowMajorMatrix(field, isBlockRowMajor))
477                 {
478                     anyRowMajor = true;
479                     break;
480                 }
481             }
482 
483             if (anyRowMajor)
484             {
485                 convertInterfaceBlock(node);
486             }
487 
488             return false;
489         }
490 
491         return true;
492     }
493 
visitSymbol(TIntermSymbol * symbol)494     void visitSymbol(TIntermSymbol *symbol) override
495     {
496         // If in inner pass, only process if the symbol is under that root.
497         if (mInnerPassRoot != nullptr && !mIsProcessingInnerPassSubtree)
498         {
499             return;
500         }
501 
502         const TVariable *variable = &symbol->variable();
503         bool needsRewrite         = mInterfaceBlockMap->count(variable) != 0;
504 
505         // If it's a field of a nameless interface block, it may still need conversion.
506         if (!needsRewrite)
507         {
508             // Nameless interface block field symbols have the interface block pointer set, but are
509             // not interface blocks.
510             if (symbol->getType().getInterfaceBlock() && !variable->getType().isInterfaceBlock())
511             {
512                 needsRewrite = convertNamelessInterfaceBlockField(symbol);
513             }
514         }
515 
516         if (needsRewrite)
517         {
518             transformExpression(symbol);
519         }
520     }
521 
visitBinary(Visit visit,TIntermBinary * node)522     bool visitBinary(Visit visit, TIntermBinary *node) override
523     {
524         if (node == mInnerPassRoot)
525         {
526             // We only want to process the right-hand side of an assignment in inner passes.  When
527             // visit is InVisit, the left-hand side is already processed, and the right-hand side is
528             // next.  Set a flag to mark this duration.
529             mIsProcessingInnerPassSubtree = visit == InVisit;
530         }
531 
532         return true;
533     }
534 
getStructCopyFunctions()535     TIntermSequence *getStructCopyFunctions() { return &mOuterPass.copyFunctionDefinitions; }
536 
537   private:
538     typedef angle::HashMap<const TStructure *, StructConversionData> StructMap;
539     typedef angle::HashMap<const TVariable *, TVariable *> InterfaceBlockMap;
540     typedef angle::HashMap<const TField *, bool> InterfaceBlockFieldConverted;
541 
RewriteRowMajorMatricesTraverser(TSymbolTable * symbolTable,RewriteRowMajorMatricesTraverser * outerTraverser,InterfaceBlockMap * interfaceBlockMap,const InterfaceBlockFieldConverted & interfaceBlockFieldConverted,StructMap * structMap,TIntermSequence * copyFunctionDefinitions,TIntermBinary * innerPassRoot)542     RewriteRowMajorMatricesTraverser(
543         TSymbolTable *symbolTable,
544         RewriteRowMajorMatricesTraverser *outerTraverser,
545         InterfaceBlockMap *interfaceBlockMap,
546         const InterfaceBlockFieldConverted &interfaceBlockFieldConverted,
547         StructMap *structMap,
548         TIntermSequence *copyFunctionDefinitions,
549         TIntermBinary *innerPassRoot)
550         : TIntermTraverser(true, true, true, symbolTable),
551           mStructMapOut(structMap),
552           mInterfaceBlockMap(interfaceBlockMap),
553           mInterfaceBlockFieldConvertedIn(interfaceBlockFieldConverted),
554           mCopyFunctionDefinitionsOut(copyFunctionDefinitions),
555           mOuterTraverser(outerTraverser),
556           mInnerPassRoot(innerPassRoot),
557           mIsProcessingInnerPassSubtree(false)
558     {}
559 
convertInterfaceBlock(TIntermDeclaration * node)560     void convertInterfaceBlock(TIntermDeclaration *node)
561     {
562         ASSERT(mInnerPassRoot == nullptr);
563 
564         const TIntermSequence &sequence = *(node->getSequence());
565 
566         TIntermTyped *variableNode   = sequence.front()->getAsTyped();
567         const TType &type            = variableNode->getType();
568         const TInterfaceBlock *block = type.getInterfaceBlock();
569         ASSERT(block);
570 
571         bool isBlockRowMajor = type.getLayoutQualifier().matrixPacking == EmpRowMajor;
572 
573         // Recreate the struct with its row-major fields converted to column-major equivalents.
574         TIntermSequence newDeclarations;
575 
576         TFieldList *newFields = new TFieldList;
577         for (const TField *field : block->fields())
578         {
579             TField *newField = nullptr;
580 
581             if (DoesFieldContainRowMajorMatrix(field, isBlockRowMajor))
582             {
583                 newField = convertField(field, &newDeclarations);
584 
585                 // Remember that this field was converted.
586                 mOuterPass.interfaceBlockFieldConverted[field] = true;
587             }
588             else
589             {
590                 newField = DuplicateField(field);
591             }
592 
593             newFields->push_back(newField);
594         }
595 
596         // Create a new interface block with these fields.
597         TLayoutQualifier blockLayoutQualifier = type.getLayoutQualifier();
598         blockLayoutQualifier.matrixPacking    = EmpColumnMajor;
599 
600         TInterfaceBlock *newInterfaceBlock =
601             new TInterfaceBlock(mSymbolTable, block->name(), newFields, blockLayoutQualifier,
602                                 block->symbolType(), block->extensions());
603 
604         // Create a new declaration with the new type.  Declarations are separated at this point,
605         // so there should be only one variable here.
606         ASSERT(sequence.size() == 1);
607 
608         TType *newInterfaceBlockType =
609             new TType(newInterfaceBlock, type.getQualifier(), blockLayoutQualifier);
610 
611         TIntermDeclaration *newDeclaration = new TIntermDeclaration;
612         const TVariable *variable          = &variableNode->getAsSymbolNode()->variable();
613 
614         const TType *newType = newInterfaceBlockType;
615         if (type.isArray())
616         {
617             TType *newArrayType = new TType(*newType);
618             CopyArraySizes(&type, newArrayType);
619             newType = newArrayType;
620         }
621 
622         // If the interface block variable itself is temp, use an empty name.
623         bool variableIsTemp = variable->symbolType() == SymbolType::Empty;
624         const ImmutableString &variableName =
625             variableIsTemp ? kEmptyImmutableString : variable->name();
626 
627         TVariable *newVariable = new TVariable(mSymbolTable, variableName, newType,
628                                                variable->symbolType(), variable->extensions());
629 
630         newDeclaration->appendDeclarator(new TIntermSymbol(newVariable));
631 
632         mOuterPass.interfaceBlockMap[variable] = newVariable;
633 
634         newDeclarations.push_back(newDeclaration);
635 
636         // Replace the interface block definition with the new one, prepending any new struct
637         // definitions.
638         mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node,
639                                         std::move(newDeclarations));
640     }
641 
convertNamelessInterfaceBlockField(TIntermSymbol * symbol)642     bool convertNamelessInterfaceBlockField(TIntermSymbol *symbol)
643     {
644         const TVariable *variable             = &symbol->variable();
645         const TInterfaceBlock *interfaceBlock = symbol->getType().getInterfaceBlock();
646 
647         // Find the variable corresponding to this interface block.  If the interface block
648         // is not rewritten, or this refers to a field that is not rewritten, there's
649         // nothing to do.
650         for (auto iter : *mInterfaceBlockMap)
651         {
652             // Skip other rewritten nameless interface block fields.
653             if (!iter.first->getType().isInterfaceBlock())
654             {
655                 continue;
656             }
657 
658             // Skip if this is not a field of this rewritten interface block.
659             if (iter.first->getType().getInterfaceBlock() != interfaceBlock)
660             {
661                 continue;
662             }
663 
664             const ImmutableString symbolName = symbol->getName();
665 
666             // Find which field it is
667             const TVector<TField *> fields = interfaceBlock->fields();
668             const size_t fieldIndex        = variable->getType().getInterfaceBlockFieldIndex();
669             ASSERT(fieldIndex < fields.size());
670 
671             const TField *field = fields[fieldIndex];
672             ASSERT(field->name() == symbolName);
673 
674             // If this field doesn't need a rewrite, there's nothing to do.
675             if (mInterfaceBlockFieldConvertedIn.count(field) == 0 ||
676                 !mInterfaceBlockFieldConvertedIn.at(field))
677             {
678                 break;
679             }
680 
681             // Create a new variable that references the replaced interface block.
682             TType *newType = new TType(variable->getType());
683             newType->setInterfaceBlockField(iter.second->getType().getInterfaceBlock(), fieldIndex);
684 
685             TVariable *newVariable = new TVariable(mSymbolTable, variable->name(), newType,
686                                                    variable->symbolType(), variable->extensions());
687 
688             (*mInterfaceBlockMap)[variable] = newVariable;
689 
690             return true;
691         }
692 
693         return false;
694     }
695 
convertStruct(const TStructure * structure,TIntermSequence * newDeclarations)696     void convertStruct(const TStructure *structure, TIntermSequence *newDeclarations)
697     {
698         ASSERT(mInnerPassRoot == nullptr);
699 
700         ASSERT(mOuterPass.structMap.count(structure) != 0);
701         StructConversionData *structData = &mOuterPass.structMap[structure];
702 
703         if (structData->convertedStruct)
704         {
705             return;
706         }
707 
708         TFieldList *newFields = new TFieldList;
709         for (const TField *field : structure->fields())
710         {
711             newFields->push_back(convertField(field, newDeclarations));
712         }
713 
714         // Create unique names for the converted structs.  We can't leave them nameless and have
715         // a name autogenerated similar to temp variables, as nameless structs exist.  A fake
716         // variable is created for the sole purpose of generating a temp name.
717         TVariable *newStructTypeName =
718             new TVariable(mSymbolTable, kEmptyImmutableString,
719                           StaticType::GetBasic<EbtUInt, EbpUndefined>(), SymbolType::Empty);
720 
721         TStructure *newStruct = new TStructure(mSymbolTable, newStructTypeName->name(), newFields,
722                                                SymbolType::AngleInternal);
723         TType *newType        = new TType(newStruct, true);
724         TVariable *newStructVar =
725             new TVariable(mSymbolTable, kEmptyImmutableString, newType, SymbolType::Empty);
726 
727         TIntermDeclaration *structDecl = new TIntermDeclaration;
728         structDecl->appendDeclarator(new TIntermSymbol(newStructVar));
729 
730         newDeclarations->push_back(structDecl);
731 
732         structData->convertedStruct = newStruct;
733     }
734 
convertField(const TField * field,TIntermSequence * newDeclarations)735     TField *convertField(const TField *field, TIntermSequence *newDeclarations)
736     {
737         ASSERT(mInnerPassRoot == nullptr);
738 
739         TField *newField = nullptr;
740 
741         const TType *fieldType = field->type();
742         TType *newType         = nullptr;
743 
744         if (fieldType->isStructureContainingMatrices())
745         {
746             // If the field is a struct instance, convert the struct and replace the field
747             // with an instance of the new struct.
748             const TStructure *fieldTypeStruct = fieldType->getStruct();
749             convertStruct(fieldTypeStruct, newDeclarations);
750 
751             StructConversionData &structData = mOuterPass.structMap[fieldTypeStruct];
752             newType                          = new TType(structData.convertedStruct, false);
753             SetColumnMajor(newType);
754             CopyArraySizes(fieldType, newType);
755         }
756         else if (fieldType->isMatrix())
757         {
758             // If the field is a matrix, transpose the matrix and replace the field with
759             // that, removing the matrix packing qualifier.
760             newType = TransposeMatrixType(fieldType);
761         }
762 
763         if (newType)
764         {
765             newField = new TField(newType, field->name(), field->line(), field->symbolType());
766         }
767         else
768         {
769             newField = DuplicateField(field);
770         }
771 
772         return newField;
773     }
774 
determineAccess(TIntermNode * expression,TIntermNode * accessor,bool * isReadOut,bool * isWriteOut)775     void determineAccess(TIntermNode *expression,
776                          TIntermNode *accessor,
777                          bool *isReadOut,
778                          bool *isWriteOut)
779     {
780         // If passing to a function, look at whether the parameter is in, out or inout.
781         TIntermAggregate *functionCall = accessor->getAsAggregate();
782 
783         if (functionCall)
784         {
785             TIntermSequence *arguments = functionCall->getSequence();
786             for (size_t argIndex = 0; argIndex < arguments->size(); ++argIndex)
787             {
788                 if ((*arguments)[argIndex] == expression)
789                 {
790                     TQualifier qualifier = EvqParamIn;
791 
792                     // If the aggregate is not a function call, it's a constructor, and so every
793                     // argument is an input.
794                     const TFunction *function = functionCall->getFunction();
795                     if (function)
796                     {
797                         const TVariable *param = function->getParam(argIndex);
798                         qualifier              = param->getType().getQualifier();
799                     }
800 
801                     *isReadOut  = qualifier != EvqParamOut;
802                     *isWriteOut = qualifier == EvqParamOut || qualifier == EvqParamInOut;
803                     break;
804                 }
805             }
806             return;
807         }
808 
809         TIntermBinary *assignment = accessor->getAsBinaryNode();
810         if (assignment && IsAssignment(assignment->getOp()))
811         {
812             // If expression is on the right of assignment, it's being read from.
813             *isReadOut = assignment->getRight() == expression;
814             // If it's on the left of assignment, it's being written to.
815             *isWriteOut = assignment->getLeft() == expression;
816             return;
817         }
818 
819         // Any other usage is a read.
820         *isReadOut  = true;
821         *isWriteOut = false;
822     }
823 
transformExpression(TIntermSymbol * symbol)824     void transformExpression(TIntermSymbol *symbol)
825     {
826         // Walk up the parent chain while the nodes are EOpIndex* (whether array indexing or struct
827         // field selection) or swizzle and construct the replacement expression.  This traversal can
828         // lead to one of the following possibilities:
829         //
830         // - a.b[N].etc.s (struct, or struct array): copy function should be declared and used,
831         // - a.b[N].etc.M (matrix or matrix array): transpose() should be used,
832         // - a.b[N].etc.M[c] (a column): each element in column needs to be handled separately,
833         // - a.b[N].etc.M[c].yz (multiple elements): similar to whole column, but a subset of
834         //   elements,
835         // - a.b[N].etc.M[c][r] (an element): single element to handle.
836         // - a.b[N].etc.x (not struct or matrix): not modified
837         //
838         // primaryIndex will contain c, if any.  secondaryIndices will contain {0, ..., R-1}
839         // (if no [r] or swizzle), {r} (if [r]), or {1, 2} (corresponding to .yz) if any.
840         //
841         // In all cases, the base symbol is replaced.  |baseExpression| will contain everything up
842         // to (and not including) the last index/swizzle operations, i.e. a.b[N].etc.s/M/x.  Any
843         // non constant array subscript is assigned to a temp variable to avoid duplicating side
844         // effects.
845         //
846         // ---
847         //
848         // NOTE that due to the use of insertStatementsInParentBlock, cases like this will be
849         // mistranslated, and this bug is likely present in most transformations that use this
850         // feature:
851         //
852         //     if (x == 1 && a.b[x = 2].etc.M = value)
853         //
854         // which will translate to:
855         //
856         //     temp = (x = 2)
857         //     if (x == 1 && a.b[temp].etc.M = transpose(value))
858         //
859         // See http://anglebug.com/42262472.
860         //
861         TIntermTyped *baseExpression =
862             new TIntermSymbol(mInterfaceBlockMap->at(&symbol->variable()));
863         const TStructure *structure = nullptr;
864 
865         TIntermNode *primaryIndex = nullptr;
866         TIntermSequence secondaryIndices;
867 
868         // In some cases, it is necessary to prepend or append statements.  Those are captured in
869         // |prependStatements| and |appendStatements|.
870         TIntermSequence prependStatements;
871         TIntermSequence appendStatements;
872 
873         // If the expression is neither a struct or matrix, no modification is necessary.
874         // If it's a struct that doesn't have matrices, again there's no transformation necessary.
875         // If it's an interface block matrix field that didn't need to be transposed, no
876         // transpformation is necessary.
877         //
878         // In all these cases, |baseExpression| contains all of the original expression.
879         //
880         // If the starting symbol itself is a field of a nameless interface block, it needs
881         // conversion if we reach here.
882         bool requiresTransformation = !symbol->getType().isInterfaceBlock();
883 
884         uint32_t accessorIndex         = 0;
885         TIntermTyped *previousAncestor = symbol;
886         while (IsIndexNode(getAncestorNode(accessorIndex), previousAncestor))
887         {
888             TIntermTyped *ancestor = getAncestorNode(accessorIndex)->getAsTyped();
889             ASSERT(ancestor);
890 
891             const TType &previousAncestorType = previousAncestor->getType();
892 
893             TIntermSequence indices;
894             TOperator op = GetIndex(mSymbolTable, ancestor, &indices, &prependStatements);
895 
896             bool opIsIndex     = op == EOpIndexDirect || op == EOpIndexIndirect;
897             bool isArrayIndex  = opIsIndex && previousAncestorType.isArray();
898             bool isMatrixIndex = opIsIndex && previousAncestorType.isMatrix();
899 
900             // If it's a direct index in a matrix, it's the primary index.
901             bool isMatrixPrimarySubscript = isMatrixIndex && !isArrayIndex;
902             ASSERT(!isMatrixPrimarySubscript ||
903                    (primaryIndex == nullptr && secondaryIndices.empty()));
904             // If primary index is seen and the ancestor is still an index, it must be a direct
905             // index as the secondary one.  Note that if primaryIndex is set, there can only ever be
906             // one more parent of interest, and that's subscripting the second dimension.
907             bool isMatrixSecondarySubscript = primaryIndex != nullptr;
908             ASSERT(!isMatrixSecondarySubscript || (opIsIndex && !isArrayIndex));
909 
910             if (requiresTransformation && isMatrixPrimarySubscript)
911             {
912                 ASSERT(indices.size() == 1);
913                 primaryIndex = indices.front();
914 
915                 // Default the secondary indices to include every row.  If there's a secondary
916                 // subscript provided, it will override this.
917                 const uint8_t rows = previousAncestorType.getRows();
918                 for (uint8_t r = 0; r < rows; ++r)
919                 {
920                     secondaryIndices.push_back(CreateIndexNode(r));
921                 }
922             }
923             else if (isMatrixSecondarySubscript)
924             {
925                 ASSERT(requiresTransformation);
926 
927                 secondaryIndices = indices;
928 
929                 // Indices after this point are not interesting.  There can't actually be any other
930                 // index nodes.  Unlike desktop GLSL, ESSL does not support swizzles on scalars
931                 // (like M[1][2].yyy).
932                 ++accessorIndex;
933                 break;
934             }
935             else
936             {
937                 // Replicate the expression otherwise.
938                 baseExpression =
939                     ReplicateIndexNode(mSymbolTable, ancestor, baseExpression, &indices);
940 
941                 const TType &ancestorType = ancestor->getType();
942                 structure                 = ancestorType.getStruct();
943 
944                 requiresTransformation =
945                     requiresTransformation ||
946                     IsConvertedField(ancestor, mInterfaceBlockFieldConvertedIn);
947 
948                 // If we reach a point where the expression is neither a matrix-containing struct
949                 // nor a matrix, there's no transformation required.  This can happen if we decend
950                 // through a struct marked with row-major but arrive at a member that doesn't
951                 // include a matrix.
952                 if (!ancestorType.isMatrix() && !ancestorType.isStructureContainingMatrices())
953                 {
954                     requiresTransformation = false;
955                 }
956             }
957 
958             previousAncestor = ancestor;
959             ++accessorIndex;
960         }
961 
962         TIntermNode *originalExpression =
963             accessorIndex == 0 ? symbol : getAncestorNode(accessorIndex - 1);
964         TIntermNode *accessor = getAncestorNode(accessorIndex);
965 
966         // if accessor is EOpArrayLength, we don't need to perform any transformations either.
967         // Note that this only applies to unsized arrays, as the RemoveArrayLengthMethod()
968         // transformation would have removed this operation otherwise.
969         TIntermUnary *accessorAsUnary = accessor->getAsUnaryNode();
970         if (requiresTransformation && accessorAsUnary && accessorAsUnary->getOp() == EOpArrayLength)
971         {
972             ASSERT(accessorAsUnary->getOperand() == originalExpression);
973             ASSERT(accessorAsUnary->getOperand()->getType().isUnsizedArray());
974 
975             requiresTransformation = false;
976 
977             // We need to replace the whole expression including the EOpArrayLength, to avoid
978             // confusing the replacement code as the original and new expressions don't have the
979             // same type (one is the transpose of the other).  This doesn't affect the .length()
980             // operation, so this replacement is ok, though it's not worth special-casing this in
981             // the node replacement algorithm.
982             //
983             // Note: the |if (!requiresTransformation)| immediately below will be entered after
984             // this.
985             originalExpression = accessor;
986             accessor           = getAncestorNode(accessorIndex + 1);
987             baseExpression     = new TIntermUnary(EOpArrayLength, baseExpression, nullptr);
988         }
989 
990         if (!requiresTransformation)
991         {
992             ASSERT(primaryIndex == nullptr);
993             queueReplacementWithParent(accessor, originalExpression, baseExpression,
994                                        OriginalNode::IS_DROPPED);
995 
996             RewriteRowMajorMatricesTraverser *traverser = mOuterTraverser ? mOuterTraverser : this;
997             traverser->insertStatementsInParentBlock(prependStatements, appendStatements);
998             return;
999         }
1000 
1001         ASSERT(structure == nullptr || primaryIndex == nullptr);
1002         ASSERT(structure != nullptr || baseExpression->getType().isMatrix());
1003 
1004         // At the end, we can determine if the expression is being read from or written to (or both,
1005         // if sent as an inout parameter to a function).  For the sake of the transformation, the
1006         // left-hand side of operations like += can be treated as "written to", without necessarily
1007         // "read from".
1008         bool isRead  = false;
1009         bool isWrite = false;
1010 
1011         determineAccess(originalExpression, accessor, &isRead, &isWrite);
1012 
1013         ASSERT(isRead || isWrite);
1014 
1015         TIntermTyped *readExpression = nullptr;
1016         if (isRead)
1017         {
1018             readExpression = transformReadExpression(
1019                 baseExpression, primaryIndex, &secondaryIndices, structure, &prependStatements);
1020 
1021             // If both read from and written to (i.e. passed to inout parameter), store the
1022             // expression in a temp variable and pass that to the function.
1023             if (isWrite)
1024             {
1025                 readExpression =
1026                     CopyToTempVariable(mSymbolTable, readExpression, &prependStatements);
1027             }
1028 
1029             // Replace the original expression with the transformed one.  Read transformations
1030             // always generate a single expression that can be used in place of the original (as
1031             // oppposed to write transformations that can generate multiple statements).
1032             queueReplacementWithParent(accessor, originalExpression, readExpression,
1033                                        OriginalNode::IS_DROPPED);
1034         }
1035 
1036         TIntermSequence postTransformPrependStatements;
1037         TIntermSequence *writeStatements = &appendStatements;
1038         TOperator assignmentOperator     = EOpAssign;
1039 
1040         if (isWrite)
1041         {
1042             TIntermTyped *valueExpression = readExpression;
1043 
1044             if (!valueExpression)
1045             {
1046                 // If there's already a read expression, this was an inout parameter and
1047                 // |valueExpression| will contain the temp variable that was passed to the function
1048                 // instead.
1049                 //
1050                 // If not, then the modification is either through being passed as an out parameter
1051                 // to a function, or an assignment.  In the former case, create a temp variable to
1052                 // be passed to the function.  In the latter case, create a temp variable that holds
1053                 // the right hand side expression.
1054                 //
1055                 // In either case, use that temp value as the value to assign to |baseExpression|.
1056 
1057                 TVariable *temp = CreateTempVariable(
1058                     mSymbolTable, &originalExpression->getAsTyped()->getType(), EvqTemporary);
1059                 TIntermDeclaration *tempDecl = nullptr;
1060 
1061                 valueExpression = new TIntermSymbol(temp);
1062 
1063                 TIntermBinary *assignment = accessor->getAsBinaryNode();
1064                 if (assignment)
1065                 {
1066                     assignmentOperator = assignment->getOp();
1067                     ASSERT(IsAssignment(assignmentOperator));
1068 
1069                     // We are converting the assignment to the left-hand side of an expression in
1070                     // the form M=exp.  A subexpression of exp itself could require a
1071                     // transformation.  This complicates things as there would be two replacements:
1072                     //
1073                     // - Replace M=exp with temp (because the return value of the assignment could
1074                     //   be used)
1075                     // - Replace exp with exp2, where parent is M=exp
1076                     //
1077                     // The second replacement however is ineffective as the whole of M=exp is
1078                     // already transformed.  What's worse, M=exp is transformed without taking exp's
1079                     // transformations into account.  To address this issue, this same traverser is
1080                     // called on the right-hand side expression, with a special flag such that it
1081                     // only processes that expression.
1082                     //
1083                     RewriteRowMajorMatricesTraverser *outerTraverser =
1084                         mOuterTraverser ? mOuterTraverser : this;
1085                     RewriteRowMajorMatricesTraverser rhsTraverser(
1086                         mSymbolTable, outerTraverser, mInterfaceBlockMap,
1087                         mInterfaceBlockFieldConvertedIn, mStructMapOut, mCopyFunctionDefinitionsOut,
1088                         assignment);
1089                     getRootNode()->traverse(&rhsTraverser);
1090                     bool valid = rhsTraverser.updateTree(mCompiler, getRootNode());
1091                     ASSERT(valid);
1092 
1093                     tempDecl = CreateTempInitDeclarationNode(temp, assignment->getRight());
1094 
1095                     // Replace the whole assignment expression with the right-hand side as a read
1096                     // expression, in case the result of the assignment is used.  For example, this
1097                     // transforms:
1098                     //
1099                     //     if ((M += exp) == X)
1100                     //     {
1101                     //         // use M
1102                     //     }
1103                     //
1104                     // to:
1105                     //
1106                     //     temp = exp;
1107                     //     M += transform(temp);
1108                     //     if (transform(M) == X)
1109                     //     {
1110                     //         // use M
1111                     //     }
1112                     //
1113                     // Note that in this case the assignment to M must be prepended in the parent
1114                     // block.  In contrast, when sent to a function, the assignment to M should be
1115                     // done after the current function call is done.
1116                     //
1117                     // If the read from M itself (to replace assigmnet) needs to generate extra
1118                     // statements, they should be appended after the statements that write to M.
1119                     // These statements are stored in postTransformPrependStatements and appended to
1120                     // prependStatements in the end.
1121                     //
1122                     writeStatements = &prependStatements;
1123 
1124                     TIntermTyped *assignmentResultExpression = transformReadExpression(
1125                         baseExpression->deepCopy(), primaryIndex, &secondaryIndices, structure,
1126                         &postTransformPrependStatements);
1127 
1128                     // Replace the whole assignment, instead of just the right hand side.
1129                     TIntermNode *accessorParent = getAncestorNode(accessorIndex + 1);
1130                     queueReplacementWithParent(accessorParent, accessor, assignmentResultExpression,
1131                                                OriginalNode::IS_DROPPED);
1132                 }
1133                 else
1134                 {
1135                     tempDecl = CreateTempDeclarationNode(temp);
1136 
1137                     // Replace the write expression (a function call argument) with the temp
1138                     // variable.
1139                     queueReplacementWithParent(accessor, originalExpression, valueExpression,
1140                                                OriginalNode::IS_DROPPED);
1141                 }
1142                 prependStatements.push_back(tempDecl);
1143             }
1144 
1145             if (isRead)
1146             {
1147                 baseExpression = baseExpression->deepCopy();
1148             }
1149             transformWriteExpression(baseExpression, primaryIndex, &secondaryIndices, structure,
1150                                      valueExpression, assignmentOperator, writeStatements);
1151         }
1152 
1153         prependStatements.insert(prependStatements.end(), postTransformPrependStatements.begin(),
1154                                  postTransformPrependStatements.end());
1155 
1156         RewriteRowMajorMatricesTraverser *traverser = mOuterTraverser ? mOuterTraverser : this;
1157         traverser->insertStatementsInParentBlock(prependStatements, appendStatements);
1158     }
1159 
transformReadExpression(TIntermTyped * baseExpression,TIntermNode * primaryIndex,TIntermSequence * secondaryIndices,const TStructure * structure,TIntermSequence * prependStatements)1160     TIntermTyped *transformReadExpression(TIntermTyped *baseExpression,
1161                                           TIntermNode *primaryIndex,
1162                                           TIntermSequence *secondaryIndices,
1163                                           const TStructure *structure,
1164                                           TIntermSequence *prependStatements)
1165     {
1166         const TType &baseExpressionType = baseExpression->getType();
1167 
1168         if (structure)
1169         {
1170             ASSERT(primaryIndex == nullptr && secondaryIndices->empty());
1171             ASSERT(mStructMapOut->count(structure) != 0);
1172             ASSERT((*mStructMapOut)[structure].convertedStruct != nullptr);
1173 
1174             // Declare copy-from-converted-to-original-struct function (if not already).
1175             declareStructCopyToOriginal(structure);
1176 
1177             const TFunction *copyToOriginal = (*mStructMapOut)[structure].copyToOriginal;
1178 
1179             if (baseExpressionType.isArray())
1180             {
1181                 // If base expression is an array, transform every element.
1182                 TransformArrayHelper transformHelper(baseExpression);
1183 
1184                 TIntermTyped *element = nullptr;
1185                 while ((element = transformHelper.getNextElement(nullptr, nullptr)) != nullptr)
1186                 {
1187                     TIntermTyped *transformedElement =
1188                         CreateStructCopyCall(copyToOriginal, element);
1189                     transformHelper.accumulateForRead(mSymbolTable, transformedElement,
1190                                                       prependStatements);
1191                 }
1192                 return transformHelper.constructReadTransformExpression();
1193             }
1194             else
1195             {
1196                 // If not reading an array, the result is simply a call to this function with the
1197                 // base expression.
1198                 return CreateStructCopyCall(copyToOriginal, baseExpression);
1199             }
1200         }
1201 
1202         // If not indexed, the result is transpose(exp)
1203         if (primaryIndex == nullptr)
1204         {
1205             ASSERT(secondaryIndices->empty());
1206 
1207             if (baseExpressionType.isArray())
1208             {
1209                 // If array, transpose every element.
1210                 TransformArrayHelper transformHelper(baseExpression);
1211 
1212                 TIntermTyped *element = nullptr;
1213                 while ((element = transformHelper.getNextElement(nullptr, nullptr)) != nullptr)
1214                 {
1215                     TIntermTyped *transformedElement = CreateTransposeCall(mSymbolTable, element);
1216                     transformHelper.accumulateForRead(mSymbolTable, transformedElement,
1217                                                       prependStatements);
1218                 }
1219                 return transformHelper.constructReadTransformExpression();
1220             }
1221             else
1222             {
1223                 return CreateTransposeCall(mSymbolTable, baseExpression);
1224             }
1225         }
1226 
1227         // If indexed the result is a vector (or just one element) where the primary and secondary
1228         // indices are swapped.
1229         ASSERT(!secondaryIndices->empty());
1230 
1231         TOperator primaryIndexOp          = GetIndexOp(primaryIndex);
1232         TIntermTyped *primaryIndexAsTyped = primaryIndex->getAsTyped();
1233 
1234         TIntermSequence transposedColumn;
1235         for (TIntermNode *secondaryIndex : *secondaryIndices)
1236         {
1237             TOperator secondaryIndexOp          = GetIndexOp(secondaryIndex);
1238             TIntermTyped *secondaryIndexAsTyped = secondaryIndex->getAsTyped();
1239 
1240             TIntermBinary *colIndexed = new TIntermBinary(
1241                 secondaryIndexOp, baseExpression->deepCopy(), secondaryIndexAsTyped->deepCopy());
1242             TIntermBinary *colRowIndexed =
1243                 new TIntermBinary(primaryIndexOp, colIndexed, primaryIndexAsTyped->deepCopy());
1244 
1245             transposedColumn.push_back(colRowIndexed);
1246         }
1247 
1248         if (secondaryIndices->size() == 1)
1249         {
1250             // If only one element, return that directly.
1251             return transposedColumn.front()->getAsTyped();
1252         }
1253 
1254         // Otherwise create a constructor with the appropriate dimension.
1255         TType *vecType = new TType(baseExpressionType.getBasicType(), secondaryIndices->size());
1256         return TIntermAggregate::CreateConstructor(*vecType, &transposedColumn);
1257     }
1258 
transformWriteExpression(TIntermTyped * baseExpression,TIntermNode * primaryIndex,TIntermSequence * secondaryIndices,const TStructure * structure,TIntermTyped * valueExpression,TOperator assignmentOperator,TIntermSequence * writeStatements)1259     void transformWriteExpression(TIntermTyped *baseExpression,
1260                                   TIntermNode *primaryIndex,
1261                                   TIntermSequence *secondaryIndices,
1262                                   const TStructure *structure,
1263                                   TIntermTyped *valueExpression,
1264                                   TOperator assignmentOperator,
1265                                   TIntermSequence *writeStatements)
1266     {
1267         const TType &baseExpressionType = baseExpression->getType();
1268 
1269         if (structure)
1270         {
1271             ASSERT(primaryIndex == nullptr && secondaryIndices->empty());
1272             ASSERT(mStructMapOut->count(structure) != 0);
1273             ASSERT((*mStructMapOut)[structure].convertedStruct != nullptr);
1274 
1275             // Declare copy-to-converted-from-original-struct function (if not already).
1276             declareStructCopyFromOriginal(structure);
1277 
1278             // The result is call to this function with the value expression assigned to base
1279             // expression.
1280             const TFunction *copyFromOriginal = (*mStructMapOut)[structure].copyFromOriginal;
1281 
1282             if (baseExpressionType.isArray())
1283             {
1284                 // If array, assign every element.
1285                 TransformArrayHelper transformHelper(baseExpression);
1286 
1287                 TIntermTyped *element      = nullptr;
1288                 TIntermTyped *valueElement = nullptr;
1289                 while ((element = transformHelper.getNextElement(valueExpression, &valueElement)) !=
1290                        nullptr)
1291                 {
1292                     TIntermTyped *functionCall =
1293                         CreateStructCopyCall(copyFromOriginal, valueElement);
1294                     writeStatements->push_back(new TIntermBinary(EOpAssign, element, functionCall));
1295                 }
1296             }
1297             else
1298             {
1299                 TIntermTyped *functionCall =
1300                     CreateStructCopyCall(copyFromOriginal, valueExpression->deepCopy());
1301                 writeStatements->push_back(
1302                     new TIntermBinary(EOpAssign, baseExpression, functionCall));
1303             }
1304 
1305             return;
1306         }
1307 
1308         // If not indexed, the result is transpose(exp)
1309         if (primaryIndex == nullptr)
1310         {
1311             ASSERT(secondaryIndices->empty());
1312 
1313             if (baseExpressionType.isArray())
1314             {
1315                 // If array, assign every element.
1316                 TransformArrayHelper transformHelper(baseExpression);
1317 
1318                 TIntermTyped *element      = nullptr;
1319                 TIntermTyped *valueElement = nullptr;
1320                 while ((element = transformHelper.getNextElement(valueExpression, &valueElement)) !=
1321                        nullptr)
1322                 {
1323                     TIntermTyped *valueTransposed = CreateTransposeCall(mSymbolTable, valueElement);
1324                     writeStatements->push_back(
1325                         new TIntermBinary(EOpAssign, element, valueTransposed));
1326                 }
1327             }
1328             else
1329             {
1330                 TIntermTyped *valueTransposed =
1331                     CreateTransposeCall(mSymbolTable, valueExpression->deepCopy());
1332                 writeStatements->push_back(
1333                     new TIntermBinary(assignmentOperator, baseExpression, valueTransposed));
1334             }
1335 
1336             return;
1337         }
1338 
1339         // If indexed, create one assignment per secondary index.  If the right-hand side is a
1340         // scalar, it's used with every assignment.  If it's a vector, the assignment is
1341         // per-component.  The right-hand side cannot be a matrix as that would imply left-hand
1342         // side being a matrix too, which is covered above where |primaryIndex == nullptr|.
1343         ASSERT(!secondaryIndices->empty());
1344 
1345         bool isValueExpressionScalar = valueExpression->getType().getNominalSize() == 1;
1346         ASSERT(isValueExpressionScalar || valueExpression->getType().getNominalSize() ==
1347                                               static_cast<int>(secondaryIndices->size()));
1348 
1349         TOperator primaryIndexOp          = GetIndexOp(primaryIndex);
1350         TIntermTyped *primaryIndexAsTyped = primaryIndex->getAsTyped();
1351 
1352         for (TIntermNode *secondaryIndex : *secondaryIndices)
1353         {
1354             TOperator secondaryIndexOp          = GetIndexOp(secondaryIndex);
1355             TIntermTyped *secondaryIndexAsTyped = secondaryIndex->getAsTyped();
1356 
1357             TIntermBinary *colIndexed = new TIntermBinary(
1358                 secondaryIndexOp, baseExpression->deepCopy(), secondaryIndexAsTyped->deepCopy());
1359             TIntermBinary *colRowIndexed =
1360                 new TIntermBinary(primaryIndexOp, colIndexed, primaryIndexAsTyped->deepCopy());
1361 
1362             TIntermTyped *valueExpressionIndexed = valueExpression->deepCopy();
1363             if (!isValueExpressionScalar)
1364             {
1365                 valueExpressionIndexed = new TIntermBinary(secondaryIndexOp, valueExpressionIndexed,
1366                                                            secondaryIndexAsTyped->deepCopy());
1367             }
1368 
1369             writeStatements->push_back(
1370                 new TIntermBinary(assignmentOperator, colRowIndexed, valueExpressionIndexed));
1371         }
1372     }
1373 
getCopyStructFieldFunction(const TType * fromFieldType,const TType * toFieldType,bool isCopyToOriginal)1374     const TFunction *getCopyStructFieldFunction(const TType *fromFieldType,
1375                                                 const TType *toFieldType,
1376                                                 bool isCopyToOriginal)
1377     {
1378         ASSERT(fromFieldType->getStruct());
1379         ASSERT(toFieldType->getStruct());
1380 
1381         // If copying from or to the original struct, the "to" field struct could require
1382         // conversion to or from the "from" field struct.  |isCopyToOriginal| tells us if we
1383         // should expect to find toField or fromField in mStructMapOut, if true or false
1384         // respectively.
1385         const TFunction *fieldCopyFunction = nullptr;
1386         if (isCopyToOriginal)
1387         {
1388             const TStructure *toFieldStruct = toFieldType->getStruct();
1389 
1390             auto iter = mStructMapOut->find(toFieldStruct);
1391             if (iter != mStructMapOut->end())
1392             {
1393                 declareStructCopyToOriginal(toFieldStruct);
1394                 fieldCopyFunction = iter->second.copyToOriginal;
1395             }
1396         }
1397         else
1398         {
1399             const TStructure *fromFieldStruct = fromFieldType->getStruct();
1400 
1401             auto iter = mStructMapOut->find(fromFieldStruct);
1402             if (iter != mStructMapOut->end())
1403             {
1404                 declareStructCopyFromOriginal(fromFieldStruct);
1405                 fieldCopyFunction = iter->second.copyFromOriginal;
1406             }
1407         }
1408 
1409         return fieldCopyFunction;
1410     }
1411 
addFieldCopy(TIntermBlock * body,TIntermTyped * to,TIntermTyped * from,bool isCopyToOriginal)1412     void addFieldCopy(TIntermBlock *body,
1413                       TIntermTyped *to,
1414                       TIntermTyped *from,
1415                       bool isCopyToOriginal)
1416     {
1417         const TType &fromType = from->getType();
1418         const TType &toType   = to->getType();
1419 
1420         TIntermTyped *rhs = from;
1421 
1422         if (fromType.getStruct())
1423         {
1424             const TFunction *fieldCopyFunction =
1425                 getCopyStructFieldFunction(&fromType, &toType, isCopyToOriginal);
1426 
1427             if (fieldCopyFunction)
1428             {
1429                 rhs = CreateStructCopyCall(fieldCopyFunction, from);
1430             }
1431         }
1432         else if (fromType.isMatrix())
1433         {
1434             rhs = CreateTransposeCall(mSymbolTable, from);
1435         }
1436 
1437         body->appendStatement(new TIntermBinary(EOpAssign, to, rhs));
1438     }
1439 
declareStructCopy(const TStructure * from,const TStructure * to,bool isCopyToOriginal)1440     TFunction *declareStructCopy(const TStructure *from,
1441                                  const TStructure *to,
1442                                  bool isCopyToOriginal)
1443     {
1444         TType *fromType = new TType(from, true);
1445         TType *toType   = new TType(to, true);
1446 
1447         // Create the parameter and return value variables.
1448         TVariable *fromVar = new TVariable(mSymbolTable, ImmutableString("from"), fromType,
1449                                            SymbolType::AngleInternal);
1450         TVariable *toVar =
1451             new TVariable(mSymbolTable, ImmutableString("to"), toType, SymbolType::AngleInternal);
1452 
1453         TIntermSymbol *fromSymbol = new TIntermSymbol(fromVar);
1454         TIntermSymbol *toSymbol   = new TIntermSymbol(toVar);
1455 
1456         // Create the function body as statements are generated.
1457         TIntermBlock *body = new TIntermBlock;
1458 
1459         // Declare the result variable.
1460         TIntermDeclaration *toDecl = new TIntermDeclaration();
1461         toDecl->appendDeclarator(toSymbol);
1462         body->appendStatement(toDecl);
1463 
1464         // Iterate over fields of the struct and copy one by one, transposing the matrices.  If a
1465         // struct is encountered that requires a transformation, this function is recursively
1466         // called.  As a result, it is important that the copy functions are placed in the code in
1467         // order.
1468         const TFieldList &fromFields = from->fields();
1469         const TFieldList &toFields   = to->fields();
1470         ASSERT(fromFields.size() == toFields.size());
1471 
1472         for (size_t fieldIndex = 0; fieldIndex < fromFields.size(); ++fieldIndex)
1473         {
1474             TIntermTyped *fieldIndexNode = CreateIndexNode(static_cast<int>(fieldIndex));
1475 
1476             TIntermTyped *fromField =
1477                 new TIntermBinary(EOpIndexDirectStruct, fromSymbol->deepCopy(), fieldIndexNode);
1478             TIntermTyped *toField = new TIntermBinary(EOpIndexDirectStruct, toSymbol->deepCopy(),
1479                                                       fieldIndexNode->deepCopy());
1480 
1481             const TType *fromFieldType = fromFields[fieldIndex]->type();
1482             bool isStructOrMatrix      = fromFieldType->getStruct() || fromFieldType->isMatrix();
1483 
1484             if (fromFieldType->isArray() && isStructOrMatrix)
1485             {
1486                 // If struct or matrix array, we need to copy element by element.
1487                 TransformArrayHelper transformHelper(toField);
1488 
1489                 TIntermTyped *toElement   = nullptr;
1490                 TIntermTyped *fromElement = nullptr;
1491                 while ((toElement = transformHelper.getNextElement(fromField, &fromElement)) !=
1492                        nullptr)
1493                 {
1494                     addFieldCopy(body, toElement, fromElement, isCopyToOriginal);
1495                 }
1496             }
1497             else
1498             {
1499                 addFieldCopy(body, toField, fromField, isCopyToOriginal);
1500             }
1501         }
1502 
1503         // Add return statement.
1504         body->appendStatement(new TIntermBranch(EOpReturn, toSymbol->deepCopy()));
1505 
1506         // Declare the function
1507         TFunction *copyFunction = new TFunction(mSymbolTable, kEmptyImmutableString,
1508                                                 SymbolType::AngleInternal, toType, true);
1509         copyFunction->addParameter(fromVar);
1510 
1511         TIntermFunctionDefinition *functionDef =
1512             CreateInternalFunctionDefinitionNode(*copyFunction, body);
1513         mCopyFunctionDefinitionsOut->push_back(functionDef);
1514 
1515         return copyFunction;
1516     }
1517 
declareStructCopyFromOriginal(const TStructure * structure)1518     void declareStructCopyFromOriginal(const TStructure *structure)
1519     {
1520         StructConversionData *structData = &(*mStructMapOut)[structure];
1521         if (structData->copyFromOriginal)
1522         {
1523             return;
1524         }
1525 
1526         structData->copyFromOriginal =
1527             declareStructCopy(structure, structData->convertedStruct, false);
1528     }
1529 
declareStructCopyToOriginal(const TStructure * structure)1530     void declareStructCopyToOriginal(const TStructure *structure)
1531     {
1532         StructConversionData *structData = &(*mStructMapOut)[structure];
1533         if (structData->copyToOriginal)
1534         {
1535             return;
1536         }
1537 
1538         structData->copyToOriginal =
1539             declareStructCopy(structData->convertedStruct, structure, true);
1540     }
1541 
1542     TCompiler *mCompiler;
1543 
1544     // This traverser can call itself to transform a subexpression before moving on.  However, it
1545     // needs to accumulate conversion functions in inner passes.  The fields below marked with Out
1546     // or In are inherited from the outer pass (for inner passes), or point to storage fields in
1547     // mOuterPass (for the outer pass).  The latter should not be used by the inner passes as they
1548     // would be empty, so they are placed inside a struct to make them explicit.
1549     struct
1550     {
1551         StructMap structMap;
1552         InterfaceBlockMap interfaceBlockMap;
1553         InterfaceBlockFieldConverted interfaceBlockFieldConverted;
1554         TIntermSequence copyFunctionDefinitions;
1555     } mOuterPass;
1556 
1557     // A map from structures with matrices to their converted version.
1558     StructMap *mStructMapOut;
1559     // A map from interface block instances with row-major matrices to their converted variable.  If
1560     // an interface block is nameless, its fields are placed in this map instead.  When a variable
1561     // in this map is encountered, it signals the start of an expression that my need conversion,
1562     // which is either "interfaceBlock.field..." or "field..." if nameless.
1563     InterfaceBlockMap *mInterfaceBlockMap;
1564     // A map from interface block fields to whether they need to be converted.  If a field was
1565     // already column-major, it shouldn't be transposed.
1566     const InterfaceBlockFieldConverted &mInterfaceBlockFieldConvertedIn;
1567 
1568     TIntermSequence *mCopyFunctionDefinitionsOut;
1569 
1570     // If set, it's an inner pass and this will point to the outer pass traverser.  All statement
1571     // insertions are stored in the outer traverser and applied at once in the end.  This prevents
1572     // the inner passes from adding statements which invalidates the outer traverser's statement
1573     // position tracking.
1574     RewriteRowMajorMatricesTraverser *mOuterTraverser;
1575 
1576     // If set, it's an inner pass that should only process the right-hand side of this particular
1577     // node.
1578     TIntermBinary *mInnerPassRoot;
1579     bool mIsProcessingInnerPassSubtree;
1580 };
1581 
1582 }  // anonymous namespace
1583 
RewriteRowMajorMatrices(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable)1584 bool RewriteRowMajorMatrices(TCompiler *compiler, TIntermBlock *root, TSymbolTable *symbolTable)
1585 {
1586     RewriteRowMajorMatricesTraverser traverser(compiler, symbolTable);
1587     root->traverse(&traverser);
1588     if (!traverser.updateTree(compiler, root))
1589     {
1590         return false;
1591     }
1592 
1593     size_t firstFunctionIndex = FindFirstFunctionDefinitionIndex(root);
1594     root->insertChildNodes(firstFunctionIndex, *traverser.getStructCopyFunctions());
1595 
1596     return compiler->validateAST(root);
1597 }
1598 }  // namespace sh
1599