• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2019 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RewriteRowMajorMatrices: Rewrite row-major matrices as column-major.
7 //
8 
9 #include "compiler/translator/tree_ops/RewriteRowMajorMatrices.h"
10 
11 #include "compiler/translator/Compiler.h"
12 #include "compiler/translator/ImmutableStringBuilder.h"
13 #include "compiler/translator/StaticType.h"
14 #include "compiler/translator/SymbolTable.h"
15 #include "compiler/translator/tree_util/IntermNode_util.h"
16 #include "compiler/translator/tree_util/IntermTraverse.h"
17 #include "compiler/translator/tree_util/ReplaceVariable.h"
18 
19 namespace sh
20 {
21 namespace
22 {
23 // Only structs with matrices are tracked.  If layout(row_major) is applied to a struct that doesn't
24 // have matrices, it's silently dropped.  This is also used to avoid creating duplicates for inner
25 // structs that don't have matrices.
26 struct StructConversionData
27 {
28     // The converted struct with every matrix transposed.
29     TStructure *convertedStruct = nullptr;
30 
31     // The copy-from and copy-to functions copying from a struct to its converted version and back.
32     TFunction *copyFromOriginal = nullptr;
33     TFunction *copyToOriginal   = nullptr;
34 };
35 
DoesFieldContainRowMajorMatrix(const TField * field,bool isBlockRowMajor)36 bool DoesFieldContainRowMajorMatrix(const TField *field, bool isBlockRowMajor)
37 {
38     TLayoutMatrixPacking matrixPacking = field->type()->getLayoutQualifier().matrixPacking;
39 
40     // The field is row major if either explicitly specified as such, or if it inherits it from the
41     // block layout qualifier.
42     if (matrixPacking == EmpColumnMajor || (matrixPacking == EmpUnspecified && !isBlockRowMajor))
43     {
44         return false;
45     }
46 
47     // The field is qualified with row_major, but if it's not a matrix or a struct containing
48     // matrices, that's a useless qualifier.
49     const TType *type = field->type();
50     return type->isMatrix() || type->isStructureContainingMatrices();
51 }
52 
DuplicateField(const TField * field)53 TField *DuplicateField(const TField *field)
54 {
55     return new TField(new TType(*field->type()), field->name(), field->line(), field->symbolType());
56 }
57 
SetColumnMajor(TType * type)58 void SetColumnMajor(TType *type)
59 {
60     TLayoutQualifier layoutQualifier = type->getLayoutQualifier();
61     layoutQualifier.matrixPacking    = EmpColumnMajor;
62     type->setLayoutQualifier(layoutQualifier);
63 }
64 
TransposeMatrixType(const TType * type)65 TType *TransposeMatrixType(const TType *type)
66 {
67     TType *newType = new TType(*type);
68 
69     SetColumnMajor(newType);
70 
71     newType->setPrimarySize(static_cast<unsigned char>(type->getRows()));
72     newType->setSecondarySize(static_cast<unsigned char>(type->getCols()));
73 
74     return newType;
75 }
76 
CopyArraySizes(const TType * from,TType * to)77 void CopyArraySizes(const TType *from, TType *to)
78 {
79     if (from->isArray())
80     {
81         to->makeArrays(from->getArraySizes());
82     }
83 }
84 
85 // Determine if the node is an index node (array index or struct field selection).  For the purposes
86 // of this transformation, swizzle nodes are considered index nodes too.
IsIndexNode(TIntermNode * node,TIntermNode * child)87 bool IsIndexNode(TIntermNode *node, TIntermNode *child)
88 {
89     if (node->getAsSwizzleNode())
90     {
91         return true;
92     }
93 
94     TIntermBinary *binaryNode = node->getAsBinaryNode();
95     if (binaryNode == nullptr || child != binaryNode->getLeft())
96     {
97         return false;
98     }
99 
100     TOperator op = binaryNode->getOp();
101 
102     return op == EOpIndexDirect || op == EOpIndexDirectInterfaceBlock ||
103            op == EOpIndexDirectStruct || op == EOpIndexIndirect;
104 }
105 
CopyToTempVariable(TSymbolTable * symbolTable,TIntermTyped * node,TIntermSequence * prependStatements)106 TIntermSymbol *CopyToTempVariable(TSymbolTable *symbolTable,
107                                   TIntermTyped *node,
108                                   TIntermSequence *prependStatements)
109 {
110     TVariable *temp              = CreateTempVariable(symbolTable, &node->getType());
111     TIntermDeclaration *tempDecl = CreateTempInitDeclarationNode(temp, node);
112     prependStatements->push_back(tempDecl);
113 
114     return new TIntermSymbol(temp);
115 }
116 
CreateStructCopyCall(const TFunction * copyFunc,TIntermTyped * expression)117 TIntermAggregate *CreateStructCopyCall(const TFunction *copyFunc, TIntermTyped *expression)
118 {
119     return TIntermAggregate::CreateFunctionCall(*copyFunc, new TIntermSequence({expression}));
120 }
121 
CreateTransposeCall(TSymbolTable * symbolTable,TIntermTyped * expression)122 TIntermTyped *CreateTransposeCall(TSymbolTable *symbolTable, TIntermTyped *expression)
123 {
124     return CreateBuiltInFunctionCallNode("transpose", new TIntermSequence({expression}),
125                                          *symbolTable, 300);
126 }
127 
GetIndex(TSymbolTable * symbolTable,TIntermNode * node,TIntermSequence * indices,TIntermSequence * prependStatements)128 TOperator GetIndex(TSymbolTable *symbolTable,
129                    TIntermNode *node,
130                    TIntermSequence *indices,
131                    TIntermSequence *prependStatements)
132 {
133     // Swizzle nodes are converted EOpIndexDirect for simplicity, with one index per swizzle
134     // channel.
135     TIntermSwizzle *asSwizzle = node->getAsSwizzleNode();
136     if (asSwizzle)
137     {
138         for (int channel : asSwizzle->getSwizzleOffsets())
139         {
140             indices->push_back(CreateIndexNode(channel));
141         }
142         return EOpIndexDirect;
143     }
144 
145     TIntermBinary *binaryNode = node->getAsBinaryNode();
146     ASSERT(binaryNode);
147 
148     TOperator op = binaryNode->getOp();
149     ASSERT(op == EOpIndexDirect || op == EOpIndexDirectInterfaceBlock ||
150            op == EOpIndexDirectStruct || op == EOpIndexIndirect);
151 
152     TIntermTyped *rhs = binaryNode->getRight()->deepCopy();
153     if (rhs->getAsConstantUnion() == nullptr)
154     {
155         rhs = CopyToTempVariable(symbolTable, rhs, prependStatements);
156     }
157 
158     indices->push_back(rhs);
159     return op;
160 }
161 
ReplicateIndexNode(TSymbolTable * symbolTable,TIntermNode * node,TIntermTyped * lhs,TIntermSequence * indices)162 TIntermTyped *ReplicateIndexNode(TSymbolTable *symbolTable,
163                                  TIntermNode *node,
164                                  TIntermTyped *lhs,
165                                  TIntermSequence *indices)
166 {
167     TIntermSwizzle *asSwizzle = node->getAsSwizzleNode();
168     if (asSwizzle)
169     {
170         return new TIntermSwizzle(lhs, asSwizzle->getSwizzleOffsets());
171     }
172 
173     TIntermBinary *binaryNode = node->getAsBinaryNode();
174     ASSERT(binaryNode);
175 
176     ASSERT(indices->size() == 1);
177     TIntermTyped *rhs = indices->front()->getAsTyped();
178 
179     return new TIntermBinary(binaryNode->getOp(), lhs, rhs);
180 }
181 
GetIndexOp(TIntermNode * node)182 TOperator GetIndexOp(TIntermNode *node)
183 {
184     return node->getAsConstantUnion() ? EOpIndexDirect : EOpIndexIndirect;
185 }
186 
IsConvertedField(TIntermTyped * indexNode,const std::unordered_map<const TField *,bool> & convertedFields)187 bool IsConvertedField(TIntermTyped *indexNode,
188                       const std::unordered_map<const TField *, bool> &convertedFields)
189 {
190     TIntermBinary *asBinary = indexNode->getAsBinaryNode();
191     if (asBinary == nullptr)
192     {
193         return false;
194     }
195 
196     if (asBinary->getOp() != EOpIndexDirectInterfaceBlock)
197     {
198         return false;
199     }
200 
201     const TInterfaceBlock *interfaceBlock = asBinary->getLeft()->getType().getInterfaceBlock();
202     ASSERT(interfaceBlock);
203 
204     TIntermConstantUnion *fieldIndexNode = asBinary->getRight()->getAsConstantUnion();
205     ASSERT(fieldIndexNode);
206     ASSERT(fieldIndexNode->getConstantValue() != nullptr);
207 
208     int fieldIndex      = fieldIndexNode->getConstantValue()->getIConst();
209     const TField *field = interfaceBlock->fields()[fieldIndex];
210 
211     return convertedFields.count(field) > 0 && convertedFields.at(field);
212 }
213 
214 // A helper class to transform expressions of array type.  Iterates over every element of the
215 // array.
216 class TransformArrayHelper
217 {
218   public:
TransformArrayHelper(TIntermTyped * baseExpression)219     TransformArrayHelper(TIntermTyped *baseExpression)
220         : mBaseExpression(baseExpression),
221           mBaseExpressionType(baseExpression->getType()),
222           mArrayIndices(mBaseExpressionType.getArraySizes().size(), 0)
223     {}
224 
getNextElement(TIntermTyped * valueExpression,TIntermTyped ** valueElementOut)225     TIntermTyped *getNextElement(TIntermTyped *valueExpression, TIntermTyped **valueElementOut)
226     {
227         const TSpan<const unsigned int> &arraySizes = mBaseExpressionType.getArraySizes();
228 
229         // If the last index overflows, element enumeration is done.
230         if (mArrayIndices.back() >= arraySizes.back())
231         {
232             return nullptr;
233         }
234 
235         TIntermTyped *element = getCurrentElement(mBaseExpression);
236         if (valueExpression)
237         {
238             *valueElementOut = getCurrentElement(valueExpression);
239         }
240 
241         incrementIndices(arraySizes);
242         return element;
243     }
244 
accumulateForRead(TSymbolTable * symbolTable,TIntermTyped * transformedElement,TIntermSequence * prependStatements)245     void accumulateForRead(TSymbolTable *symbolTable,
246                            TIntermTyped *transformedElement,
247                            TIntermSequence *prependStatements)
248     {
249         TIntermTyped *temp = CopyToTempVariable(symbolTable, transformedElement, prependStatements);
250         mReadTransformConstructorArgs.push_back(temp);
251     }
252 
constructReadTransformExpression()253     TIntermTyped *constructReadTransformExpression()
254     {
255         const TSpan<const unsigned int> &baseTypeArraySizes = mBaseExpressionType.getArraySizes();
256         TVector<unsigned int> arraySizes(baseTypeArraySizes.begin(), baseTypeArraySizes.end());
257         TIntermTyped *firstElement = mReadTransformConstructorArgs.front()->getAsTyped();
258         const TType &baseType      = firstElement->getType();
259 
260         // If N dimensions, acc[0] == size[0] and acc[i] == size[i] * acc[i-1].
261         // The last value is unused, and is not present.
262         TVector<unsigned int> accumulatedArraySizes(arraySizes.size() - 1);
263 
264         accumulatedArraySizes[0] = arraySizes[0];
265         for (size_t index = 1; index + 1 < arraySizes.size(); ++index)
266         {
267             accumulatedArraySizes[index] = accumulatedArraySizes[index - 1] * arraySizes[index];
268         }
269 
270         return constructReadTransformExpressionHelper(arraySizes, accumulatedArraySizes, baseType,
271                                                       0);
272     }
273 
274   private:
getCurrentElement(TIntermTyped * expression)275     TIntermTyped *getCurrentElement(TIntermTyped *expression)
276     {
277         TIntermTyped *element = expression->deepCopy();
278         for (auto it = mArrayIndices.rbegin(); it != mArrayIndices.rend(); ++it)
279         {
280             unsigned int index = *it;
281             element            = new TIntermBinary(EOpIndexDirect, element, CreateIndexNode(index));
282         }
283         return element;
284     }
285 
incrementIndices(const TSpan<const unsigned int> & arraySizes)286     void incrementIndices(const TSpan<const unsigned int> &arraySizes)
287     {
288         // Assume mArrayIndices is an N digit number, where digit i is in the range
289         // [0, arraySizes[i]).  This function increments this number.  Last digit is the most
290         // significant digit.
291         for (size_t digitIndex = 0; digitIndex < arraySizes.size(); ++digitIndex)
292         {
293             ++mArrayIndices[digitIndex];
294             if (mArrayIndices[digitIndex] < arraySizes[digitIndex])
295             {
296                 break;
297             }
298             if (digitIndex + 1 != arraySizes.size())
299             {
300                 // This digit has now overflown and is reset to 0, carry will be added to the next
301                 // digit.  The most significant digit will keep the overflow though, to make it
302                 // clear we have exhausted the range.
303                 mArrayIndices[digitIndex] = 0;
304             }
305         }
306     }
307 
constructReadTransformExpressionHelper(const TVector<unsigned int> & arraySizes,const TVector<unsigned int> & accumulatedArraySizes,const TType & baseType,size_t elementsOffset)308     TIntermTyped *constructReadTransformExpressionHelper(
309         const TVector<unsigned int> &arraySizes,
310         const TVector<unsigned int> &accumulatedArraySizes,
311         const TType &baseType,
312         size_t elementsOffset)
313     {
314         ASSERT(!arraySizes.empty());
315 
316         TType *transformedType = new TType(baseType);
317         transformedType->makeArrays(arraySizes);
318 
319         // If one dimensional, create the constructor with the given elements.
320         if (arraySizes.size() == 1)
321         {
322             ASSERT(accumulatedArraySizes.size() == 0);
323 
324             auto sliceStart = mReadTransformConstructorArgs.begin() + elementsOffset;
325             TIntermSequence slice(sliceStart, sliceStart + arraySizes[0]);
326 
327             return TIntermAggregate::CreateConstructor(*transformedType, &slice);
328         }
329 
330         // If not, create constructors for every column recursively.
331         TVector<unsigned int> subArraySizes(arraySizes.begin(), arraySizes.end() - 1);
332         TVector<unsigned int> subArrayAccumulatedSizes(accumulatedArraySizes.begin(),
333                                                        accumulatedArraySizes.end() - 1);
334 
335         TIntermSequence constructorArgs;
336         unsigned int colStride = accumulatedArraySizes.back();
337         for (size_t col = 0; col < arraySizes.back(); ++col)
338         {
339             size_t colElementsOffset = elementsOffset + col * colStride;
340 
341             constructorArgs.push_back(constructReadTransformExpressionHelper(
342                 subArraySizes, subArrayAccumulatedSizes, baseType, colElementsOffset));
343         }
344 
345         return TIntermAggregate::CreateConstructor(*transformedType, &constructorArgs);
346     }
347 
348     TIntermTyped *mBaseExpression;
349     const TType &mBaseExpressionType;
350     TVector<unsigned int> mArrayIndices;
351 
352     TIntermSequence mReadTransformConstructorArgs;
353 };
354 
355 // Traverser that:
356 //
357 // 1. Converts |layout(row_major) matCxR M| to |layout(column_major) matRxC Mt|.
358 // 2. Converts |layout(row_major) S s| to |layout(column_major) St st|, where S is a struct that
359 //    contains matrices, and St is a new struct with the transformation in 1 applied to matrix
360 //    members (recursively).
361 // 3. When read from, the following transformations are applied:
362 //
363 //            M       -> transpose(Mt)
364 //            M[c]    -> gvecN(Mt[0][c], Mt[1][c], ..., Mt[N-1][c])
365 //            M[c][r] -> Mt[r][c]
366 //            M[c].yz -> gvec2(Mt[1][c], Mt[2][c])
367 //            MArr    -> MType[D1]..[DN](transpose(MtArr[0]...[0]), ...)
368 //            s       -> copy_St_to_S(st)
369 //            sArr    -> SType[D1]...[DN](copy_St_to_S(stArr[0]..[0]), ...)
370 //            (matrix reads through struct are transformed similarly to M)
371 //
372 // 4. When written to, the following transformations are applied:
373 //
374 //      M = exp       -> Mt = transpose(exp)
375 //      M[c] = exp    -> temp = exp
376 //                       Mt[0][c] = temp[0]
377 //                       Mt[1][c] = temp[1]
378 //                       ...
379 //                       Mt[N-1][c] = temp[N-1]
380 //      M[c][r] = exp -> Mt[r][c] = exp
381 //      M[c].yz = exp -> temp = exp
382 //                       Mt[1][c] = temp[0]
383 //                       Mt[2][c] = temp[1]
384 //      MArr = exp    -> temp = exp
385 //                       Mt = MtType[D1]..[DN](temp([0]...[0]), ...)
386 //      s = exp       -> st = copy_S_to_St(exp)
387 //      sArr = exp    -> temp = exp
388 //                       St = StType[D1]...[DN](copy_S_to_St(temp[0]..[0]), ...)
389 //      (matrix writes through struct are transformed similarly to M)
390 //
391 // 5. If any of the above is passed to an `inout` parameter, both transformations are applied:
392 //
393 //            f(M[c]) -> temp = gvecN(Mt[0][c], Mt[1][c], ..., Mt[N-1][c])
394 //                       f(temp)
395 //                       Mt[0][c] = temp[0]
396 //                       Mt[1][c] = temp[1]
397 //                       ...
398 //                       Mt[N-1][c] = temp[N-1]
399 //
400 //               f(s) -> temp = copy_St_to_S(st)
401 //                       f(temp)
402 //                       st = copy_S_to_St(temp)
403 //
404 //    If passed to an `out` parameter, the `temp` parameter is simply not initialized.
405 //
406 // 6. If the expression leading to the matrix or struct has array subscripts, temp values are
407 //    created for them to avoid duplicating side effects.
408 //
409 class RewriteRowMajorMatricesTraverser : public TIntermTraverser
410 {
411   public:
RewriteRowMajorMatricesTraverser(TCompiler * compiler,TSymbolTable * symbolTable)412     RewriteRowMajorMatricesTraverser(TCompiler *compiler, TSymbolTable *symbolTable)
413         : TIntermTraverser(true, true, true, symbolTable),
414           mCompiler(compiler),
415           mStructMapOut(&mOuterPass.structMap),
416           mInterfaceBlockMap(&mOuterPass.interfaceBlockMap),
417           mInterfaceBlockFieldConvertedIn(mOuterPass.interfaceBlockFieldConverted),
418           mCopyFunctionDefinitionsOut(&mOuterPass.copyFunctionDefinitions),
419           mOuterTraverser(nullptr),
420           mInnerPassRoot(nullptr),
421           mIsProcessingInnerPassSubtree(false)
422     {}
423 
visitDeclaration(Visit visit,TIntermDeclaration * node)424     bool visitDeclaration(Visit visit, TIntermDeclaration *node) override
425     {
426         // No need to process declarations in inner passes.
427         if (mInnerPassRoot != nullptr)
428         {
429             return true;
430         }
431 
432         if (visit != PreVisit)
433         {
434             return true;
435         }
436 
437         const TIntermSequence &sequence = *(node->getSequence());
438 
439         TIntermTyped *variable = sequence.front()->getAsTyped();
440         const TType &type      = variable->getType();
441 
442         // If it's a struct declaration that has matrices, remember it.  If a row-major instance
443         // of it is created, it will have to be converted.
444         if (type.isStructSpecifier() && type.isStructureContainingMatrices())
445         {
446             const TStructure *structure = type.getStruct();
447             ASSERT(structure);
448 
449             ASSERT(mOuterPass.structMap.count(structure) == 0);
450 
451             StructConversionData structData;
452             mOuterPass.structMap[structure] = structData;
453 
454             return false;
455         }
456 
457         // If it's an interface block, it may have to be converted if it contains any row-major
458         // fields.
459         if (type.isInterfaceBlock() && type.getInterfaceBlock()->containsMatrices())
460         {
461             const TInterfaceBlock *block = type.getInterfaceBlock();
462             ASSERT(block);
463             bool isBlockRowMajor = type.getLayoutQualifier().matrixPacking == EmpRowMajor;
464 
465             const TFieldList &fields = block->fields();
466             bool anyRowMajor         = isBlockRowMajor;
467 
468             for (const TField *field : fields)
469             {
470                 if (DoesFieldContainRowMajorMatrix(field, isBlockRowMajor))
471                 {
472                     anyRowMajor = true;
473                     break;
474                 }
475             }
476 
477             if (anyRowMajor)
478             {
479                 convertInterfaceBlock(node);
480             }
481 
482             return false;
483         }
484 
485         return true;
486     }
487 
visitSymbol(TIntermSymbol * symbol)488     void visitSymbol(TIntermSymbol *symbol) override
489     {
490         // If in inner pass, only process if the symbol is under that root.
491         if (mInnerPassRoot != nullptr && !mIsProcessingInnerPassSubtree)
492         {
493             return;
494         }
495 
496         const TVariable *variable = &symbol->variable();
497         bool needsRewrite         = mInterfaceBlockMap->count(variable) != 0;
498 
499         // If it's a field of a nameless interface block, it may still need conversion.
500         if (!needsRewrite)
501         {
502             // Nameless interface block field symbols have the interface block pointer set, but are
503             // not interface blocks.
504             if (symbol->getType().getInterfaceBlock() && !variable->getType().isInterfaceBlock())
505             {
506                 needsRewrite = convertNamelessInterfaceBlockField(symbol);
507             }
508         }
509 
510         if (needsRewrite)
511         {
512             transformExpression(symbol);
513         }
514     }
515 
visitBinary(Visit visit,TIntermBinary * node)516     bool visitBinary(Visit visit, TIntermBinary *node) override
517     {
518         if (node == mInnerPassRoot)
519         {
520             // We only want to process the right-hand side of an assignment in inner passes.  When
521             // visit is InVisit, the left-hand side is already processed, and the right-hand side is
522             // next.  Set a flag to mark this duration.
523             mIsProcessingInnerPassSubtree = visit == InVisit;
524         }
525 
526         return true;
527     }
528 
getStructCopyFunctions()529     TIntermSequence *getStructCopyFunctions() { return &mOuterPass.copyFunctionDefinitions; }
530 
531   private:
532     typedef std::unordered_map<const TStructure *, StructConversionData> StructMap;
533     typedef std::unordered_map<const TVariable *, TVariable *> InterfaceBlockMap;
534     typedef std::unordered_map<const TField *, bool> InterfaceBlockFieldConverted;
535 
RewriteRowMajorMatricesTraverser(TSymbolTable * symbolTable,RewriteRowMajorMatricesTraverser * outerTraverser,InterfaceBlockMap * interfaceBlockMap,const InterfaceBlockFieldConverted & interfaceBlockFieldConverted,StructMap * structMap,TIntermSequence * copyFunctionDefinitions,TIntermBinary * innerPassRoot)536     RewriteRowMajorMatricesTraverser(
537         TSymbolTable *symbolTable,
538         RewriteRowMajorMatricesTraverser *outerTraverser,
539         InterfaceBlockMap *interfaceBlockMap,
540         const InterfaceBlockFieldConverted &interfaceBlockFieldConverted,
541         StructMap *structMap,
542         TIntermSequence *copyFunctionDefinitions,
543         TIntermBinary *innerPassRoot)
544         : TIntermTraverser(true, true, true, symbolTable),
545           mStructMapOut(structMap),
546           mInterfaceBlockMap(interfaceBlockMap),
547           mInterfaceBlockFieldConvertedIn(interfaceBlockFieldConverted),
548           mCopyFunctionDefinitionsOut(copyFunctionDefinitions),
549           mOuterTraverser(outerTraverser),
550           mInnerPassRoot(innerPassRoot),
551           mIsProcessingInnerPassSubtree(false)
552     {}
553 
convertInterfaceBlock(TIntermDeclaration * node)554     void convertInterfaceBlock(TIntermDeclaration *node)
555     {
556         ASSERT(mInnerPassRoot == nullptr);
557 
558         const TIntermSequence &sequence = *(node->getSequence());
559 
560         TIntermTyped *variableNode   = sequence.front()->getAsTyped();
561         const TType &type            = variableNode->getType();
562         const TInterfaceBlock *block = type.getInterfaceBlock();
563         ASSERT(block);
564 
565         bool isBlockRowMajor = type.getLayoutQualifier().matrixPacking == EmpRowMajor;
566 
567         // Recreate the struct with its row-major fields converted to column-major equivalents.
568         TIntermSequence newDeclarations;
569 
570         TFieldList *newFields = new TFieldList;
571         for (const TField *field : block->fields())
572         {
573             TField *newField = nullptr;
574 
575             if (DoesFieldContainRowMajorMatrix(field, isBlockRowMajor))
576             {
577                 newField = convertField(field, &newDeclarations);
578 
579                 // Remember that this field was converted.
580                 mOuterPass.interfaceBlockFieldConverted[field] = true;
581             }
582             else
583             {
584                 newField = DuplicateField(field);
585             }
586 
587             newFields->push_back(newField);
588         }
589 
590         // Create a new interface block with these fields.
591         TLayoutQualifier blockLayoutQualifier = type.getLayoutQualifier();
592         blockLayoutQualifier.matrixPacking    = EmpColumnMajor;
593 
594         TInterfaceBlock *newInterfaceBlock =
595             new TInterfaceBlock(mSymbolTable, block->name(), newFields, blockLayoutQualifier,
596                                 block->symbolType(), block->extension());
597 
598         // Create a new declaration with the new type.  Declarations are separated at this point,
599         // so there should be only one variable here.
600         ASSERT(sequence.size() == 1);
601 
602         TType *newInterfaceBlockType =
603             new TType(newInterfaceBlock, type.getQualifier(), blockLayoutQualifier);
604 
605         TIntermDeclaration *newDeclaration = new TIntermDeclaration;
606         const TVariable *variable          = &variableNode->getAsSymbolNode()->variable();
607 
608         const TType *newType = newInterfaceBlockType;
609         if (type.isArray())
610         {
611             TType *newArrayType = new TType(*newType);
612             CopyArraySizes(&type, newArrayType);
613             newType = newArrayType;
614         }
615 
616         // If the interface block variable itself is temp, use an empty name.
617         bool variableIsTemp = variable->symbolType() == SymbolType::Empty;
618         const ImmutableString &variableName =
619             variableIsTemp ? kEmptyImmutableString : variable->name();
620 
621         TVariable *newVariable = new TVariable(mSymbolTable, variableName, newType,
622                                                variable->symbolType(), variable->extension());
623 
624         newDeclaration->appendDeclarator(new TIntermSymbol(newVariable));
625 
626         mOuterPass.interfaceBlockMap[variable] = newVariable;
627 
628         newDeclarations.push_back(newDeclaration);
629 
630         // Replace the interface block definition with the new one, prepending any new struct
631         // definitions.
632         mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node, newDeclarations);
633     }
634 
convertNamelessInterfaceBlockField(TIntermSymbol * symbol)635     bool convertNamelessInterfaceBlockField(TIntermSymbol *symbol)
636     {
637         const TVariable *variable             = &symbol->variable();
638         const TInterfaceBlock *interfaceBlock = symbol->getType().getInterfaceBlock();
639 
640         // Find the variable corresponding to this interface block.  If the interface block
641         // is not rewritten, or this refers to a field that is not rewritten, there's
642         // nothing to do.
643         for (auto iter : *mInterfaceBlockMap)
644         {
645             // Skip other rewritten nameless interface block fields.
646             if (!iter.first->getType().isInterfaceBlock())
647             {
648                 continue;
649             }
650 
651             // Skip if this is not a field of this rewritten interface block.
652             if (iter.first->getType().getInterfaceBlock() != interfaceBlock)
653             {
654                 continue;
655             }
656 
657             const ImmutableString symbolName = symbol->getName();
658 
659             // Find which field it is
660             const TVector<TField *> fields = interfaceBlock->fields();
661             for (size_t fieldIndex = 0; fieldIndex < fields.size(); ++fieldIndex)
662             {
663                 const TField *field = fields[fieldIndex];
664                 if (field->name() != symbolName)
665                 {
666                     continue;
667                 }
668 
669                 // If this field doesn't need a rewrite, there's nothing to do.
670                 if (mInterfaceBlockFieldConvertedIn.count(field) == 0 ||
671                     !mInterfaceBlockFieldConvertedIn.at(field))
672                 {
673                     break;
674                 }
675 
676                 // Create a new variable that references the replaced interface block.
677                 TType *newType = new TType(variable->getType());
678                 newType->setInterfaceBlock(iter.second->getType().getInterfaceBlock());
679 
680                 TVariable *newVariable =
681                     new TVariable(mSymbolTable, variable->name(), newType, variable->symbolType(),
682                                   variable->extension());
683 
684                 (*mInterfaceBlockMap)[variable] = newVariable;
685 
686                 return true;
687             }
688 
689             break;
690         }
691 
692         return false;
693     }
694 
convertStruct(const TStructure * structure,TIntermSequence * newDeclarations)695     void convertStruct(const TStructure *structure, TIntermSequence *newDeclarations)
696     {
697         ASSERT(mInnerPassRoot == nullptr);
698 
699         ASSERT(mOuterPass.structMap.count(structure) != 0);
700         StructConversionData *structData = &mOuterPass.structMap[structure];
701 
702         if (structData->convertedStruct)
703         {
704             return;
705         }
706 
707         TFieldList *newFields = new TFieldList;
708         for (const TField *field : structure->fields())
709         {
710             newFields->push_back(convertField(field, newDeclarations));
711         }
712 
713         // Create unique names for the converted structs.  We can't leave them nameless and have
714         // a name autogenerated similar to temp variables, as nameless structs exist.  A fake
715         // variable is created for the sole purpose of generating a temp name.
716         TVariable *newStructTypeName =
717             new TVariable(mSymbolTable, kEmptyImmutableString, StaticType::GetBasic<EbtUInt>(),
718                           SymbolType::Empty);
719 
720         TStructure *newStruct = new TStructure(mSymbolTable, newStructTypeName->name(), newFields,
721                                                SymbolType::AngleInternal);
722         TType *newType        = new TType(newStruct, true);
723         TVariable *newStructVar =
724             new TVariable(mSymbolTable, kEmptyImmutableString, newType, SymbolType::Empty);
725 
726         TIntermDeclaration *structDecl = new TIntermDeclaration;
727         structDecl->appendDeclarator(new TIntermSymbol(newStructVar));
728 
729         newDeclarations->push_back(structDecl);
730 
731         structData->convertedStruct = newStruct;
732     }
733 
convertField(const TField * field,TIntermSequence * newDeclarations)734     TField *convertField(const TField *field, TIntermSequence *newDeclarations)
735     {
736         ASSERT(mInnerPassRoot == nullptr);
737 
738         TField *newField = nullptr;
739 
740         const TType *fieldType = field->type();
741         TType *newType         = nullptr;
742 
743         if (fieldType->isStructureContainingMatrices())
744         {
745             // If the field is a struct instance, convert the struct and replace the field
746             // with an instance of the new struct.
747             const TStructure *fieldTypeStruct = fieldType->getStruct();
748             convertStruct(fieldTypeStruct, newDeclarations);
749 
750             StructConversionData &structData = mOuterPass.structMap[fieldTypeStruct];
751             newType                          = new TType(structData.convertedStruct, false);
752             SetColumnMajor(newType);
753             CopyArraySizes(fieldType, newType);
754         }
755         else if (fieldType->isMatrix())
756         {
757             // If the field is a matrix, transpose the matrix and replace the field with
758             // that, removing the matrix packing qualifier.
759             newType = TransposeMatrixType(fieldType);
760         }
761 
762         if (newType)
763         {
764             newField = new TField(newType, field->name(), field->line(), field->symbolType());
765         }
766         else
767         {
768             newField = DuplicateField(field);
769         }
770 
771         return newField;
772     }
773 
determineAccess(TIntermNode * expression,TIntermNode * accessor,bool * isReadOut,bool * isWriteOut)774     void determineAccess(TIntermNode *expression,
775                          TIntermNode *accessor,
776                          bool *isReadOut,
777                          bool *isWriteOut)
778     {
779         // If passing to a function, look at whether the parameter is in, out or inout.
780         TIntermAggregate *functionCall = accessor->getAsAggregate();
781 
782         if (functionCall)
783         {
784             TIntermSequence *arguments = functionCall->getSequence();
785             for (size_t argIndex = 0; argIndex < arguments->size(); ++argIndex)
786             {
787                 if ((*arguments)[argIndex] == expression)
788                 {
789                     TQualifier qualifier = EvqIn;
790 
791                     // If the aggregate is not a function call, it's a constructor, and so every
792                     // argument is an input.
793                     const TFunction *function = functionCall->getFunction();
794                     if (function)
795                     {
796                         const TVariable *param = function->getParam(argIndex);
797                         qualifier              = param->getType().getQualifier();
798                     }
799 
800                     *isReadOut  = qualifier != EvqOut;
801                     *isWriteOut = qualifier == EvqOut || qualifier == EvqInOut;
802                     break;
803                 }
804             }
805             return;
806         }
807 
808         TIntermBinary *assignment = accessor->getAsBinaryNode();
809         if (assignment && IsAssignment(assignment->getOp()))
810         {
811             // If expression is on the right of assignment, it's being read from.
812             *isReadOut = assignment->getRight() == expression;
813             // If it's on the left of assignment, it's being written to.
814             *isWriteOut = assignment->getLeft() == expression;
815             return;
816         }
817 
818         // Any other usage is a read.
819         *isReadOut  = true;
820         *isWriteOut = false;
821     }
822 
transformExpression(TIntermSymbol * symbol)823     void transformExpression(TIntermSymbol *symbol)
824     {
825         // Walk up the parent chain while the nodes are EOpIndex* (whether array indexing or struct
826         // field selection) or swizzle and construct the replacement expression.  This traversal can
827         // lead to one of the following possibilities:
828         //
829         // - a.b[N].etc.s (struct, or struct array): copy function should be declared and used,
830         // - a.b[N].etc.M (matrix or matrix array): transpose() should be used,
831         // - a.b[N].etc.M[c] (a column): each element in column needs to be handled separately,
832         // - a.b[N].etc.M[c].yz (multiple elements): similar to whole column, but a subset of
833         //   elements,
834         // - a.b[N].etc.M[c][r] (an element): single element to handle.
835         // - a.b[N].etc.x (not struct or matrix): not modified
836         //
837         // primaryIndex will contain c, if any.  secondaryIndices will contain {0, ..., R-1}
838         // (if no [r] or swizzle), {r} (if [r]), or {1, 2} (corresponding to .yz) if any.
839         //
840         // In all cases, the base symbol is replaced.  |baseExpression| will contain everything up
841         // to (and not including) the last index/swizzle operations, i.e. a.b[N].etc.s/M/x.  Any
842         // non constant array subscript is assigned to a temp variable to avoid duplicating side
843         // effects.
844         //
845         // ---
846         //
847         // NOTE that due to the use of insertStatementsInParentBlock, cases like this will be
848         // mistranslated, and this bug is likely present in most transformations that use this
849         // feature:
850         //
851         //     if (x == 1 && a.b[x = 2].etc.M = value)
852         //
853         // which will translate to:
854         //
855         //     temp = (x = 2)
856         //     if (x == 1 && a.b[temp].etc.M = transpose(value))
857         //
858         // See http://anglebug.com/3829.
859         //
860         TIntermTyped *baseExpression =
861             new TIntermSymbol(mInterfaceBlockMap->at(&symbol->variable()));
862         const TStructure *structure = nullptr;
863 
864         TIntermNode *primaryIndex = nullptr;
865         TIntermSequence secondaryIndices;
866 
867         // In some cases, it is necessary to prepend or append statements.  Those are captured in
868         // |prependStatements| and |appendStatements|.
869         TIntermSequence prependStatements;
870         TIntermSequence appendStatements;
871 
872         // If the expression is neither a struct or matrix, no modification is necessary.
873         // If it's a struct that doesn't have matrices, again there's no transformation necessary.
874         // If it's an interface block matrix field that didn't need to be transposed, no
875         // transpformation is necessary.
876         //
877         // In all these cases, |baseExpression| contains all of the original expression.
878         //
879         // If the starting symbol itself is a field of a nameless interface block, it needs
880         // conversion if we reach here.
881         bool requiresTransformation = !symbol->getType().isInterfaceBlock();
882 
883         uint32_t accessorIndex         = 0;
884         TIntermTyped *previousAncestor = symbol;
885         while (IsIndexNode(getAncestorNode(accessorIndex), previousAncestor))
886         {
887             TIntermTyped *ancestor = getAncestorNode(accessorIndex)->getAsTyped();
888             ASSERT(ancestor);
889 
890             const TType &previousAncestorType = previousAncestor->getType();
891 
892             TIntermSequence indices;
893             TOperator op = GetIndex(mSymbolTable, ancestor, &indices, &prependStatements);
894 
895             bool opIsIndex     = op == EOpIndexDirect || op == EOpIndexIndirect;
896             bool isArrayIndex  = opIsIndex && previousAncestorType.isArray();
897             bool isMatrixIndex = opIsIndex && previousAncestorType.isMatrix();
898 
899             // If it's a direct index in a matrix, it's the primary index.
900             bool isMatrixPrimarySubscript = isMatrixIndex && !isArrayIndex;
901             ASSERT(!isMatrixPrimarySubscript ||
902                    (primaryIndex == nullptr && secondaryIndices.empty()));
903             // If primary index is seen and the ancestor is still an index, it must be a direct
904             // index as the secondary one.  Note that if primaryIndex is set, there can only ever be
905             // one more parent of interest, and that's subscripting the second dimension.
906             bool isMatrixSecondarySubscript = primaryIndex != nullptr;
907             ASSERT(!isMatrixSecondarySubscript || (opIsIndex && !isArrayIndex));
908 
909             if (requiresTransformation && isMatrixPrimarySubscript)
910             {
911                 ASSERT(indices.size() == 1);
912                 primaryIndex = indices.front();
913 
914                 // Default the secondary indices to include every row.  If there's a secondary
915                 // subscript provided, it will override this.
916                 int rows = previousAncestorType.getRows();
917                 for (int r = 0; r < rows; ++r)
918                 {
919                     secondaryIndices.push_back(CreateIndexNode(r));
920                 }
921             }
922             else if (isMatrixSecondarySubscript)
923             {
924                 ASSERT(requiresTransformation);
925 
926                 secondaryIndices = indices;
927 
928                 // Indices after this point are not interesting.  There can't actually be any other
929                 // index nodes other than desktop GLSL's swizzles on scalars, like M[1][2].yyy.
930                 ++accessorIndex;
931                 break;
932             }
933             else
934             {
935                 // Replicate the expression otherwise.
936                 baseExpression =
937                     ReplicateIndexNode(mSymbolTable, ancestor, baseExpression, &indices);
938 
939                 const TType &ancestorType = ancestor->getType();
940                 structure                 = ancestorType.getStruct();
941 
942                 requiresTransformation =
943                     requiresTransformation ||
944                     IsConvertedField(ancestor, mInterfaceBlockFieldConvertedIn);
945 
946                 // If we reach a point where the expression is neither a matrix-containing struct
947                 // nor a matrix, there's no transformation required.  This can happen if we decend
948                 // through a struct marked with row-major but arrive at a member that doesn't
949                 // include a matrix.
950                 if (!ancestorType.isMatrix() && !ancestorType.isStructureContainingMatrices())
951                 {
952                     requiresTransformation = false;
953                 }
954             }
955 
956             previousAncestor = ancestor;
957             ++accessorIndex;
958         }
959 
960         TIntermNode *originalExpression =
961             accessorIndex == 0 ? symbol : getAncestorNode(accessorIndex - 1);
962         TIntermNode *accessor = getAncestorNode(accessorIndex);
963 
964         // if accessor is EOpArrayLength, we don't need to perform any transformations either.
965         // Note that this only applies to unsized arrays, as the RemoveArrayLengthMethod()
966         // transformation would have removed this operation otherwise.
967         TIntermUnary *accessorAsUnary = accessor->getAsUnaryNode();
968         if (requiresTransformation && accessorAsUnary && accessorAsUnary->getOp() == EOpArrayLength)
969         {
970             ASSERT(accessorAsUnary->getOperand() == originalExpression);
971             ASSERT(accessorAsUnary->getOperand()->getType().isUnsizedArray());
972 
973             requiresTransformation = false;
974 
975             // We need to replace the whole expression including the EOpArrayLength, to avoid
976             // confusing the replacement code as the original and new expressions don't have the
977             // same type (one is the transpose of the other).  This doesn't affect the .length()
978             // operation, so this replacement is ok, though it's not worth special-casing this in
979             // the node replacement algorithm.
980             //
981             // Note: the |if (!requiresTransformation)| immediately below will be entered after
982             // this.
983             originalExpression = accessor;
984             accessor           = getAncestorNode(accessorIndex + 1);
985             baseExpression     = new TIntermUnary(EOpArrayLength, baseExpression, nullptr);
986         }
987 
988         if (!requiresTransformation)
989         {
990             ASSERT(primaryIndex == nullptr);
991             queueReplacementWithParent(accessor, originalExpression, baseExpression,
992                                        OriginalNode::IS_DROPPED);
993 
994             RewriteRowMajorMatricesTraverser *traverser = mOuterTraverser ? mOuterTraverser : this;
995             traverser->insertStatementsInParentBlock(prependStatements, appendStatements);
996             return;
997         }
998 
999         ASSERT(structure == nullptr || primaryIndex == nullptr);
1000         ASSERT(structure != nullptr || baseExpression->getType().isMatrix());
1001 
1002         // At the end, we can determine if the expression is being read from or written to (or both,
1003         // if sent as an inout parameter to a function).  For the sake of the transformation, the
1004         // left-hand side of operations like += can be treated as "written to", without necessarily
1005         // "read from".
1006         bool isRead  = false;
1007         bool isWrite = false;
1008 
1009         determineAccess(originalExpression, accessor, &isRead, &isWrite);
1010 
1011         ASSERT(isRead || isWrite);
1012 
1013         TIntermTyped *readExpression = nullptr;
1014         if (isRead)
1015         {
1016             readExpression = transformReadExpression(
1017                 baseExpression, primaryIndex, &secondaryIndices, structure, &prependStatements);
1018 
1019             // If both read from and written to (i.e. passed to inout parameter), store the
1020             // expression in a temp variable and pass that to the function.
1021             if (isWrite)
1022             {
1023                 readExpression =
1024                     CopyToTempVariable(mSymbolTable, readExpression, &prependStatements);
1025             }
1026 
1027             // Replace the original expression with the transformed one.  Read transformations
1028             // always generate a single expression that can be used in place of the original (as
1029             // oppposed to write transformations that can generate multiple statements).
1030             queueReplacementWithParent(accessor, originalExpression, readExpression,
1031                                        OriginalNode::IS_DROPPED);
1032         }
1033 
1034         TIntermSequence postTransformPrependStatements;
1035         TIntermSequence *writeStatements = &appendStatements;
1036         TOperator assignmentOperator     = EOpAssign;
1037 
1038         if (isWrite)
1039         {
1040             TIntermTyped *valueExpression = readExpression;
1041 
1042             if (!valueExpression)
1043             {
1044                 // If there's already a read expression, this was an inout parameter and
1045                 // |valueExpression| will contain the temp variable that was passed to the function
1046                 // instead.
1047                 //
1048                 // If not, then the modification is either through being passed as an out parameter
1049                 // to a function, or an assignment.  In the former case, create a temp variable to
1050                 // be passed to the function.  In the latter case, create a temp variable that holds
1051                 // the right hand side expression.
1052                 //
1053                 // In either case, use that temp value as the value to assign to |baseExpression|.
1054 
1055                 TVariable *temp =
1056                     CreateTempVariable(mSymbolTable, &originalExpression->getAsTyped()->getType());
1057                 TIntermDeclaration *tempDecl = nullptr;
1058 
1059                 valueExpression = new TIntermSymbol(temp);
1060 
1061                 TIntermBinary *assignment = accessor->getAsBinaryNode();
1062                 if (assignment)
1063                 {
1064                     assignmentOperator = assignment->getOp();
1065                     ASSERT(IsAssignment(assignmentOperator));
1066 
1067                     // We are converting the assignment to the left-hand side of an expression in
1068                     // the form M=exp.  A subexpression of exp itself could require a
1069                     // transformation.  This complicates things as there would be two replacements:
1070                     //
1071                     // - Replace M=exp with temp (because the return value of the assignment could
1072                     //   be used)
1073                     // - Replace exp with exp2, where parent is M=exp
1074                     //
1075                     // The second replacement however is ineffective as the whole of M=exp is
1076                     // already transformed.  What's worse, M=exp is transformed without taking exp's
1077                     // transformations into account.  To address this issue, this same traverser is
1078                     // called on the right-hand side expression, with a special flag such that it
1079                     // only processes that expression.
1080                     //
1081                     RewriteRowMajorMatricesTraverser *outerTraverser =
1082                         mOuterTraverser ? mOuterTraverser : this;
1083                     RewriteRowMajorMatricesTraverser rhsTraverser(
1084                         mSymbolTable, outerTraverser, mInterfaceBlockMap,
1085                         mInterfaceBlockFieldConvertedIn, mStructMapOut, mCopyFunctionDefinitionsOut,
1086                         assignment);
1087                     getRootNode()->traverse(&rhsTraverser);
1088                     bool valid = rhsTraverser.updateTree(mCompiler, getRootNode());
1089                     ASSERT(valid);
1090 
1091                     tempDecl = CreateTempInitDeclarationNode(temp, assignment->getRight());
1092 
1093                     // Replace the whole assignment expression with the right-hand side as a read
1094                     // expression, in case the result of the assignment is used.  For example, this
1095                     // transforms:
1096                     //
1097                     //     if ((M += exp) == X)
1098                     //     {
1099                     //         // use M
1100                     //     }
1101                     //
1102                     // to:
1103                     //
1104                     //     temp = exp;
1105                     //     M += transform(temp);
1106                     //     if (transform(M) == X)
1107                     //     {
1108                     //         // use M
1109                     //     }
1110                     //
1111                     // Note that in this case the assignment to M must be prepended in the parent
1112                     // block.  In contrast, when sent to a function, the assignment to M should be
1113                     // done after the current function call is done.
1114                     //
1115                     // If the read from M itself (to replace assigmnet) needs to generate extra
1116                     // statements, they should be appended after the statements that write to M.
1117                     // These statements are stored in postTransformPrependStatements and appended to
1118                     // prependStatements in the end.
1119                     //
1120                     writeStatements = &prependStatements;
1121 
1122                     TIntermTyped *assignmentResultExpression = transformReadExpression(
1123                         baseExpression->deepCopy(), primaryIndex, &secondaryIndices, structure,
1124                         &postTransformPrependStatements);
1125 
1126                     // Replace the whole assignment, instead of just the right hand side.
1127                     TIntermNode *accessorParent = getAncestorNode(accessorIndex + 1);
1128                     queueReplacementWithParent(accessorParent, accessor, assignmentResultExpression,
1129                                                OriginalNode::IS_DROPPED);
1130                 }
1131                 else
1132                 {
1133                     tempDecl = CreateTempDeclarationNode(temp);
1134 
1135                     // Replace the write expression (a function call argument) with the temp
1136                     // variable.
1137                     queueReplacementWithParent(accessor, originalExpression, valueExpression,
1138                                                OriginalNode::IS_DROPPED);
1139                 }
1140                 prependStatements.push_back(tempDecl);
1141             }
1142 
1143             if (isRead)
1144             {
1145                 baseExpression = baseExpression->deepCopy();
1146             }
1147             transformWriteExpression(baseExpression, primaryIndex, &secondaryIndices, structure,
1148                                      valueExpression, assignmentOperator, writeStatements);
1149         }
1150 
1151         prependStatements.insert(prependStatements.end(), postTransformPrependStatements.begin(),
1152                                  postTransformPrependStatements.end());
1153 
1154         RewriteRowMajorMatricesTraverser *traverser = mOuterTraverser ? mOuterTraverser : this;
1155         traverser->insertStatementsInParentBlock(prependStatements, appendStatements);
1156     }
1157 
transformReadExpression(TIntermTyped * baseExpression,TIntermNode * primaryIndex,TIntermSequence * secondaryIndices,const TStructure * structure,TIntermSequence * prependStatements)1158     TIntermTyped *transformReadExpression(TIntermTyped *baseExpression,
1159                                           TIntermNode *primaryIndex,
1160                                           TIntermSequence *secondaryIndices,
1161                                           const TStructure *structure,
1162                                           TIntermSequence *prependStatements)
1163     {
1164         const TType &baseExpressionType = baseExpression->getType();
1165 
1166         if (structure)
1167         {
1168             ASSERT(primaryIndex == nullptr && secondaryIndices->empty());
1169             ASSERT(mStructMapOut->count(structure) != 0);
1170             ASSERT((*mStructMapOut)[structure].convertedStruct != nullptr);
1171 
1172             // Declare copy-from-converted-to-original-struct function (if not already).
1173             declareStructCopyToOriginal(structure);
1174 
1175             const TFunction *copyToOriginal = (*mStructMapOut)[structure].copyToOriginal;
1176 
1177             if (baseExpressionType.isArray())
1178             {
1179                 // If base expression is an array, transform every element.
1180                 TransformArrayHelper transformHelper(baseExpression);
1181 
1182                 TIntermTyped *element = nullptr;
1183                 while ((element = transformHelper.getNextElement(nullptr, nullptr)) != nullptr)
1184                 {
1185                     TIntermTyped *transformedElement =
1186                         CreateStructCopyCall(copyToOriginal, element);
1187                     transformHelper.accumulateForRead(mSymbolTable, transformedElement,
1188                                                       prependStatements);
1189                 }
1190                 return transformHelper.constructReadTransformExpression();
1191             }
1192             else
1193             {
1194                 // If not reading an array, the result is simply a call to this function with the
1195                 // base expression.
1196                 return CreateStructCopyCall(copyToOriginal, baseExpression);
1197             }
1198         }
1199 
1200         // If not indexed, the result is transpose(exp)
1201         if (primaryIndex == nullptr)
1202         {
1203             ASSERT(secondaryIndices->empty());
1204 
1205             if (baseExpressionType.isArray())
1206             {
1207                 // If array, transpose every element.
1208                 TransformArrayHelper transformHelper(baseExpression);
1209 
1210                 TIntermTyped *element = nullptr;
1211                 while ((element = transformHelper.getNextElement(nullptr, nullptr)) != nullptr)
1212                 {
1213                     TIntermTyped *transformedElement = CreateTransposeCall(mSymbolTable, element);
1214                     transformHelper.accumulateForRead(mSymbolTable, transformedElement,
1215                                                       prependStatements);
1216                 }
1217                 return transformHelper.constructReadTransformExpression();
1218             }
1219             else
1220             {
1221                 return CreateTransposeCall(mSymbolTable, baseExpression);
1222             }
1223         }
1224 
1225         // If indexed the result is a vector (or just one element) where the primary and secondary
1226         // indices are swapped.
1227         ASSERT(!secondaryIndices->empty());
1228 
1229         TOperator primaryIndexOp          = GetIndexOp(primaryIndex);
1230         TIntermTyped *primaryIndexAsTyped = primaryIndex->getAsTyped();
1231 
1232         TIntermSequence transposedColumn;
1233         for (TIntermNode *secondaryIndex : *secondaryIndices)
1234         {
1235             TOperator secondaryIndexOp          = GetIndexOp(secondaryIndex);
1236             TIntermTyped *secondaryIndexAsTyped = secondaryIndex->getAsTyped();
1237 
1238             TIntermBinary *colIndexed = new TIntermBinary(
1239                 secondaryIndexOp, baseExpression->deepCopy(), secondaryIndexAsTyped->deepCopy());
1240             TIntermBinary *colRowIndexed =
1241                 new TIntermBinary(primaryIndexOp, colIndexed, primaryIndexAsTyped->deepCopy());
1242 
1243             transposedColumn.push_back(colRowIndexed);
1244         }
1245 
1246         if (secondaryIndices->size() == 1)
1247         {
1248             // If only one element, return that directly.
1249             return transposedColumn.front()->getAsTyped();
1250         }
1251 
1252         // Otherwise create a constructor with the appropriate dimension.
1253         TType *vecType = new TType(baseExpressionType.getBasicType(), secondaryIndices->size());
1254         return TIntermAggregate::CreateConstructor(*vecType, &transposedColumn);
1255     }
1256 
transformWriteExpression(TIntermTyped * baseExpression,TIntermNode * primaryIndex,TIntermSequence * secondaryIndices,const TStructure * structure,TIntermTyped * valueExpression,TOperator assignmentOperator,TIntermSequence * writeStatements)1257     void transformWriteExpression(TIntermTyped *baseExpression,
1258                                   TIntermNode *primaryIndex,
1259                                   TIntermSequence *secondaryIndices,
1260                                   const TStructure *structure,
1261                                   TIntermTyped *valueExpression,
1262                                   TOperator assignmentOperator,
1263                                   TIntermSequence *writeStatements)
1264     {
1265         const TType &baseExpressionType = baseExpression->getType();
1266 
1267         if (structure)
1268         {
1269             ASSERT(primaryIndex == nullptr && secondaryIndices->empty());
1270             ASSERT(mStructMapOut->count(structure) != 0);
1271             ASSERT((*mStructMapOut)[structure].convertedStruct != nullptr);
1272 
1273             // Declare copy-to-converted-from-original-struct function (if not already).
1274             declareStructCopyFromOriginal(structure);
1275 
1276             // The result is call to this function with the value expression assigned to base
1277             // expression.
1278             const TFunction *copyFromOriginal = (*mStructMapOut)[structure].copyFromOriginal;
1279 
1280             if (baseExpressionType.isArray())
1281             {
1282                 // If array, assign every element.
1283                 TransformArrayHelper transformHelper(baseExpression);
1284 
1285                 TIntermTyped *element      = nullptr;
1286                 TIntermTyped *valueElement = nullptr;
1287                 while ((element = transformHelper.getNextElement(valueExpression, &valueElement)) !=
1288                        nullptr)
1289                 {
1290                     TIntermTyped *functionCall =
1291                         CreateStructCopyCall(copyFromOriginal, valueElement);
1292                     writeStatements->push_back(new TIntermBinary(EOpAssign, element, functionCall));
1293                 }
1294             }
1295             else
1296             {
1297                 TIntermTyped *functionCall =
1298                     CreateStructCopyCall(copyFromOriginal, valueExpression->deepCopy());
1299                 writeStatements->push_back(
1300                     new TIntermBinary(EOpAssign, baseExpression, functionCall));
1301             }
1302 
1303             return;
1304         }
1305 
1306         // If not indexed, the result is transpose(exp)
1307         if (primaryIndex == nullptr)
1308         {
1309             ASSERT(secondaryIndices->empty());
1310 
1311             if (baseExpressionType.isArray())
1312             {
1313                 // If array, assign every element.
1314                 TransformArrayHelper transformHelper(baseExpression);
1315 
1316                 TIntermTyped *element      = nullptr;
1317                 TIntermTyped *valueElement = nullptr;
1318                 while ((element = transformHelper.getNextElement(valueExpression, &valueElement)) !=
1319                        nullptr)
1320                 {
1321                     TIntermTyped *valueTransposed = CreateTransposeCall(mSymbolTable, valueElement);
1322                     writeStatements->push_back(
1323                         new TIntermBinary(EOpAssign, element, valueTransposed));
1324                 }
1325             }
1326             else
1327             {
1328                 TIntermTyped *valueTransposed =
1329                     CreateTransposeCall(mSymbolTable, valueExpression->deepCopy());
1330                 writeStatements->push_back(
1331                     new TIntermBinary(assignmentOperator, baseExpression, valueTransposed));
1332             }
1333 
1334             return;
1335         }
1336 
1337         // If indexed, create one assignment per secondary index.  If the right-hand side is a
1338         // scalar, it's used with every assignment.  If it's a vector, the assignment is
1339         // per-component.  The right-hand side cannot be a matrix as that would imply left-hand
1340         // side being a matrix too, which is covered above where |primaryIndex == nullptr|.
1341         ASSERT(!secondaryIndices->empty());
1342 
1343         bool isValueExpressionScalar = valueExpression->getType().getNominalSize() == 1;
1344         ASSERT(isValueExpressionScalar || valueExpression->getType().getNominalSize() ==
1345                                               static_cast<int>(secondaryIndices->size()));
1346 
1347         TOperator primaryIndexOp          = GetIndexOp(primaryIndex);
1348         TIntermTyped *primaryIndexAsTyped = primaryIndex->getAsTyped();
1349 
1350         for (TIntermNode *secondaryIndex : *secondaryIndices)
1351         {
1352             TOperator secondaryIndexOp          = GetIndexOp(secondaryIndex);
1353             TIntermTyped *secondaryIndexAsTyped = secondaryIndex->getAsTyped();
1354 
1355             TIntermBinary *colIndexed = new TIntermBinary(
1356                 secondaryIndexOp, baseExpression->deepCopy(), secondaryIndexAsTyped->deepCopy());
1357             TIntermBinary *colRowIndexed =
1358                 new TIntermBinary(primaryIndexOp, colIndexed, primaryIndexAsTyped->deepCopy());
1359 
1360             TIntermTyped *valueExpressionIndexed = valueExpression->deepCopy();
1361             if (!isValueExpressionScalar)
1362             {
1363                 valueExpressionIndexed = new TIntermBinary(secondaryIndexOp, valueExpressionIndexed,
1364                                                            secondaryIndexAsTyped->deepCopy());
1365             }
1366 
1367             writeStatements->push_back(
1368                 new TIntermBinary(assignmentOperator, colRowIndexed, valueExpressionIndexed));
1369         }
1370     }
1371 
getCopyStructFieldFunction(const TType * fromFieldType,const TType * toFieldType,bool isCopyToOriginal)1372     const TFunction *getCopyStructFieldFunction(const TType *fromFieldType,
1373                                                 const TType *toFieldType,
1374                                                 bool isCopyToOriginal)
1375     {
1376         ASSERT(fromFieldType->getStruct());
1377         ASSERT(toFieldType->getStruct());
1378 
1379         // If copying from or to the original struct, the "to" field struct could require
1380         // conversion to or from the "from" field struct.  |isCopyToOriginal| tells us if we
1381         // should expect to find toField or fromField in mStructMapOut, if true or false
1382         // respectively.
1383         const TFunction *fieldCopyFunction = nullptr;
1384         if (isCopyToOriginal)
1385         {
1386             const TStructure *toFieldStruct = toFieldType->getStruct();
1387 
1388             auto iter = mStructMapOut->find(toFieldStruct);
1389             if (iter != mStructMapOut->end())
1390             {
1391                 declareStructCopyToOriginal(toFieldStruct);
1392                 fieldCopyFunction = iter->second.copyToOriginal;
1393             }
1394         }
1395         else
1396         {
1397             const TStructure *fromFieldStruct = fromFieldType->getStruct();
1398 
1399             auto iter = mStructMapOut->find(fromFieldStruct);
1400             if (iter != mStructMapOut->end())
1401             {
1402                 declareStructCopyFromOriginal(fromFieldStruct);
1403                 fieldCopyFunction = iter->second.copyFromOriginal;
1404             }
1405         }
1406 
1407         return fieldCopyFunction;
1408     }
1409 
addFieldCopy(TIntermBlock * body,TIntermTyped * to,TIntermTyped * from,bool isCopyToOriginal)1410     void addFieldCopy(TIntermBlock *body,
1411                       TIntermTyped *to,
1412                       TIntermTyped *from,
1413                       bool isCopyToOriginal)
1414     {
1415         const TType &fromType = from->getType();
1416         const TType &toType   = to->getType();
1417 
1418         TIntermTyped *rhs = from;
1419 
1420         if (fromType.getStruct())
1421         {
1422             const TFunction *fieldCopyFunction =
1423                 getCopyStructFieldFunction(&fromType, &toType, isCopyToOriginal);
1424 
1425             if (fieldCopyFunction)
1426             {
1427                 rhs = CreateStructCopyCall(fieldCopyFunction, from);
1428             }
1429         }
1430         else if (fromType.isMatrix())
1431         {
1432             rhs = CreateTransposeCall(mSymbolTable, from);
1433         }
1434 
1435         body->appendStatement(new TIntermBinary(EOpAssign, to, rhs));
1436     }
1437 
declareStructCopy(const TStructure * from,const TStructure * to,bool isCopyToOriginal)1438     TFunction *declareStructCopy(const TStructure *from,
1439                                  const TStructure *to,
1440                                  bool isCopyToOriginal)
1441     {
1442         TType *fromType = new TType(from, true);
1443         TType *toType   = new TType(to, true);
1444 
1445         // Create the parameter and return value variables.
1446         TVariable *fromVar = new TVariable(mSymbolTable, ImmutableString("from"), fromType,
1447                                            SymbolType::AngleInternal);
1448         TVariable *toVar =
1449             new TVariable(mSymbolTable, ImmutableString("to"), toType, SymbolType::AngleInternal);
1450 
1451         TIntermSymbol *fromSymbol = new TIntermSymbol(fromVar);
1452         TIntermSymbol *toSymbol   = new TIntermSymbol(toVar);
1453 
1454         // Create the function body as statements are generated.
1455         TIntermBlock *body = new TIntermBlock;
1456 
1457         // Declare the result variable.
1458         TIntermDeclaration *toDecl = new TIntermDeclaration();
1459         toDecl->appendDeclarator(toSymbol);
1460         body->appendStatement(toDecl);
1461 
1462         // Iterate over fields of the struct and copy one by one, transposing the matrices.  If a
1463         // struct is encountered that requires a transformation, this function is recursively
1464         // called.  As a result, it is important that the copy functions are placed in the code in
1465         // order.
1466         const TFieldList &fromFields = from->fields();
1467         const TFieldList &toFields   = to->fields();
1468         ASSERT(fromFields.size() == toFields.size());
1469 
1470         for (size_t fieldIndex = 0; fieldIndex < fromFields.size(); ++fieldIndex)
1471         {
1472             TIntermTyped *fieldIndexNode = CreateIndexNode(static_cast<int>(fieldIndex));
1473 
1474             TIntermTyped *fromField =
1475                 new TIntermBinary(EOpIndexDirectStruct, fromSymbol->deepCopy(), fieldIndexNode);
1476             TIntermTyped *toField = new TIntermBinary(EOpIndexDirectStruct, toSymbol->deepCopy(),
1477                                                       fieldIndexNode->deepCopy());
1478 
1479             const TType *fromFieldType = fromFields[fieldIndex]->type();
1480             bool isStructOrMatrix      = fromFieldType->getStruct() || fromFieldType->isMatrix();
1481 
1482             if (fromFieldType->isArray() && isStructOrMatrix)
1483             {
1484                 // If struct or matrix array, we need to copy element by element.
1485                 TransformArrayHelper transformHelper(toField);
1486 
1487                 TIntermTyped *toElement   = nullptr;
1488                 TIntermTyped *fromElement = nullptr;
1489                 while ((toElement = transformHelper.getNextElement(fromField, &fromElement)) !=
1490                        nullptr)
1491                 {
1492                     addFieldCopy(body, toElement, fromElement, isCopyToOriginal);
1493                 }
1494             }
1495             else
1496             {
1497                 addFieldCopy(body, toField, fromField, isCopyToOriginal);
1498             }
1499         }
1500 
1501         // Add return statement.
1502         body->appendStatement(new TIntermBranch(EOpReturn, toSymbol->deepCopy()));
1503 
1504         // Declare the function
1505         TFunction *copyFunction = new TFunction(mSymbolTable, kEmptyImmutableString,
1506                                                 SymbolType::AngleInternal, toType, true);
1507         copyFunction->addParameter(fromVar);
1508 
1509         TIntermFunctionDefinition *functionDef =
1510             CreateInternalFunctionDefinitionNode(*copyFunction, body);
1511         mCopyFunctionDefinitionsOut->push_back(functionDef);
1512 
1513         return copyFunction;
1514     }
1515 
declareStructCopyFromOriginal(const TStructure * structure)1516     void declareStructCopyFromOriginal(const TStructure *structure)
1517     {
1518         StructConversionData *structData = &(*mStructMapOut)[structure];
1519         if (structData->copyFromOriginal)
1520         {
1521             return;
1522         }
1523 
1524         structData->copyFromOriginal =
1525             declareStructCopy(structure, structData->convertedStruct, false);
1526     }
1527 
declareStructCopyToOriginal(const TStructure * structure)1528     void declareStructCopyToOriginal(const TStructure *structure)
1529     {
1530         StructConversionData *structData = &(*mStructMapOut)[structure];
1531         if (structData->copyToOriginal)
1532         {
1533             return;
1534         }
1535 
1536         structData->copyToOriginal =
1537             declareStructCopy(structData->convertedStruct, structure, true);
1538     }
1539 
1540     TCompiler *mCompiler;
1541 
1542     // This traverser can call itself to transform a subexpression before moving on.  However, it
1543     // needs to accumulate conversion functions in inner passes.  The fields below marked with Out
1544     // or In are inherited from the outer pass (for inner passes), or point to storage fields in
1545     // mOuterPass (for the outer pass).  The latter should not be used by the inner passes as they
1546     // would be empty, so they are placed inside a struct to make them explicit.
1547     struct
1548     {
1549         StructMap structMap;
1550         InterfaceBlockMap interfaceBlockMap;
1551         InterfaceBlockFieldConverted interfaceBlockFieldConverted;
1552         TIntermSequence copyFunctionDefinitions;
1553     } mOuterPass;
1554 
1555     // A map from structures with matrices to their converted version.
1556     StructMap *mStructMapOut;
1557     // A map from interface block instances with row-major matrices to their converted variable.  If
1558     // an interface block is nameless, its fields are placed in this map instead.  When a variable
1559     // in this map is encountered, it signals the start of an expression that my need conversion,
1560     // which is either "interfaceBlock.field..." or "field..." if nameless.
1561     InterfaceBlockMap *mInterfaceBlockMap;
1562     // A map from interface block fields to whether they need to be converted.  If a field was
1563     // already column-major, it shouldn't be transposed.
1564     const InterfaceBlockFieldConverted &mInterfaceBlockFieldConvertedIn;
1565 
1566     TIntermSequence *mCopyFunctionDefinitionsOut;
1567 
1568     // If set, it's an inner pass and this will point to the outer pass traverser.  All statement
1569     // insertions are stored in the outer traverser and applied at once in the end.  This prevents
1570     // the inner passes from adding statements which invalidates the outer traverser's statement
1571     // position tracking.
1572     RewriteRowMajorMatricesTraverser *mOuterTraverser;
1573 
1574     // If set, it's an inner pass that should only process the right-hand side of this particular
1575     // node.
1576     TIntermBinary *mInnerPassRoot;
1577     bool mIsProcessingInnerPassSubtree;
1578 };
1579 
1580 }  // anonymous namespace
1581 
RewriteRowMajorMatrices(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable)1582 bool RewriteRowMajorMatrices(TCompiler *compiler, TIntermBlock *root, TSymbolTable *symbolTable)
1583 {
1584     RewriteRowMajorMatricesTraverser traverser(compiler, symbolTable);
1585     root->traverse(&traverser);
1586     if (!traverser.updateTree(compiler, root))
1587     {
1588         return false;
1589     }
1590 
1591     size_t firstFunctionIndex = FindFirstFunctionDefinitionIndex(root);
1592     root->insertChildNodes(firstFunctionIndex, *traverser.getStructCopyFunctions());
1593 
1594     return compiler->validateAST(root);
1595 }
1596 }  // namespace sh
1597