1 //
2 // Copyright 2019 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RewriteRowMajorMatrices: Rewrite row-major matrices as column-major.
7 //
8
9 #include "compiler/translator/tree_ops/glsl/apple/RewriteRowMajorMatrices.h"
10
11 #include "common/span.h"
12 #include "compiler/translator/Compiler.h"
13 #include "compiler/translator/ImmutableStringBuilder.h"
14 #include "compiler/translator/StaticType.h"
15 #include "compiler/translator/SymbolTable.h"
16 #include "compiler/translator/tree_util/IntermNode_util.h"
17 #include "compiler/translator/tree_util/IntermTraverse.h"
18 #include "compiler/translator/tree_util/ReplaceVariable.h"
19
20 namespace sh
21 {
22 namespace
23 {
24 // Only structs with matrices are tracked. If layout(row_major) is applied to a struct that doesn't
25 // have matrices, it's silently dropped. This is also used to avoid creating duplicates for inner
26 // structs that don't have matrices.
27 struct StructConversionData
28 {
29 // The converted struct with every matrix transposed.
30 TStructure *convertedStruct = nullptr;
31
32 // The copy-from and copy-to functions copying from a struct to its converted version and back.
33 TFunction *copyFromOriginal = nullptr;
34 TFunction *copyToOriginal = nullptr;
35 };
36
DoesFieldContainRowMajorMatrix(const TField * field,bool isBlockRowMajor)37 bool DoesFieldContainRowMajorMatrix(const TField *field, bool isBlockRowMajor)
38 {
39 TLayoutMatrixPacking matrixPacking = field->type()->getLayoutQualifier().matrixPacking;
40
41 // The field is row major if either explicitly specified as such, or if it inherits it from the
42 // block layout qualifier.
43 if (matrixPacking == EmpColumnMajor || (matrixPacking == EmpUnspecified && !isBlockRowMajor))
44 {
45 return false;
46 }
47
48 // The field is qualified with row_major, but if it's not a matrix or a struct containing
49 // matrices, that's a useless qualifier.
50 const TType *type = field->type();
51 return type->isMatrix() || type->isStructureContainingMatrices();
52 }
53
DuplicateField(const TField * field)54 TField *DuplicateField(const TField *field)
55 {
56 return new TField(new TType(*field->type()), field->name(), field->line(), field->symbolType());
57 }
58
SetColumnMajor(TType * type)59 void SetColumnMajor(TType *type)
60 {
61 TLayoutQualifier layoutQualifier = type->getLayoutQualifier();
62 layoutQualifier.matrixPacking = EmpColumnMajor;
63 type->setLayoutQualifier(layoutQualifier);
64 }
65
TransposeMatrixType(const TType * type)66 TType *TransposeMatrixType(const TType *type)
67 {
68 TType *newType = new TType(*type);
69
70 SetColumnMajor(newType);
71
72 newType->setPrimarySize(type->getRows());
73 newType->setSecondarySize(type->getCols());
74
75 return newType;
76 }
77
CopyArraySizes(const TType * from,TType * to)78 void CopyArraySizes(const TType *from, TType *to)
79 {
80 if (from->isArray())
81 {
82 to->makeArrays(from->getArraySizes());
83 }
84 }
85
86 // Determine if the node is an index node (array index or struct field selection). For the purposes
87 // of this transformation, swizzle nodes are considered index nodes too.
IsIndexNode(TIntermNode * node,TIntermNode * child)88 bool IsIndexNode(TIntermNode *node, TIntermNode *child)
89 {
90 if (node->getAsSwizzleNode())
91 {
92 return true;
93 }
94
95 TIntermBinary *binaryNode = node->getAsBinaryNode();
96 if (binaryNode == nullptr || child != binaryNode->getLeft())
97 {
98 return false;
99 }
100
101 TOperator op = binaryNode->getOp();
102
103 return op == EOpIndexDirect || op == EOpIndexDirectInterfaceBlock ||
104 op == EOpIndexDirectStruct || op == EOpIndexIndirect;
105 }
106
CopyToTempVariable(TSymbolTable * symbolTable,TIntermTyped * node,TIntermSequence * prependStatements)107 TIntermSymbol *CopyToTempVariable(TSymbolTable *symbolTable,
108 TIntermTyped *node,
109 TIntermSequence *prependStatements)
110 {
111 TVariable *temp = CreateTempVariable(symbolTable, &node->getType(), EvqTemporary);
112 TIntermDeclaration *tempDecl = CreateTempInitDeclarationNode(temp, node);
113 prependStatements->push_back(tempDecl);
114
115 return new TIntermSymbol(temp);
116 }
117
CreateStructCopyCall(const TFunction * copyFunc,TIntermTyped * expression)118 TIntermAggregate *CreateStructCopyCall(const TFunction *copyFunc, TIntermTyped *expression)
119 {
120 TIntermSequence args = {expression};
121 return TIntermAggregate::CreateFunctionCall(*copyFunc, &args);
122 }
123
CreateTransposeCall(TSymbolTable * symbolTable,TIntermTyped * expression)124 TIntermTyped *CreateTransposeCall(TSymbolTable *symbolTable, TIntermTyped *expression)
125 {
126 TIntermSequence args = {expression};
127 return CreateBuiltInFunctionCallNode("transpose", &args, *symbolTable, 300);
128 }
129
GetIndex(TSymbolTable * symbolTable,TIntermNode * node,TIntermSequence * indices,TIntermSequence * prependStatements)130 TOperator GetIndex(TSymbolTable *symbolTable,
131 TIntermNode *node,
132 TIntermSequence *indices,
133 TIntermSequence *prependStatements)
134 {
135 // Swizzle nodes are converted EOpIndexDirect for simplicity, with one index per swizzle
136 // channel.
137 TIntermSwizzle *asSwizzle = node->getAsSwizzleNode();
138 if (asSwizzle)
139 {
140 for (int channel : asSwizzle->getSwizzleOffsets())
141 {
142 indices->push_back(CreateIndexNode(channel));
143 }
144 return EOpIndexDirect;
145 }
146
147 TIntermBinary *binaryNode = node->getAsBinaryNode();
148 ASSERT(binaryNode);
149
150 TOperator op = binaryNode->getOp();
151 ASSERT(op == EOpIndexDirect || op == EOpIndexDirectInterfaceBlock ||
152 op == EOpIndexDirectStruct || op == EOpIndexIndirect);
153
154 TIntermTyped *rhs = binaryNode->getRight()->deepCopy();
155 if (rhs->getAsConstantUnion() == nullptr)
156 {
157 rhs = CopyToTempVariable(symbolTable, rhs, prependStatements);
158 }
159
160 indices->push_back(rhs);
161 return op;
162 }
163
ReplicateIndexNode(TSymbolTable * symbolTable,TIntermNode * node,TIntermTyped * lhs,TIntermSequence * indices)164 TIntermTyped *ReplicateIndexNode(TSymbolTable *symbolTable,
165 TIntermNode *node,
166 TIntermTyped *lhs,
167 TIntermSequence *indices)
168 {
169 TIntermSwizzle *asSwizzle = node->getAsSwizzleNode();
170 if (asSwizzle)
171 {
172 return new TIntermSwizzle(lhs, asSwizzle->getSwizzleOffsets());
173 }
174
175 TIntermBinary *binaryNode = node->getAsBinaryNode();
176 ASSERT(binaryNode);
177
178 ASSERT(indices->size() == 1);
179 TIntermTyped *rhs = indices->front()->getAsTyped();
180
181 return new TIntermBinary(binaryNode->getOp(), lhs, rhs);
182 }
183
GetIndexOp(TIntermNode * node)184 TOperator GetIndexOp(TIntermNode *node)
185 {
186 return node->getAsConstantUnion() ? EOpIndexDirect : EOpIndexIndirect;
187 }
188
IsConvertedField(TIntermTyped * indexNode,const angle::HashMap<const TField *,bool> & convertedFields)189 bool IsConvertedField(TIntermTyped *indexNode,
190 const angle::HashMap<const TField *, bool> &convertedFields)
191 {
192 TIntermBinary *asBinary = indexNode->getAsBinaryNode();
193 if (asBinary == nullptr)
194 {
195 return false;
196 }
197
198 if (asBinary->getOp() != EOpIndexDirectInterfaceBlock)
199 {
200 return false;
201 }
202
203 const TInterfaceBlock *interfaceBlock = asBinary->getLeft()->getType().getInterfaceBlock();
204 ASSERT(interfaceBlock);
205
206 TIntermConstantUnion *fieldIndexNode = asBinary->getRight()->getAsConstantUnion();
207 ASSERT(fieldIndexNode);
208 ASSERT(fieldIndexNode->getConstantValue() != nullptr);
209
210 int fieldIndex = fieldIndexNode->getConstantValue()->getIConst();
211 const TField *field = interfaceBlock->fields()[fieldIndex];
212
213 return convertedFields.count(field) > 0 && convertedFields.at(field);
214 }
215
216 // A helper class to transform expressions of array type. Iterates over every element of the
217 // array.
218 class TransformArrayHelper
219 {
220 public:
TransformArrayHelper(TIntermTyped * baseExpression)221 TransformArrayHelper(TIntermTyped *baseExpression)
222 : mBaseExpression(baseExpression),
223 mBaseExpressionType(baseExpression->getType()),
224 mArrayIndices(mBaseExpressionType.getArraySizes().size(), 0)
225 {}
226
getNextElement(TIntermTyped * valueExpression,TIntermTyped ** valueElementOut)227 TIntermTyped *getNextElement(TIntermTyped *valueExpression, TIntermTyped **valueElementOut)
228 {
229 const angle::Span<const unsigned int> &arraySizes = mBaseExpressionType.getArraySizes();
230
231 // If the last index overflows, element enumeration is done.
232 if (mArrayIndices.back() >= arraySizes.back())
233 {
234 return nullptr;
235 }
236
237 TIntermTyped *element = getCurrentElement(mBaseExpression);
238 if (valueExpression)
239 {
240 *valueElementOut = getCurrentElement(valueExpression);
241 }
242
243 incrementIndices(arraySizes);
244 return element;
245 }
246
accumulateForRead(TSymbolTable * symbolTable,TIntermTyped * transformedElement,TIntermSequence * prependStatements)247 void accumulateForRead(TSymbolTable *symbolTable,
248 TIntermTyped *transformedElement,
249 TIntermSequence *prependStatements)
250 {
251 TIntermTyped *temp = CopyToTempVariable(symbolTable, transformedElement, prependStatements);
252 mReadTransformConstructorArgs.push_back(temp);
253 }
254
constructReadTransformExpression()255 TIntermTyped *constructReadTransformExpression()
256 {
257 const angle::Span<const unsigned int> &baseTypeArraySizes =
258 mBaseExpressionType.getArraySizes();
259 TVector<unsigned int> arraySizes(baseTypeArraySizes.begin(), baseTypeArraySizes.end());
260 TIntermTyped *firstElement = mReadTransformConstructorArgs.front()->getAsTyped();
261 const TType &baseType = firstElement->getType();
262
263 // If N dimensions, acc[0] == size[0] and acc[i] == size[i] * acc[i-1].
264 // The last value is unused, and is not present.
265 TVector<unsigned int> accumulatedArraySizes(arraySizes.size() - 1);
266
267 if (accumulatedArraySizes.size() > 0)
268 {
269 accumulatedArraySizes[0] = arraySizes[0];
270 }
271 for (size_t index = 1; index + 1 < arraySizes.size(); ++index)
272 {
273 accumulatedArraySizes[index] = accumulatedArraySizes[index - 1] * arraySizes[index];
274 }
275
276 return constructReadTransformExpressionHelper(arraySizes, accumulatedArraySizes, baseType,
277 0);
278 }
279
280 private:
getCurrentElement(TIntermTyped * expression)281 TIntermTyped *getCurrentElement(TIntermTyped *expression)
282 {
283 TIntermTyped *element = expression->deepCopy();
284 for (auto it = mArrayIndices.rbegin(); it != mArrayIndices.rend(); ++it)
285 {
286 unsigned int index = *it;
287 element = new TIntermBinary(EOpIndexDirect, element, CreateIndexNode(index));
288 }
289 return element;
290 }
291
incrementIndices(const angle::Span<const unsigned int> & arraySizes)292 void incrementIndices(const angle::Span<const unsigned int> &arraySizes)
293 {
294 // Assume mArrayIndices is an N digit number, where digit i is in the range
295 // [0, arraySizes[i]). This function increments this number. Last digit is the most
296 // significant digit.
297 for (size_t digitIndex = 0; digitIndex < arraySizes.size(); ++digitIndex)
298 {
299 ++mArrayIndices[digitIndex];
300 if (mArrayIndices[digitIndex] < arraySizes[digitIndex])
301 {
302 break;
303 }
304 if (digitIndex + 1 != arraySizes.size())
305 {
306 // This digit has now overflown and is reset to 0, carry will be added to the next
307 // digit. The most significant digit will keep the overflow though, to make it
308 // clear we have exhausted the range.
309 mArrayIndices[digitIndex] = 0;
310 }
311 }
312 }
313
constructReadTransformExpressionHelper(const TVector<unsigned int> & arraySizes,const TVector<unsigned int> & accumulatedArraySizes,const TType & baseType,size_t elementsOffset)314 TIntermTyped *constructReadTransformExpressionHelper(
315 const TVector<unsigned int> &arraySizes,
316 const TVector<unsigned int> &accumulatedArraySizes,
317 const TType &baseType,
318 size_t elementsOffset)
319 {
320 ASSERT(!arraySizes.empty());
321
322 TType *transformedType = new TType(baseType);
323 transformedType->makeArrays(arraySizes);
324
325 // If one dimensional, create the constructor with the given elements.
326 if (arraySizes.size() == 1)
327 {
328 ASSERT(accumulatedArraySizes.size() == 0);
329
330 auto sliceStart = mReadTransformConstructorArgs.begin() + elementsOffset;
331 TIntermSequence slice(sliceStart, sliceStart + arraySizes[0]);
332
333 return TIntermAggregate::CreateConstructor(*transformedType, &slice);
334 }
335
336 // If not, create constructors for every column recursively.
337 TVector<unsigned int> subArraySizes(arraySizes.begin(), arraySizes.end() - 1);
338 TVector<unsigned int> subArrayAccumulatedSizes(accumulatedArraySizes.begin(),
339 accumulatedArraySizes.end() - 1);
340
341 TIntermSequence constructorArgs;
342 unsigned int colStride = accumulatedArraySizes.back();
343 for (size_t col = 0; col < arraySizes.back(); ++col)
344 {
345 size_t colElementsOffset = elementsOffset + col * colStride;
346
347 constructorArgs.push_back(constructReadTransformExpressionHelper(
348 subArraySizes, subArrayAccumulatedSizes, baseType, colElementsOffset));
349 }
350
351 return TIntermAggregate::CreateConstructor(*transformedType, &constructorArgs);
352 }
353
354 TIntermTyped *mBaseExpression;
355 const TType &mBaseExpressionType;
356 TVector<unsigned int> mArrayIndices;
357
358 TIntermSequence mReadTransformConstructorArgs;
359 };
360
361 // Traverser that:
362 //
363 // 1. Converts |layout(row_major) matCxR M| to |layout(column_major) matRxC Mt|.
364 // 2. Converts |layout(row_major) S s| to |layout(column_major) St st|, where S is a struct that
365 // contains matrices, and St is a new struct with the transformation in 1 applied to matrix
366 // members (recursively).
367 // 3. When read from, the following transformations are applied:
368 //
369 // M -> transpose(Mt)
370 // M[c] -> gvecN(Mt[0][c], Mt[1][c], ..., Mt[N-1][c])
371 // M[c][r] -> Mt[r][c]
372 // M[c].yz -> gvec2(Mt[1][c], Mt[2][c])
373 // MArr -> MType[D1]..[DN](transpose(MtArr[0]...[0]), ...)
374 // s -> copy_St_to_S(st)
375 // sArr -> SType[D1]...[DN](copy_St_to_S(stArr[0]..[0]), ...)
376 // (matrix reads through struct are transformed similarly to M)
377 //
378 // 4. When written to, the following transformations are applied:
379 //
380 // M = exp -> Mt = transpose(exp)
381 // M[c] = exp -> temp = exp
382 // Mt[0][c] = temp[0]
383 // Mt[1][c] = temp[1]
384 // ...
385 // Mt[N-1][c] = temp[N-1]
386 // M[c][r] = exp -> Mt[r][c] = exp
387 // M[c].yz = exp -> temp = exp
388 // Mt[1][c] = temp[0]
389 // Mt[2][c] = temp[1]
390 // MArr = exp -> temp = exp
391 // Mt = MtType[D1]..[DN](temp([0]...[0]), ...)
392 // s = exp -> st = copy_S_to_St(exp)
393 // sArr = exp -> temp = exp
394 // St = StType[D1]...[DN](copy_S_to_St(temp[0]..[0]), ...)
395 // (matrix writes through struct are transformed similarly to M)
396 //
397 // 5. If any of the above is passed to an `inout` parameter, both transformations are applied:
398 //
399 // f(M[c]) -> temp = gvecN(Mt[0][c], Mt[1][c], ..., Mt[N-1][c])
400 // f(temp)
401 // Mt[0][c] = temp[0]
402 // Mt[1][c] = temp[1]
403 // ...
404 // Mt[N-1][c] = temp[N-1]
405 //
406 // f(s) -> temp = copy_St_to_S(st)
407 // f(temp)
408 // st = copy_S_to_St(temp)
409 //
410 // If passed to an `out` parameter, the `temp` parameter is simply not initialized.
411 //
412 // 6. If the expression leading to the matrix or struct has array subscripts, temp values are
413 // created for them to avoid duplicating side effects.
414 //
415 class RewriteRowMajorMatricesTraverser : public TIntermTraverser
416 {
417 public:
RewriteRowMajorMatricesTraverser(TCompiler * compiler,TSymbolTable * symbolTable)418 RewriteRowMajorMatricesTraverser(TCompiler *compiler, TSymbolTable *symbolTable)
419 : TIntermTraverser(true, true, true, symbolTable),
420 mCompiler(compiler),
421 mStructMapOut(&mOuterPass.structMap),
422 mInterfaceBlockMap(&mOuterPass.interfaceBlockMap),
423 mInterfaceBlockFieldConvertedIn(mOuterPass.interfaceBlockFieldConverted),
424 mCopyFunctionDefinitionsOut(&mOuterPass.copyFunctionDefinitions),
425 mOuterTraverser(nullptr),
426 mInnerPassRoot(nullptr),
427 mIsProcessingInnerPassSubtree(false)
428 {}
429
visitDeclaration(Visit visit,TIntermDeclaration * node)430 bool visitDeclaration(Visit visit, TIntermDeclaration *node) override
431 {
432 // No need to process declarations in inner passes.
433 if (mInnerPassRoot != nullptr)
434 {
435 return true;
436 }
437
438 if (visit != PreVisit)
439 {
440 return true;
441 }
442
443 const TIntermSequence &sequence = *(node->getSequence());
444
445 TIntermTyped *variable = sequence.front()->getAsTyped();
446 const TType &type = variable->getType();
447
448 // If it's a struct declaration that has matrices, remember it. If a row-major instance
449 // of it is created, it will have to be converted.
450 if (type.isStructSpecifier() && type.isStructureContainingMatrices())
451 {
452 const TStructure *structure = type.getStruct();
453 ASSERT(structure);
454
455 ASSERT(mOuterPass.structMap.count(structure) == 0);
456
457 StructConversionData structData;
458 mOuterPass.structMap[structure] = structData;
459
460 return false;
461 }
462
463 // If it's an interface block, it may have to be converted if it contains any row-major
464 // fields.
465 if (type.isInterfaceBlock() && type.getInterfaceBlock()->containsMatrices())
466 {
467 const TInterfaceBlock *block = type.getInterfaceBlock();
468 ASSERT(block);
469 bool isBlockRowMajor = type.getLayoutQualifier().matrixPacking == EmpRowMajor;
470
471 const TFieldList &fields = block->fields();
472 bool anyRowMajor = isBlockRowMajor;
473
474 for (const TField *field : fields)
475 {
476 if (DoesFieldContainRowMajorMatrix(field, isBlockRowMajor))
477 {
478 anyRowMajor = true;
479 break;
480 }
481 }
482
483 if (anyRowMajor)
484 {
485 convertInterfaceBlock(node);
486 }
487
488 return false;
489 }
490
491 return true;
492 }
493
visitSymbol(TIntermSymbol * symbol)494 void visitSymbol(TIntermSymbol *symbol) override
495 {
496 // If in inner pass, only process if the symbol is under that root.
497 if (mInnerPassRoot != nullptr && !mIsProcessingInnerPassSubtree)
498 {
499 return;
500 }
501
502 const TVariable *variable = &symbol->variable();
503 bool needsRewrite = mInterfaceBlockMap->count(variable) != 0;
504
505 // If it's a field of a nameless interface block, it may still need conversion.
506 if (!needsRewrite)
507 {
508 // Nameless interface block field symbols have the interface block pointer set, but are
509 // not interface blocks.
510 if (symbol->getType().getInterfaceBlock() && !variable->getType().isInterfaceBlock())
511 {
512 needsRewrite = convertNamelessInterfaceBlockField(symbol);
513 }
514 }
515
516 if (needsRewrite)
517 {
518 transformExpression(symbol);
519 }
520 }
521
visitBinary(Visit visit,TIntermBinary * node)522 bool visitBinary(Visit visit, TIntermBinary *node) override
523 {
524 if (node == mInnerPassRoot)
525 {
526 // We only want to process the right-hand side of an assignment in inner passes. When
527 // visit is InVisit, the left-hand side is already processed, and the right-hand side is
528 // next. Set a flag to mark this duration.
529 mIsProcessingInnerPassSubtree = visit == InVisit;
530 }
531
532 return true;
533 }
534
getStructCopyFunctions()535 TIntermSequence *getStructCopyFunctions() { return &mOuterPass.copyFunctionDefinitions; }
536
537 private:
538 typedef angle::HashMap<const TStructure *, StructConversionData> StructMap;
539 typedef angle::HashMap<const TVariable *, TVariable *> InterfaceBlockMap;
540 typedef angle::HashMap<const TField *, bool> InterfaceBlockFieldConverted;
541
RewriteRowMajorMatricesTraverser(TSymbolTable * symbolTable,RewriteRowMajorMatricesTraverser * outerTraverser,InterfaceBlockMap * interfaceBlockMap,const InterfaceBlockFieldConverted & interfaceBlockFieldConverted,StructMap * structMap,TIntermSequence * copyFunctionDefinitions,TIntermBinary * innerPassRoot)542 RewriteRowMajorMatricesTraverser(
543 TSymbolTable *symbolTable,
544 RewriteRowMajorMatricesTraverser *outerTraverser,
545 InterfaceBlockMap *interfaceBlockMap,
546 const InterfaceBlockFieldConverted &interfaceBlockFieldConverted,
547 StructMap *structMap,
548 TIntermSequence *copyFunctionDefinitions,
549 TIntermBinary *innerPassRoot)
550 : TIntermTraverser(true, true, true, symbolTable),
551 mStructMapOut(structMap),
552 mInterfaceBlockMap(interfaceBlockMap),
553 mInterfaceBlockFieldConvertedIn(interfaceBlockFieldConverted),
554 mCopyFunctionDefinitionsOut(copyFunctionDefinitions),
555 mOuterTraverser(outerTraverser),
556 mInnerPassRoot(innerPassRoot),
557 mIsProcessingInnerPassSubtree(false)
558 {}
559
convertInterfaceBlock(TIntermDeclaration * node)560 void convertInterfaceBlock(TIntermDeclaration *node)
561 {
562 ASSERT(mInnerPassRoot == nullptr);
563
564 const TIntermSequence &sequence = *(node->getSequence());
565
566 TIntermTyped *variableNode = sequence.front()->getAsTyped();
567 const TType &type = variableNode->getType();
568 const TInterfaceBlock *block = type.getInterfaceBlock();
569 ASSERT(block);
570
571 bool isBlockRowMajor = type.getLayoutQualifier().matrixPacking == EmpRowMajor;
572
573 // Recreate the struct with its row-major fields converted to column-major equivalents.
574 TIntermSequence newDeclarations;
575
576 TFieldList *newFields = new TFieldList;
577 for (const TField *field : block->fields())
578 {
579 TField *newField = nullptr;
580
581 if (DoesFieldContainRowMajorMatrix(field, isBlockRowMajor))
582 {
583 newField = convertField(field, &newDeclarations);
584
585 // Remember that this field was converted.
586 mOuterPass.interfaceBlockFieldConverted[field] = true;
587 }
588 else
589 {
590 newField = DuplicateField(field);
591 }
592
593 newFields->push_back(newField);
594 }
595
596 // Create a new interface block with these fields.
597 TLayoutQualifier blockLayoutQualifier = type.getLayoutQualifier();
598 blockLayoutQualifier.matrixPacking = EmpColumnMajor;
599
600 TInterfaceBlock *newInterfaceBlock =
601 new TInterfaceBlock(mSymbolTable, block->name(), newFields, blockLayoutQualifier,
602 block->symbolType(), block->extensions());
603
604 // Create a new declaration with the new type. Declarations are separated at this point,
605 // so there should be only one variable here.
606 ASSERT(sequence.size() == 1);
607
608 TType *newInterfaceBlockType =
609 new TType(newInterfaceBlock, type.getQualifier(), blockLayoutQualifier);
610
611 TIntermDeclaration *newDeclaration = new TIntermDeclaration;
612 const TVariable *variable = &variableNode->getAsSymbolNode()->variable();
613
614 const TType *newType = newInterfaceBlockType;
615 if (type.isArray())
616 {
617 TType *newArrayType = new TType(*newType);
618 CopyArraySizes(&type, newArrayType);
619 newType = newArrayType;
620 }
621
622 // If the interface block variable itself is temp, use an empty name.
623 bool variableIsTemp = variable->symbolType() == SymbolType::Empty;
624 const ImmutableString &variableName =
625 variableIsTemp ? kEmptyImmutableString : variable->name();
626
627 TVariable *newVariable = new TVariable(mSymbolTable, variableName, newType,
628 variable->symbolType(), variable->extensions());
629
630 newDeclaration->appendDeclarator(new TIntermSymbol(newVariable));
631
632 mOuterPass.interfaceBlockMap[variable] = newVariable;
633
634 newDeclarations.push_back(newDeclaration);
635
636 // Replace the interface block definition with the new one, prepending any new struct
637 // definitions.
638 mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node,
639 std::move(newDeclarations));
640 }
641
convertNamelessInterfaceBlockField(TIntermSymbol * symbol)642 bool convertNamelessInterfaceBlockField(TIntermSymbol *symbol)
643 {
644 const TVariable *variable = &symbol->variable();
645 const TInterfaceBlock *interfaceBlock = symbol->getType().getInterfaceBlock();
646
647 // Find the variable corresponding to this interface block. If the interface block
648 // is not rewritten, or this refers to a field that is not rewritten, there's
649 // nothing to do.
650 for (auto iter : *mInterfaceBlockMap)
651 {
652 // Skip other rewritten nameless interface block fields.
653 if (!iter.first->getType().isInterfaceBlock())
654 {
655 continue;
656 }
657
658 // Skip if this is not a field of this rewritten interface block.
659 if (iter.first->getType().getInterfaceBlock() != interfaceBlock)
660 {
661 continue;
662 }
663
664 const ImmutableString symbolName = symbol->getName();
665
666 // Find which field it is
667 const TVector<TField *> fields = interfaceBlock->fields();
668 const size_t fieldIndex = variable->getType().getInterfaceBlockFieldIndex();
669 ASSERT(fieldIndex < fields.size());
670
671 const TField *field = fields[fieldIndex];
672 ASSERT(field->name() == symbolName);
673
674 // If this field doesn't need a rewrite, there's nothing to do.
675 if (mInterfaceBlockFieldConvertedIn.count(field) == 0 ||
676 !mInterfaceBlockFieldConvertedIn.at(field))
677 {
678 break;
679 }
680
681 // Create a new variable that references the replaced interface block.
682 TType *newType = new TType(variable->getType());
683 newType->setInterfaceBlockField(iter.second->getType().getInterfaceBlock(), fieldIndex);
684
685 TVariable *newVariable = new TVariable(mSymbolTable, variable->name(), newType,
686 variable->symbolType(), variable->extensions());
687
688 (*mInterfaceBlockMap)[variable] = newVariable;
689
690 return true;
691 }
692
693 return false;
694 }
695
convertStruct(const TStructure * structure,TIntermSequence * newDeclarations)696 void convertStruct(const TStructure *structure, TIntermSequence *newDeclarations)
697 {
698 ASSERT(mInnerPassRoot == nullptr);
699
700 ASSERT(mOuterPass.structMap.count(structure) != 0);
701 StructConversionData *structData = &mOuterPass.structMap[structure];
702
703 if (structData->convertedStruct)
704 {
705 return;
706 }
707
708 TFieldList *newFields = new TFieldList;
709 for (const TField *field : structure->fields())
710 {
711 newFields->push_back(convertField(field, newDeclarations));
712 }
713
714 // Create unique names for the converted structs. We can't leave them nameless and have
715 // a name autogenerated similar to temp variables, as nameless structs exist. A fake
716 // variable is created for the sole purpose of generating a temp name.
717 TVariable *newStructTypeName =
718 new TVariable(mSymbolTable, kEmptyImmutableString,
719 StaticType::GetBasic<EbtUInt, EbpUndefined>(), SymbolType::Empty);
720
721 TStructure *newStruct = new TStructure(mSymbolTable, newStructTypeName->name(), newFields,
722 SymbolType::AngleInternal);
723 TType *newType = new TType(newStruct, true);
724 TVariable *newStructVar =
725 new TVariable(mSymbolTable, kEmptyImmutableString, newType, SymbolType::Empty);
726
727 TIntermDeclaration *structDecl = new TIntermDeclaration;
728 structDecl->appendDeclarator(new TIntermSymbol(newStructVar));
729
730 newDeclarations->push_back(structDecl);
731
732 structData->convertedStruct = newStruct;
733 }
734
convertField(const TField * field,TIntermSequence * newDeclarations)735 TField *convertField(const TField *field, TIntermSequence *newDeclarations)
736 {
737 ASSERT(mInnerPassRoot == nullptr);
738
739 TField *newField = nullptr;
740
741 const TType *fieldType = field->type();
742 TType *newType = nullptr;
743
744 if (fieldType->isStructureContainingMatrices())
745 {
746 // If the field is a struct instance, convert the struct and replace the field
747 // with an instance of the new struct.
748 const TStructure *fieldTypeStruct = fieldType->getStruct();
749 convertStruct(fieldTypeStruct, newDeclarations);
750
751 StructConversionData &structData = mOuterPass.structMap[fieldTypeStruct];
752 newType = new TType(structData.convertedStruct, false);
753 SetColumnMajor(newType);
754 CopyArraySizes(fieldType, newType);
755 }
756 else if (fieldType->isMatrix())
757 {
758 // If the field is a matrix, transpose the matrix and replace the field with
759 // that, removing the matrix packing qualifier.
760 newType = TransposeMatrixType(fieldType);
761 }
762
763 if (newType)
764 {
765 newField = new TField(newType, field->name(), field->line(), field->symbolType());
766 }
767 else
768 {
769 newField = DuplicateField(field);
770 }
771
772 return newField;
773 }
774
determineAccess(TIntermNode * expression,TIntermNode * accessor,bool * isReadOut,bool * isWriteOut)775 void determineAccess(TIntermNode *expression,
776 TIntermNode *accessor,
777 bool *isReadOut,
778 bool *isWriteOut)
779 {
780 // If passing to a function, look at whether the parameter is in, out or inout.
781 TIntermAggregate *functionCall = accessor->getAsAggregate();
782
783 if (functionCall)
784 {
785 TIntermSequence *arguments = functionCall->getSequence();
786 for (size_t argIndex = 0; argIndex < arguments->size(); ++argIndex)
787 {
788 if ((*arguments)[argIndex] == expression)
789 {
790 TQualifier qualifier = EvqParamIn;
791
792 // If the aggregate is not a function call, it's a constructor, and so every
793 // argument is an input.
794 const TFunction *function = functionCall->getFunction();
795 if (function)
796 {
797 const TVariable *param = function->getParam(argIndex);
798 qualifier = param->getType().getQualifier();
799 }
800
801 *isReadOut = qualifier != EvqParamOut;
802 *isWriteOut = qualifier == EvqParamOut || qualifier == EvqParamInOut;
803 break;
804 }
805 }
806 return;
807 }
808
809 TIntermBinary *assignment = accessor->getAsBinaryNode();
810 if (assignment && IsAssignment(assignment->getOp()))
811 {
812 // If expression is on the right of assignment, it's being read from.
813 *isReadOut = assignment->getRight() == expression;
814 // If it's on the left of assignment, it's being written to.
815 *isWriteOut = assignment->getLeft() == expression;
816 return;
817 }
818
819 // Any other usage is a read.
820 *isReadOut = true;
821 *isWriteOut = false;
822 }
823
transformExpression(TIntermSymbol * symbol)824 void transformExpression(TIntermSymbol *symbol)
825 {
826 // Walk up the parent chain while the nodes are EOpIndex* (whether array indexing or struct
827 // field selection) or swizzle and construct the replacement expression. This traversal can
828 // lead to one of the following possibilities:
829 //
830 // - a.b[N].etc.s (struct, or struct array): copy function should be declared and used,
831 // - a.b[N].etc.M (matrix or matrix array): transpose() should be used,
832 // - a.b[N].etc.M[c] (a column): each element in column needs to be handled separately,
833 // - a.b[N].etc.M[c].yz (multiple elements): similar to whole column, but a subset of
834 // elements,
835 // - a.b[N].etc.M[c][r] (an element): single element to handle.
836 // - a.b[N].etc.x (not struct or matrix): not modified
837 //
838 // primaryIndex will contain c, if any. secondaryIndices will contain {0, ..., R-1}
839 // (if no [r] or swizzle), {r} (if [r]), or {1, 2} (corresponding to .yz) if any.
840 //
841 // In all cases, the base symbol is replaced. |baseExpression| will contain everything up
842 // to (and not including) the last index/swizzle operations, i.e. a.b[N].etc.s/M/x. Any
843 // non constant array subscript is assigned to a temp variable to avoid duplicating side
844 // effects.
845 //
846 // ---
847 //
848 // NOTE that due to the use of insertStatementsInParentBlock, cases like this will be
849 // mistranslated, and this bug is likely present in most transformations that use this
850 // feature:
851 //
852 // if (x == 1 && a.b[x = 2].etc.M = value)
853 //
854 // which will translate to:
855 //
856 // temp = (x = 2)
857 // if (x == 1 && a.b[temp].etc.M = transpose(value))
858 //
859 // See http://anglebug.com/42262472.
860 //
861 TIntermTyped *baseExpression =
862 new TIntermSymbol(mInterfaceBlockMap->at(&symbol->variable()));
863 const TStructure *structure = nullptr;
864
865 TIntermNode *primaryIndex = nullptr;
866 TIntermSequence secondaryIndices;
867
868 // In some cases, it is necessary to prepend or append statements. Those are captured in
869 // |prependStatements| and |appendStatements|.
870 TIntermSequence prependStatements;
871 TIntermSequence appendStatements;
872
873 // If the expression is neither a struct or matrix, no modification is necessary.
874 // If it's a struct that doesn't have matrices, again there's no transformation necessary.
875 // If it's an interface block matrix field that didn't need to be transposed, no
876 // transpformation is necessary.
877 //
878 // In all these cases, |baseExpression| contains all of the original expression.
879 //
880 // If the starting symbol itself is a field of a nameless interface block, it needs
881 // conversion if we reach here.
882 bool requiresTransformation = !symbol->getType().isInterfaceBlock();
883
884 uint32_t accessorIndex = 0;
885 TIntermTyped *previousAncestor = symbol;
886 while (IsIndexNode(getAncestorNode(accessorIndex), previousAncestor))
887 {
888 TIntermTyped *ancestor = getAncestorNode(accessorIndex)->getAsTyped();
889 ASSERT(ancestor);
890
891 const TType &previousAncestorType = previousAncestor->getType();
892
893 TIntermSequence indices;
894 TOperator op = GetIndex(mSymbolTable, ancestor, &indices, &prependStatements);
895
896 bool opIsIndex = op == EOpIndexDirect || op == EOpIndexIndirect;
897 bool isArrayIndex = opIsIndex && previousAncestorType.isArray();
898 bool isMatrixIndex = opIsIndex && previousAncestorType.isMatrix();
899
900 // If it's a direct index in a matrix, it's the primary index.
901 bool isMatrixPrimarySubscript = isMatrixIndex && !isArrayIndex;
902 ASSERT(!isMatrixPrimarySubscript ||
903 (primaryIndex == nullptr && secondaryIndices.empty()));
904 // If primary index is seen and the ancestor is still an index, it must be a direct
905 // index as the secondary one. Note that if primaryIndex is set, there can only ever be
906 // one more parent of interest, and that's subscripting the second dimension.
907 bool isMatrixSecondarySubscript = primaryIndex != nullptr;
908 ASSERT(!isMatrixSecondarySubscript || (opIsIndex && !isArrayIndex));
909
910 if (requiresTransformation && isMatrixPrimarySubscript)
911 {
912 ASSERT(indices.size() == 1);
913 primaryIndex = indices.front();
914
915 // Default the secondary indices to include every row. If there's a secondary
916 // subscript provided, it will override this.
917 const uint8_t rows = previousAncestorType.getRows();
918 for (uint8_t r = 0; r < rows; ++r)
919 {
920 secondaryIndices.push_back(CreateIndexNode(r));
921 }
922 }
923 else if (isMatrixSecondarySubscript)
924 {
925 ASSERT(requiresTransformation);
926
927 secondaryIndices = indices;
928
929 // Indices after this point are not interesting. There can't actually be any other
930 // index nodes. Unlike desktop GLSL, ESSL does not support swizzles on scalars
931 // (like M[1][2].yyy).
932 ++accessorIndex;
933 break;
934 }
935 else
936 {
937 // Replicate the expression otherwise.
938 baseExpression =
939 ReplicateIndexNode(mSymbolTable, ancestor, baseExpression, &indices);
940
941 const TType &ancestorType = ancestor->getType();
942 structure = ancestorType.getStruct();
943
944 requiresTransformation =
945 requiresTransformation ||
946 IsConvertedField(ancestor, mInterfaceBlockFieldConvertedIn);
947
948 // If we reach a point where the expression is neither a matrix-containing struct
949 // nor a matrix, there's no transformation required. This can happen if we decend
950 // through a struct marked with row-major but arrive at a member that doesn't
951 // include a matrix.
952 if (!ancestorType.isMatrix() && !ancestorType.isStructureContainingMatrices())
953 {
954 requiresTransformation = false;
955 }
956 }
957
958 previousAncestor = ancestor;
959 ++accessorIndex;
960 }
961
962 TIntermNode *originalExpression =
963 accessorIndex == 0 ? symbol : getAncestorNode(accessorIndex - 1);
964 TIntermNode *accessor = getAncestorNode(accessorIndex);
965
966 // if accessor is EOpArrayLength, we don't need to perform any transformations either.
967 // Note that this only applies to unsized arrays, as the RemoveArrayLengthMethod()
968 // transformation would have removed this operation otherwise.
969 TIntermUnary *accessorAsUnary = accessor->getAsUnaryNode();
970 if (requiresTransformation && accessorAsUnary && accessorAsUnary->getOp() == EOpArrayLength)
971 {
972 ASSERT(accessorAsUnary->getOperand() == originalExpression);
973 ASSERT(accessorAsUnary->getOperand()->getType().isUnsizedArray());
974
975 requiresTransformation = false;
976
977 // We need to replace the whole expression including the EOpArrayLength, to avoid
978 // confusing the replacement code as the original and new expressions don't have the
979 // same type (one is the transpose of the other). This doesn't affect the .length()
980 // operation, so this replacement is ok, though it's not worth special-casing this in
981 // the node replacement algorithm.
982 //
983 // Note: the |if (!requiresTransformation)| immediately below will be entered after
984 // this.
985 originalExpression = accessor;
986 accessor = getAncestorNode(accessorIndex + 1);
987 baseExpression = new TIntermUnary(EOpArrayLength, baseExpression, nullptr);
988 }
989
990 if (!requiresTransformation)
991 {
992 ASSERT(primaryIndex == nullptr);
993 queueReplacementWithParent(accessor, originalExpression, baseExpression,
994 OriginalNode::IS_DROPPED);
995
996 RewriteRowMajorMatricesTraverser *traverser = mOuterTraverser ? mOuterTraverser : this;
997 traverser->insertStatementsInParentBlock(prependStatements, appendStatements);
998 return;
999 }
1000
1001 ASSERT(structure == nullptr || primaryIndex == nullptr);
1002 ASSERT(structure != nullptr || baseExpression->getType().isMatrix());
1003
1004 // At the end, we can determine if the expression is being read from or written to (or both,
1005 // if sent as an inout parameter to a function). For the sake of the transformation, the
1006 // left-hand side of operations like += can be treated as "written to", without necessarily
1007 // "read from".
1008 bool isRead = false;
1009 bool isWrite = false;
1010
1011 determineAccess(originalExpression, accessor, &isRead, &isWrite);
1012
1013 ASSERT(isRead || isWrite);
1014
1015 TIntermTyped *readExpression = nullptr;
1016 if (isRead)
1017 {
1018 readExpression = transformReadExpression(
1019 baseExpression, primaryIndex, &secondaryIndices, structure, &prependStatements);
1020
1021 // If both read from and written to (i.e. passed to inout parameter), store the
1022 // expression in a temp variable and pass that to the function.
1023 if (isWrite)
1024 {
1025 readExpression =
1026 CopyToTempVariable(mSymbolTable, readExpression, &prependStatements);
1027 }
1028
1029 // Replace the original expression with the transformed one. Read transformations
1030 // always generate a single expression that can be used in place of the original (as
1031 // oppposed to write transformations that can generate multiple statements).
1032 queueReplacementWithParent(accessor, originalExpression, readExpression,
1033 OriginalNode::IS_DROPPED);
1034 }
1035
1036 TIntermSequence postTransformPrependStatements;
1037 TIntermSequence *writeStatements = &appendStatements;
1038 TOperator assignmentOperator = EOpAssign;
1039
1040 if (isWrite)
1041 {
1042 TIntermTyped *valueExpression = readExpression;
1043
1044 if (!valueExpression)
1045 {
1046 // If there's already a read expression, this was an inout parameter and
1047 // |valueExpression| will contain the temp variable that was passed to the function
1048 // instead.
1049 //
1050 // If not, then the modification is either through being passed as an out parameter
1051 // to a function, or an assignment. In the former case, create a temp variable to
1052 // be passed to the function. In the latter case, create a temp variable that holds
1053 // the right hand side expression.
1054 //
1055 // In either case, use that temp value as the value to assign to |baseExpression|.
1056
1057 TVariable *temp = CreateTempVariable(
1058 mSymbolTable, &originalExpression->getAsTyped()->getType(), EvqTemporary);
1059 TIntermDeclaration *tempDecl = nullptr;
1060
1061 valueExpression = new TIntermSymbol(temp);
1062
1063 TIntermBinary *assignment = accessor->getAsBinaryNode();
1064 if (assignment)
1065 {
1066 assignmentOperator = assignment->getOp();
1067 ASSERT(IsAssignment(assignmentOperator));
1068
1069 // We are converting the assignment to the left-hand side of an expression in
1070 // the form M=exp. A subexpression of exp itself could require a
1071 // transformation. This complicates things as there would be two replacements:
1072 //
1073 // - Replace M=exp with temp (because the return value of the assignment could
1074 // be used)
1075 // - Replace exp with exp2, where parent is M=exp
1076 //
1077 // The second replacement however is ineffective as the whole of M=exp is
1078 // already transformed. What's worse, M=exp is transformed without taking exp's
1079 // transformations into account. To address this issue, this same traverser is
1080 // called on the right-hand side expression, with a special flag such that it
1081 // only processes that expression.
1082 //
1083 RewriteRowMajorMatricesTraverser *outerTraverser =
1084 mOuterTraverser ? mOuterTraverser : this;
1085 RewriteRowMajorMatricesTraverser rhsTraverser(
1086 mSymbolTable, outerTraverser, mInterfaceBlockMap,
1087 mInterfaceBlockFieldConvertedIn, mStructMapOut, mCopyFunctionDefinitionsOut,
1088 assignment);
1089 getRootNode()->traverse(&rhsTraverser);
1090 bool valid = rhsTraverser.updateTree(mCompiler, getRootNode());
1091 ASSERT(valid);
1092
1093 tempDecl = CreateTempInitDeclarationNode(temp, assignment->getRight());
1094
1095 // Replace the whole assignment expression with the right-hand side as a read
1096 // expression, in case the result of the assignment is used. For example, this
1097 // transforms:
1098 //
1099 // if ((M += exp) == X)
1100 // {
1101 // // use M
1102 // }
1103 //
1104 // to:
1105 //
1106 // temp = exp;
1107 // M += transform(temp);
1108 // if (transform(M) == X)
1109 // {
1110 // // use M
1111 // }
1112 //
1113 // Note that in this case the assignment to M must be prepended in the parent
1114 // block. In contrast, when sent to a function, the assignment to M should be
1115 // done after the current function call is done.
1116 //
1117 // If the read from M itself (to replace assigmnet) needs to generate extra
1118 // statements, they should be appended after the statements that write to M.
1119 // These statements are stored in postTransformPrependStatements and appended to
1120 // prependStatements in the end.
1121 //
1122 writeStatements = &prependStatements;
1123
1124 TIntermTyped *assignmentResultExpression = transformReadExpression(
1125 baseExpression->deepCopy(), primaryIndex, &secondaryIndices, structure,
1126 &postTransformPrependStatements);
1127
1128 // Replace the whole assignment, instead of just the right hand side.
1129 TIntermNode *accessorParent = getAncestorNode(accessorIndex + 1);
1130 queueReplacementWithParent(accessorParent, accessor, assignmentResultExpression,
1131 OriginalNode::IS_DROPPED);
1132 }
1133 else
1134 {
1135 tempDecl = CreateTempDeclarationNode(temp);
1136
1137 // Replace the write expression (a function call argument) with the temp
1138 // variable.
1139 queueReplacementWithParent(accessor, originalExpression, valueExpression,
1140 OriginalNode::IS_DROPPED);
1141 }
1142 prependStatements.push_back(tempDecl);
1143 }
1144
1145 if (isRead)
1146 {
1147 baseExpression = baseExpression->deepCopy();
1148 }
1149 transformWriteExpression(baseExpression, primaryIndex, &secondaryIndices, structure,
1150 valueExpression, assignmentOperator, writeStatements);
1151 }
1152
1153 prependStatements.insert(prependStatements.end(), postTransformPrependStatements.begin(),
1154 postTransformPrependStatements.end());
1155
1156 RewriteRowMajorMatricesTraverser *traverser = mOuterTraverser ? mOuterTraverser : this;
1157 traverser->insertStatementsInParentBlock(prependStatements, appendStatements);
1158 }
1159
transformReadExpression(TIntermTyped * baseExpression,TIntermNode * primaryIndex,TIntermSequence * secondaryIndices,const TStructure * structure,TIntermSequence * prependStatements)1160 TIntermTyped *transformReadExpression(TIntermTyped *baseExpression,
1161 TIntermNode *primaryIndex,
1162 TIntermSequence *secondaryIndices,
1163 const TStructure *structure,
1164 TIntermSequence *prependStatements)
1165 {
1166 const TType &baseExpressionType = baseExpression->getType();
1167
1168 if (structure)
1169 {
1170 ASSERT(primaryIndex == nullptr && secondaryIndices->empty());
1171 ASSERT(mStructMapOut->count(structure) != 0);
1172 ASSERT((*mStructMapOut)[structure].convertedStruct != nullptr);
1173
1174 // Declare copy-from-converted-to-original-struct function (if not already).
1175 declareStructCopyToOriginal(structure);
1176
1177 const TFunction *copyToOriginal = (*mStructMapOut)[structure].copyToOriginal;
1178
1179 if (baseExpressionType.isArray())
1180 {
1181 // If base expression is an array, transform every element.
1182 TransformArrayHelper transformHelper(baseExpression);
1183
1184 TIntermTyped *element = nullptr;
1185 while ((element = transformHelper.getNextElement(nullptr, nullptr)) != nullptr)
1186 {
1187 TIntermTyped *transformedElement =
1188 CreateStructCopyCall(copyToOriginal, element);
1189 transformHelper.accumulateForRead(mSymbolTable, transformedElement,
1190 prependStatements);
1191 }
1192 return transformHelper.constructReadTransformExpression();
1193 }
1194 else
1195 {
1196 // If not reading an array, the result is simply a call to this function with the
1197 // base expression.
1198 return CreateStructCopyCall(copyToOriginal, baseExpression);
1199 }
1200 }
1201
1202 // If not indexed, the result is transpose(exp)
1203 if (primaryIndex == nullptr)
1204 {
1205 ASSERT(secondaryIndices->empty());
1206
1207 if (baseExpressionType.isArray())
1208 {
1209 // If array, transpose every element.
1210 TransformArrayHelper transformHelper(baseExpression);
1211
1212 TIntermTyped *element = nullptr;
1213 while ((element = transformHelper.getNextElement(nullptr, nullptr)) != nullptr)
1214 {
1215 TIntermTyped *transformedElement = CreateTransposeCall(mSymbolTable, element);
1216 transformHelper.accumulateForRead(mSymbolTable, transformedElement,
1217 prependStatements);
1218 }
1219 return transformHelper.constructReadTransformExpression();
1220 }
1221 else
1222 {
1223 return CreateTransposeCall(mSymbolTable, baseExpression);
1224 }
1225 }
1226
1227 // If indexed the result is a vector (or just one element) where the primary and secondary
1228 // indices are swapped.
1229 ASSERT(!secondaryIndices->empty());
1230
1231 TOperator primaryIndexOp = GetIndexOp(primaryIndex);
1232 TIntermTyped *primaryIndexAsTyped = primaryIndex->getAsTyped();
1233
1234 TIntermSequence transposedColumn;
1235 for (TIntermNode *secondaryIndex : *secondaryIndices)
1236 {
1237 TOperator secondaryIndexOp = GetIndexOp(secondaryIndex);
1238 TIntermTyped *secondaryIndexAsTyped = secondaryIndex->getAsTyped();
1239
1240 TIntermBinary *colIndexed = new TIntermBinary(
1241 secondaryIndexOp, baseExpression->deepCopy(), secondaryIndexAsTyped->deepCopy());
1242 TIntermBinary *colRowIndexed =
1243 new TIntermBinary(primaryIndexOp, colIndexed, primaryIndexAsTyped->deepCopy());
1244
1245 transposedColumn.push_back(colRowIndexed);
1246 }
1247
1248 if (secondaryIndices->size() == 1)
1249 {
1250 // If only one element, return that directly.
1251 return transposedColumn.front()->getAsTyped();
1252 }
1253
1254 // Otherwise create a constructor with the appropriate dimension.
1255 TType *vecType = new TType(baseExpressionType.getBasicType(), secondaryIndices->size());
1256 return TIntermAggregate::CreateConstructor(*vecType, &transposedColumn);
1257 }
1258
transformWriteExpression(TIntermTyped * baseExpression,TIntermNode * primaryIndex,TIntermSequence * secondaryIndices,const TStructure * structure,TIntermTyped * valueExpression,TOperator assignmentOperator,TIntermSequence * writeStatements)1259 void transformWriteExpression(TIntermTyped *baseExpression,
1260 TIntermNode *primaryIndex,
1261 TIntermSequence *secondaryIndices,
1262 const TStructure *structure,
1263 TIntermTyped *valueExpression,
1264 TOperator assignmentOperator,
1265 TIntermSequence *writeStatements)
1266 {
1267 const TType &baseExpressionType = baseExpression->getType();
1268
1269 if (structure)
1270 {
1271 ASSERT(primaryIndex == nullptr && secondaryIndices->empty());
1272 ASSERT(mStructMapOut->count(structure) != 0);
1273 ASSERT((*mStructMapOut)[structure].convertedStruct != nullptr);
1274
1275 // Declare copy-to-converted-from-original-struct function (if not already).
1276 declareStructCopyFromOriginal(structure);
1277
1278 // The result is call to this function with the value expression assigned to base
1279 // expression.
1280 const TFunction *copyFromOriginal = (*mStructMapOut)[structure].copyFromOriginal;
1281
1282 if (baseExpressionType.isArray())
1283 {
1284 // If array, assign every element.
1285 TransformArrayHelper transformHelper(baseExpression);
1286
1287 TIntermTyped *element = nullptr;
1288 TIntermTyped *valueElement = nullptr;
1289 while ((element = transformHelper.getNextElement(valueExpression, &valueElement)) !=
1290 nullptr)
1291 {
1292 TIntermTyped *functionCall =
1293 CreateStructCopyCall(copyFromOriginal, valueElement);
1294 writeStatements->push_back(new TIntermBinary(EOpAssign, element, functionCall));
1295 }
1296 }
1297 else
1298 {
1299 TIntermTyped *functionCall =
1300 CreateStructCopyCall(copyFromOriginal, valueExpression->deepCopy());
1301 writeStatements->push_back(
1302 new TIntermBinary(EOpAssign, baseExpression, functionCall));
1303 }
1304
1305 return;
1306 }
1307
1308 // If not indexed, the result is transpose(exp)
1309 if (primaryIndex == nullptr)
1310 {
1311 ASSERT(secondaryIndices->empty());
1312
1313 if (baseExpressionType.isArray())
1314 {
1315 // If array, assign every element.
1316 TransformArrayHelper transformHelper(baseExpression);
1317
1318 TIntermTyped *element = nullptr;
1319 TIntermTyped *valueElement = nullptr;
1320 while ((element = transformHelper.getNextElement(valueExpression, &valueElement)) !=
1321 nullptr)
1322 {
1323 TIntermTyped *valueTransposed = CreateTransposeCall(mSymbolTable, valueElement);
1324 writeStatements->push_back(
1325 new TIntermBinary(EOpAssign, element, valueTransposed));
1326 }
1327 }
1328 else
1329 {
1330 TIntermTyped *valueTransposed =
1331 CreateTransposeCall(mSymbolTable, valueExpression->deepCopy());
1332 writeStatements->push_back(
1333 new TIntermBinary(assignmentOperator, baseExpression, valueTransposed));
1334 }
1335
1336 return;
1337 }
1338
1339 // If indexed, create one assignment per secondary index. If the right-hand side is a
1340 // scalar, it's used with every assignment. If it's a vector, the assignment is
1341 // per-component. The right-hand side cannot be a matrix as that would imply left-hand
1342 // side being a matrix too, which is covered above where |primaryIndex == nullptr|.
1343 ASSERT(!secondaryIndices->empty());
1344
1345 bool isValueExpressionScalar = valueExpression->getType().getNominalSize() == 1;
1346 ASSERT(isValueExpressionScalar || valueExpression->getType().getNominalSize() ==
1347 static_cast<int>(secondaryIndices->size()));
1348
1349 TOperator primaryIndexOp = GetIndexOp(primaryIndex);
1350 TIntermTyped *primaryIndexAsTyped = primaryIndex->getAsTyped();
1351
1352 for (TIntermNode *secondaryIndex : *secondaryIndices)
1353 {
1354 TOperator secondaryIndexOp = GetIndexOp(secondaryIndex);
1355 TIntermTyped *secondaryIndexAsTyped = secondaryIndex->getAsTyped();
1356
1357 TIntermBinary *colIndexed = new TIntermBinary(
1358 secondaryIndexOp, baseExpression->deepCopy(), secondaryIndexAsTyped->deepCopy());
1359 TIntermBinary *colRowIndexed =
1360 new TIntermBinary(primaryIndexOp, colIndexed, primaryIndexAsTyped->deepCopy());
1361
1362 TIntermTyped *valueExpressionIndexed = valueExpression->deepCopy();
1363 if (!isValueExpressionScalar)
1364 {
1365 valueExpressionIndexed = new TIntermBinary(secondaryIndexOp, valueExpressionIndexed,
1366 secondaryIndexAsTyped->deepCopy());
1367 }
1368
1369 writeStatements->push_back(
1370 new TIntermBinary(assignmentOperator, colRowIndexed, valueExpressionIndexed));
1371 }
1372 }
1373
getCopyStructFieldFunction(const TType * fromFieldType,const TType * toFieldType,bool isCopyToOriginal)1374 const TFunction *getCopyStructFieldFunction(const TType *fromFieldType,
1375 const TType *toFieldType,
1376 bool isCopyToOriginal)
1377 {
1378 ASSERT(fromFieldType->getStruct());
1379 ASSERT(toFieldType->getStruct());
1380
1381 // If copying from or to the original struct, the "to" field struct could require
1382 // conversion to or from the "from" field struct. |isCopyToOriginal| tells us if we
1383 // should expect to find toField or fromField in mStructMapOut, if true or false
1384 // respectively.
1385 const TFunction *fieldCopyFunction = nullptr;
1386 if (isCopyToOriginal)
1387 {
1388 const TStructure *toFieldStruct = toFieldType->getStruct();
1389
1390 auto iter = mStructMapOut->find(toFieldStruct);
1391 if (iter != mStructMapOut->end())
1392 {
1393 declareStructCopyToOriginal(toFieldStruct);
1394 fieldCopyFunction = iter->second.copyToOriginal;
1395 }
1396 }
1397 else
1398 {
1399 const TStructure *fromFieldStruct = fromFieldType->getStruct();
1400
1401 auto iter = mStructMapOut->find(fromFieldStruct);
1402 if (iter != mStructMapOut->end())
1403 {
1404 declareStructCopyFromOriginal(fromFieldStruct);
1405 fieldCopyFunction = iter->second.copyFromOriginal;
1406 }
1407 }
1408
1409 return fieldCopyFunction;
1410 }
1411
addFieldCopy(TIntermBlock * body,TIntermTyped * to,TIntermTyped * from,bool isCopyToOriginal)1412 void addFieldCopy(TIntermBlock *body,
1413 TIntermTyped *to,
1414 TIntermTyped *from,
1415 bool isCopyToOriginal)
1416 {
1417 const TType &fromType = from->getType();
1418 const TType &toType = to->getType();
1419
1420 TIntermTyped *rhs = from;
1421
1422 if (fromType.getStruct())
1423 {
1424 const TFunction *fieldCopyFunction =
1425 getCopyStructFieldFunction(&fromType, &toType, isCopyToOriginal);
1426
1427 if (fieldCopyFunction)
1428 {
1429 rhs = CreateStructCopyCall(fieldCopyFunction, from);
1430 }
1431 }
1432 else if (fromType.isMatrix())
1433 {
1434 rhs = CreateTransposeCall(mSymbolTable, from);
1435 }
1436
1437 body->appendStatement(new TIntermBinary(EOpAssign, to, rhs));
1438 }
1439
declareStructCopy(const TStructure * from,const TStructure * to,bool isCopyToOriginal)1440 TFunction *declareStructCopy(const TStructure *from,
1441 const TStructure *to,
1442 bool isCopyToOriginal)
1443 {
1444 TType *fromType = new TType(from, true);
1445 TType *toType = new TType(to, true);
1446
1447 // Create the parameter and return value variables.
1448 TVariable *fromVar = new TVariable(mSymbolTable, ImmutableString("from"), fromType,
1449 SymbolType::AngleInternal);
1450 TVariable *toVar =
1451 new TVariable(mSymbolTable, ImmutableString("to"), toType, SymbolType::AngleInternal);
1452
1453 TIntermSymbol *fromSymbol = new TIntermSymbol(fromVar);
1454 TIntermSymbol *toSymbol = new TIntermSymbol(toVar);
1455
1456 // Create the function body as statements are generated.
1457 TIntermBlock *body = new TIntermBlock;
1458
1459 // Declare the result variable.
1460 TIntermDeclaration *toDecl = new TIntermDeclaration();
1461 toDecl->appendDeclarator(toSymbol);
1462 body->appendStatement(toDecl);
1463
1464 // Iterate over fields of the struct and copy one by one, transposing the matrices. If a
1465 // struct is encountered that requires a transformation, this function is recursively
1466 // called. As a result, it is important that the copy functions are placed in the code in
1467 // order.
1468 const TFieldList &fromFields = from->fields();
1469 const TFieldList &toFields = to->fields();
1470 ASSERT(fromFields.size() == toFields.size());
1471
1472 for (size_t fieldIndex = 0; fieldIndex < fromFields.size(); ++fieldIndex)
1473 {
1474 TIntermTyped *fieldIndexNode = CreateIndexNode(static_cast<int>(fieldIndex));
1475
1476 TIntermTyped *fromField =
1477 new TIntermBinary(EOpIndexDirectStruct, fromSymbol->deepCopy(), fieldIndexNode);
1478 TIntermTyped *toField = new TIntermBinary(EOpIndexDirectStruct, toSymbol->deepCopy(),
1479 fieldIndexNode->deepCopy());
1480
1481 const TType *fromFieldType = fromFields[fieldIndex]->type();
1482 bool isStructOrMatrix = fromFieldType->getStruct() || fromFieldType->isMatrix();
1483
1484 if (fromFieldType->isArray() && isStructOrMatrix)
1485 {
1486 // If struct or matrix array, we need to copy element by element.
1487 TransformArrayHelper transformHelper(toField);
1488
1489 TIntermTyped *toElement = nullptr;
1490 TIntermTyped *fromElement = nullptr;
1491 while ((toElement = transformHelper.getNextElement(fromField, &fromElement)) !=
1492 nullptr)
1493 {
1494 addFieldCopy(body, toElement, fromElement, isCopyToOriginal);
1495 }
1496 }
1497 else
1498 {
1499 addFieldCopy(body, toField, fromField, isCopyToOriginal);
1500 }
1501 }
1502
1503 // Add return statement.
1504 body->appendStatement(new TIntermBranch(EOpReturn, toSymbol->deepCopy()));
1505
1506 // Declare the function
1507 TFunction *copyFunction = new TFunction(mSymbolTable, kEmptyImmutableString,
1508 SymbolType::AngleInternal, toType, true);
1509 copyFunction->addParameter(fromVar);
1510
1511 TIntermFunctionDefinition *functionDef =
1512 CreateInternalFunctionDefinitionNode(*copyFunction, body);
1513 mCopyFunctionDefinitionsOut->push_back(functionDef);
1514
1515 return copyFunction;
1516 }
1517
declareStructCopyFromOriginal(const TStructure * structure)1518 void declareStructCopyFromOriginal(const TStructure *structure)
1519 {
1520 StructConversionData *structData = &(*mStructMapOut)[structure];
1521 if (structData->copyFromOriginal)
1522 {
1523 return;
1524 }
1525
1526 structData->copyFromOriginal =
1527 declareStructCopy(structure, structData->convertedStruct, false);
1528 }
1529
declareStructCopyToOriginal(const TStructure * structure)1530 void declareStructCopyToOriginal(const TStructure *structure)
1531 {
1532 StructConversionData *structData = &(*mStructMapOut)[structure];
1533 if (structData->copyToOriginal)
1534 {
1535 return;
1536 }
1537
1538 structData->copyToOriginal =
1539 declareStructCopy(structData->convertedStruct, structure, true);
1540 }
1541
1542 TCompiler *mCompiler;
1543
1544 // This traverser can call itself to transform a subexpression before moving on. However, it
1545 // needs to accumulate conversion functions in inner passes. The fields below marked with Out
1546 // or In are inherited from the outer pass (for inner passes), or point to storage fields in
1547 // mOuterPass (for the outer pass). The latter should not be used by the inner passes as they
1548 // would be empty, so they are placed inside a struct to make them explicit.
1549 struct
1550 {
1551 StructMap structMap;
1552 InterfaceBlockMap interfaceBlockMap;
1553 InterfaceBlockFieldConverted interfaceBlockFieldConverted;
1554 TIntermSequence copyFunctionDefinitions;
1555 } mOuterPass;
1556
1557 // A map from structures with matrices to their converted version.
1558 StructMap *mStructMapOut;
1559 // A map from interface block instances with row-major matrices to their converted variable. If
1560 // an interface block is nameless, its fields are placed in this map instead. When a variable
1561 // in this map is encountered, it signals the start of an expression that my need conversion,
1562 // which is either "interfaceBlock.field..." or "field..." if nameless.
1563 InterfaceBlockMap *mInterfaceBlockMap;
1564 // A map from interface block fields to whether they need to be converted. If a field was
1565 // already column-major, it shouldn't be transposed.
1566 const InterfaceBlockFieldConverted &mInterfaceBlockFieldConvertedIn;
1567
1568 TIntermSequence *mCopyFunctionDefinitionsOut;
1569
1570 // If set, it's an inner pass and this will point to the outer pass traverser. All statement
1571 // insertions are stored in the outer traverser and applied at once in the end. This prevents
1572 // the inner passes from adding statements which invalidates the outer traverser's statement
1573 // position tracking.
1574 RewriteRowMajorMatricesTraverser *mOuterTraverser;
1575
1576 // If set, it's an inner pass that should only process the right-hand side of this particular
1577 // node.
1578 TIntermBinary *mInnerPassRoot;
1579 bool mIsProcessingInnerPassSubtree;
1580 };
1581
1582 } // anonymous namespace
1583
RewriteRowMajorMatrices(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable)1584 bool RewriteRowMajorMatrices(TCompiler *compiler, TIntermBlock *root, TSymbolTable *symbolTable)
1585 {
1586 RewriteRowMajorMatricesTraverser traverser(compiler, symbolTable);
1587 root->traverse(&traverser);
1588 if (!traverser.updateTree(compiler, root))
1589 {
1590 return false;
1591 }
1592
1593 size_t firstFunctionIndex = FindFirstFunctionDefinitionIndex(root);
1594 root->insertChildNodes(firstFunctionIndex, *traverser.getStructCopyFunctions());
1595
1596 return compiler->validateAST(root);
1597 }
1598 } // namespace sh
1599