1 //
2 // Copyright 2019 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RewriteRowMajorMatrices: Rewrite row-major matrices as column-major.
7 //
8
9 #include "compiler/translator/tree_ops/RewriteRowMajorMatrices.h"
10
11 #include "compiler/translator/Compiler.h"
12 #include "compiler/translator/ImmutableStringBuilder.h"
13 #include "compiler/translator/StaticType.h"
14 #include "compiler/translator/SymbolTable.h"
15 #include "compiler/translator/tree_util/IntermNode_util.h"
16 #include "compiler/translator/tree_util/IntermTraverse.h"
17 #include "compiler/translator/tree_util/ReplaceVariable.h"
18
19 namespace sh
20 {
21 namespace
22 {
23 // Only structs with matrices are tracked. If layout(row_major) is applied to a struct that doesn't
24 // have matrices, it's silently dropped. This is also used to avoid creating duplicates for inner
25 // structs that don't have matrices.
26 struct StructConversionData
27 {
28 // The converted struct with every matrix transposed.
29 TStructure *convertedStruct = nullptr;
30
31 // The copy-from and copy-to functions copying from a struct to its converted version and back.
32 TFunction *copyFromOriginal = nullptr;
33 TFunction *copyToOriginal = nullptr;
34 };
35
DoesFieldContainRowMajorMatrix(const TField * field,bool isBlockRowMajor)36 bool DoesFieldContainRowMajorMatrix(const TField *field, bool isBlockRowMajor)
37 {
38 TLayoutMatrixPacking matrixPacking = field->type()->getLayoutQualifier().matrixPacking;
39
40 // The field is row major if either explicitly specified as such, or if it inherits it from the
41 // block layout qualifier.
42 if (matrixPacking == EmpColumnMajor || (matrixPacking == EmpUnspecified && !isBlockRowMajor))
43 {
44 return false;
45 }
46
47 // The field is qualified with row_major, but if it's not a matrix or a struct containing
48 // matrices, that's a useless qualifier.
49 const TType *type = field->type();
50 return type->isMatrix() || type->isStructureContainingMatrices();
51 }
52
DuplicateField(const TField * field)53 TField *DuplicateField(const TField *field)
54 {
55 return new TField(new TType(*field->type()), field->name(), field->line(), field->symbolType());
56 }
57
SetColumnMajor(TType * type)58 void SetColumnMajor(TType *type)
59 {
60 TLayoutQualifier layoutQualifier = type->getLayoutQualifier();
61 layoutQualifier.matrixPacking = EmpColumnMajor;
62 type->setLayoutQualifier(layoutQualifier);
63 }
64
TransposeMatrixType(const TType * type)65 TType *TransposeMatrixType(const TType *type)
66 {
67 TType *newType = new TType(*type);
68
69 SetColumnMajor(newType);
70
71 newType->setPrimarySize(static_cast<unsigned char>(type->getRows()));
72 newType->setSecondarySize(static_cast<unsigned char>(type->getCols()));
73
74 return newType;
75 }
76
CopyArraySizes(const TType * from,TType * to)77 void CopyArraySizes(const TType *from, TType *to)
78 {
79 if (from->isArray())
80 {
81 to->makeArrays(from->getArraySizes());
82 }
83 }
84
85 // Determine if the node is an index node (array index or struct field selection). For the purposes
86 // of this transformation, swizzle nodes are considered index nodes too.
IsIndexNode(TIntermNode * node,TIntermNode * child)87 bool IsIndexNode(TIntermNode *node, TIntermNode *child)
88 {
89 if (node->getAsSwizzleNode())
90 {
91 return true;
92 }
93
94 TIntermBinary *binaryNode = node->getAsBinaryNode();
95 if (binaryNode == nullptr || child != binaryNode->getLeft())
96 {
97 return false;
98 }
99
100 TOperator op = binaryNode->getOp();
101
102 return op == EOpIndexDirect || op == EOpIndexDirectInterfaceBlock ||
103 op == EOpIndexDirectStruct || op == EOpIndexIndirect;
104 }
105
CopyToTempVariable(TSymbolTable * symbolTable,TIntermTyped * node,TIntermSequence * prependStatements)106 TIntermSymbol *CopyToTempVariable(TSymbolTable *symbolTable,
107 TIntermTyped *node,
108 TIntermSequence *prependStatements)
109 {
110 TVariable *temp = CreateTempVariable(symbolTable, &node->getType());
111 TIntermDeclaration *tempDecl = CreateTempInitDeclarationNode(temp, node);
112 prependStatements->push_back(tempDecl);
113
114 return new TIntermSymbol(temp);
115 }
116
CreateStructCopyCall(const TFunction * copyFunc,TIntermTyped * expression)117 TIntermAggregate *CreateStructCopyCall(const TFunction *copyFunc, TIntermTyped *expression)
118 {
119 return TIntermAggregate::CreateFunctionCall(*copyFunc, new TIntermSequence({expression}));
120 }
121
CreateTransposeCall(TSymbolTable * symbolTable,TIntermTyped * expression)122 TIntermTyped *CreateTransposeCall(TSymbolTable *symbolTable, TIntermTyped *expression)
123 {
124 return CreateBuiltInFunctionCallNode("transpose", new TIntermSequence({expression}),
125 *symbolTable, 300);
126 }
127
GetIndex(TSymbolTable * symbolTable,TIntermNode * node,TIntermSequence * indices,TIntermSequence * prependStatements)128 TOperator GetIndex(TSymbolTable *symbolTable,
129 TIntermNode *node,
130 TIntermSequence *indices,
131 TIntermSequence *prependStatements)
132 {
133 // Swizzle nodes are converted EOpIndexDirect for simplicity, with one index per swizzle
134 // channel.
135 TIntermSwizzle *asSwizzle = node->getAsSwizzleNode();
136 if (asSwizzle)
137 {
138 for (int channel : asSwizzle->getSwizzleOffsets())
139 {
140 indices->push_back(CreateIndexNode(channel));
141 }
142 return EOpIndexDirect;
143 }
144
145 TIntermBinary *binaryNode = node->getAsBinaryNode();
146 ASSERT(binaryNode);
147
148 TOperator op = binaryNode->getOp();
149 ASSERT(op == EOpIndexDirect || op == EOpIndexDirectInterfaceBlock ||
150 op == EOpIndexDirectStruct || op == EOpIndexIndirect);
151
152 TIntermTyped *rhs = binaryNode->getRight()->deepCopy();
153 if (rhs->getAsConstantUnion() == nullptr)
154 {
155 rhs = CopyToTempVariable(symbolTable, rhs, prependStatements);
156 }
157
158 indices->push_back(rhs);
159 return op;
160 }
161
ReplicateIndexNode(TSymbolTable * symbolTable,TIntermNode * node,TIntermTyped * lhs,TIntermSequence * indices)162 TIntermTyped *ReplicateIndexNode(TSymbolTable *symbolTable,
163 TIntermNode *node,
164 TIntermTyped *lhs,
165 TIntermSequence *indices)
166 {
167 TIntermSwizzle *asSwizzle = node->getAsSwizzleNode();
168 if (asSwizzle)
169 {
170 return new TIntermSwizzle(lhs, asSwizzle->getSwizzleOffsets());
171 }
172
173 TIntermBinary *binaryNode = node->getAsBinaryNode();
174 ASSERT(binaryNode);
175
176 ASSERT(indices->size() == 1);
177 TIntermTyped *rhs = indices->front()->getAsTyped();
178
179 return new TIntermBinary(binaryNode->getOp(), lhs, rhs);
180 }
181
GetIndexOp(TIntermNode * node)182 TOperator GetIndexOp(TIntermNode *node)
183 {
184 return node->getAsConstantUnion() ? EOpIndexDirect : EOpIndexIndirect;
185 }
186
IsConvertedField(TIntermTyped * indexNode,const std::unordered_map<const TField *,bool> & convertedFields)187 bool IsConvertedField(TIntermTyped *indexNode,
188 const std::unordered_map<const TField *, bool> &convertedFields)
189 {
190 TIntermBinary *asBinary = indexNode->getAsBinaryNode();
191 if (asBinary == nullptr)
192 {
193 return false;
194 }
195
196 if (asBinary->getOp() != EOpIndexDirectInterfaceBlock)
197 {
198 return false;
199 }
200
201 const TInterfaceBlock *interfaceBlock = asBinary->getLeft()->getType().getInterfaceBlock();
202 ASSERT(interfaceBlock);
203
204 TIntermConstantUnion *fieldIndexNode = asBinary->getRight()->getAsConstantUnion();
205 ASSERT(fieldIndexNode);
206 ASSERT(fieldIndexNode->getConstantValue() != nullptr);
207
208 int fieldIndex = fieldIndexNode->getConstantValue()->getIConst();
209 const TField *field = interfaceBlock->fields()[fieldIndex];
210
211 return convertedFields.count(field) > 0 && convertedFields.at(field);
212 }
213
214 // A helper class to transform expressions of array type. Iterates over every element of the
215 // array.
216 class TransformArrayHelper
217 {
218 public:
TransformArrayHelper(TIntermTyped * baseExpression)219 TransformArrayHelper(TIntermTyped *baseExpression)
220 : mBaseExpression(baseExpression),
221 mBaseExpressionType(baseExpression->getType()),
222 mArrayIndices(mBaseExpressionType.getArraySizes().size(), 0)
223 {}
224
getNextElement(TIntermTyped * valueExpression,TIntermTyped ** valueElementOut)225 TIntermTyped *getNextElement(TIntermTyped *valueExpression, TIntermTyped **valueElementOut)
226 {
227 const TSpan<const unsigned int> &arraySizes = mBaseExpressionType.getArraySizes();
228
229 // If the last index overflows, element enumeration is done.
230 if (mArrayIndices.back() >= arraySizes.back())
231 {
232 return nullptr;
233 }
234
235 TIntermTyped *element = getCurrentElement(mBaseExpression);
236 if (valueExpression)
237 {
238 *valueElementOut = getCurrentElement(valueExpression);
239 }
240
241 incrementIndices(arraySizes);
242 return element;
243 }
244
accumulateForRead(TSymbolTable * symbolTable,TIntermTyped * transformedElement,TIntermSequence * prependStatements)245 void accumulateForRead(TSymbolTable *symbolTable,
246 TIntermTyped *transformedElement,
247 TIntermSequence *prependStatements)
248 {
249 TIntermTyped *temp = CopyToTempVariable(symbolTable, transformedElement, prependStatements);
250 mReadTransformConstructorArgs.push_back(temp);
251 }
252
constructReadTransformExpression()253 TIntermTyped *constructReadTransformExpression()
254 {
255 const TSpan<const unsigned int> &baseTypeArraySizes = mBaseExpressionType.getArraySizes();
256 TVector<unsigned int> arraySizes(baseTypeArraySizes.begin(), baseTypeArraySizes.end());
257 TIntermTyped *firstElement = mReadTransformConstructorArgs.front()->getAsTyped();
258 const TType &baseType = firstElement->getType();
259
260 // If N dimensions, acc[0] == size[0] and acc[i] == size[i] * acc[i-1].
261 // The last value is unused, and is not present.
262 TVector<unsigned int> accumulatedArraySizes(arraySizes.size() - 1);
263
264 accumulatedArraySizes[0] = arraySizes[0];
265 for (size_t index = 1; index + 1 < arraySizes.size(); ++index)
266 {
267 accumulatedArraySizes[index] = accumulatedArraySizes[index - 1] * arraySizes[index];
268 }
269
270 return constructReadTransformExpressionHelper(arraySizes, accumulatedArraySizes, baseType,
271 0);
272 }
273
274 private:
getCurrentElement(TIntermTyped * expression)275 TIntermTyped *getCurrentElement(TIntermTyped *expression)
276 {
277 TIntermTyped *element = expression->deepCopy();
278 for (auto it = mArrayIndices.rbegin(); it != mArrayIndices.rend(); ++it)
279 {
280 unsigned int index = *it;
281 element = new TIntermBinary(EOpIndexDirect, element, CreateIndexNode(index));
282 }
283 return element;
284 }
285
incrementIndices(const TSpan<const unsigned int> & arraySizes)286 void incrementIndices(const TSpan<const unsigned int> &arraySizes)
287 {
288 // Assume mArrayIndices is an N digit number, where digit i is in the range
289 // [0, arraySizes[i]). This function increments this number. Last digit is the most
290 // significant digit.
291 for (size_t digitIndex = 0; digitIndex < arraySizes.size(); ++digitIndex)
292 {
293 ++mArrayIndices[digitIndex];
294 if (mArrayIndices[digitIndex] < arraySizes[digitIndex])
295 {
296 break;
297 }
298 if (digitIndex + 1 != arraySizes.size())
299 {
300 // This digit has now overflown and is reset to 0, carry will be added to the next
301 // digit. The most significant digit will keep the overflow though, to make it
302 // clear we have exhausted the range.
303 mArrayIndices[digitIndex] = 0;
304 }
305 }
306 }
307
constructReadTransformExpressionHelper(const TVector<unsigned int> & arraySizes,const TVector<unsigned int> & accumulatedArraySizes,const TType & baseType,size_t elementsOffset)308 TIntermTyped *constructReadTransformExpressionHelper(
309 const TVector<unsigned int> &arraySizes,
310 const TVector<unsigned int> &accumulatedArraySizes,
311 const TType &baseType,
312 size_t elementsOffset)
313 {
314 ASSERT(!arraySizes.empty());
315
316 TType *transformedType = new TType(baseType);
317 transformedType->makeArrays(arraySizes);
318
319 // If one dimensional, create the constructor with the given elements.
320 if (arraySizes.size() == 1)
321 {
322 ASSERT(accumulatedArraySizes.size() == 0);
323
324 auto sliceStart = mReadTransformConstructorArgs.begin() + elementsOffset;
325 TIntermSequence slice(sliceStart, sliceStart + arraySizes[0]);
326
327 return TIntermAggregate::CreateConstructor(*transformedType, &slice);
328 }
329
330 // If not, create constructors for every column recursively.
331 TVector<unsigned int> subArraySizes(arraySizes.begin(), arraySizes.end() - 1);
332 TVector<unsigned int> subArrayAccumulatedSizes(accumulatedArraySizes.begin(),
333 accumulatedArraySizes.end() - 1);
334
335 TIntermSequence constructorArgs;
336 unsigned int colStride = accumulatedArraySizes.back();
337 for (size_t col = 0; col < arraySizes.back(); ++col)
338 {
339 size_t colElementsOffset = elementsOffset + col * colStride;
340
341 constructorArgs.push_back(constructReadTransformExpressionHelper(
342 subArraySizes, subArrayAccumulatedSizes, baseType, colElementsOffset));
343 }
344
345 return TIntermAggregate::CreateConstructor(*transformedType, &constructorArgs);
346 }
347
348 TIntermTyped *mBaseExpression;
349 const TType &mBaseExpressionType;
350 TVector<unsigned int> mArrayIndices;
351
352 TIntermSequence mReadTransformConstructorArgs;
353 };
354
355 // Traverser that:
356 //
357 // 1. Converts |layout(row_major) matCxR M| to |layout(column_major) matRxC Mt|.
358 // 2. Converts |layout(row_major) S s| to |layout(column_major) St st|, where S is a struct that
359 // contains matrices, and St is a new struct with the transformation in 1 applied to matrix
360 // members (recursively).
361 // 3. When read from, the following transformations are applied:
362 //
363 // M -> transpose(Mt)
364 // M[c] -> gvecN(Mt[0][c], Mt[1][c], ..., Mt[N-1][c])
365 // M[c][r] -> Mt[r][c]
366 // M[c].yz -> gvec2(Mt[1][c], Mt[2][c])
367 // MArr -> MType[D1]..[DN](transpose(MtArr[0]...[0]), ...)
368 // s -> copy_St_to_S(st)
369 // sArr -> SType[D1]...[DN](copy_St_to_S(stArr[0]..[0]), ...)
370 // (matrix reads through struct are transformed similarly to M)
371 //
372 // 4. When written to, the following transformations are applied:
373 //
374 // M = exp -> Mt = transpose(exp)
375 // M[c] = exp -> temp = exp
376 // Mt[0][c] = temp[0]
377 // Mt[1][c] = temp[1]
378 // ...
379 // Mt[N-1][c] = temp[N-1]
380 // M[c][r] = exp -> Mt[r][c] = exp
381 // M[c].yz = exp -> temp = exp
382 // Mt[1][c] = temp[0]
383 // Mt[2][c] = temp[1]
384 // MArr = exp -> temp = exp
385 // Mt = MtType[D1]..[DN](temp([0]...[0]), ...)
386 // s = exp -> st = copy_S_to_St(exp)
387 // sArr = exp -> temp = exp
388 // St = StType[D1]...[DN](copy_S_to_St(temp[0]..[0]), ...)
389 // (matrix writes through struct are transformed similarly to M)
390 //
391 // 5. If any of the above is passed to an `inout` parameter, both transformations are applied:
392 //
393 // f(M[c]) -> temp = gvecN(Mt[0][c], Mt[1][c], ..., Mt[N-1][c])
394 // f(temp)
395 // Mt[0][c] = temp[0]
396 // Mt[1][c] = temp[1]
397 // ...
398 // Mt[N-1][c] = temp[N-1]
399 //
400 // f(s) -> temp = copy_St_to_S(st)
401 // f(temp)
402 // st = copy_S_to_St(temp)
403 //
404 // If passed to an `out` parameter, the `temp` parameter is simply not initialized.
405 //
406 // 6. If the expression leading to the matrix or struct has array subscripts, temp values are
407 // created for them to avoid duplicating side effects.
408 //
409 class RewriteRowMajorMatricesTraverser : public TIntermTraverser
410 {
411 public:
RewriteRowMajorMatricesTraverser(TCompiler * compiler,TSymbolTable * symbolTable)412 RewriteRowMajorMatricesTraverser(TCompiler *compiler, TSymbolTable *symbolTable)
413 : TIntermTraverser(true, true, true, symbolTable),
414 mCompiler(compiler),
415 mStructMapOut(&mOuterPass.structMap),
416 mInterfaceBlockMap(&mOuterPass.interfaceBlockMap),
417 mInterfaceBlockFieldConvertedIn(mOuterPass.interfaceBlockFieldConverted),
418 mCopyFunctionDefinitionsOut(&mOuterPass.copyFunctionDefinitions),
419 mOuterTraverser(nullptr),
420 mInnerPassRoot(nullptr),
421 mIsProcessingInnerPassSubtree(false)
422 {}
423
visitDeclaration(Visit visit,TIntermDeclaration * node)424 bool visitDeclaration(Visit visit, TIntermDeclaration *node) override
425 {
426 // No need to process declarations in inner passes.
427 if (mInnerPassRoot != nullptr)
428 {
429 return true;
430 }
431
432 if (visit != PreVisit)
433 {
434 return true;
435 }
436
437 const TIntermSequence &sequence = *(node->getSequence());
438
439 TIntermTyped *variable = sequence.front()->getAsTyped();
440 const TType &type = variable->getType();
441
442 // If it's a struct declaration that has matrices, remember it. If a row-major instance
443 // of it is created, it will have to be converted.
444 if (type.isStructSpecifier() && type.isStructureContainingMatrices())
445 {
446 const TStructure *structure = type.getStruct();
447 ASSERT(structure);
448
449 ASSERT(mOuterPass.structMap.count(structure) == 0);
450
451 StructConversionData structData;
452 mOuterPass.structMap[structure] = structData;
453
454 return false;
455 }
456
457 // If it's an interface block, it may have to be converted if it contains any row-major
458 // fields.
459 if (type.isInterfaceBlock() && type.getInterfaceBlock()->containsMatrices())
460 {
461 const TInterfaceBlock *block = type.getInterfaceBlock();
462 ASSERT(block);
463 bool isBlockRowMajor = type.getLayoutQualifier().matrixPacking == EmpRowMajor;
464
465 const TFieldList &fields = block->fields();
466 bool anyRowMajor = isBlockRowMajor;
467
468 for (const TField *field : fields)
469 {
470 if (DoesFieldContainRowMajorMatrix(field, isBlockRowMajor))
471 {
472 anyRowMajor = true;
473 break;
474 }
475 }
476
477 if (anyRowMajor)
478 {
479 convertInterfaceBlock(node);
480 }
481
482 return false;
483 }
484
485 return true;
486 }
487
visitSymbol(TIntermSymbol * symbol)488 void visitSymbol(TIntermSymbol *symbol) override
489 {
490 // If in inner pass, only process if the symbol is under that root.
491 if (mInnerPassRoot != nullptr && !mIsProcessingInnerPassSubtree)
492 {
493 return;
494 }
495
496 const TVariable *variable = &symbol->variable();
497 bool needsRewrite = mInterfaceBlockMap->count(variable) != 0;
498
499 // If it's a field of a nameless interface block, it may still need conversion.
500 if (!needsRewrite)
501 {
502 // Nameless interface block field symbols have the interface block pointer set, but are
503 // not interface blocks.
504 if (symbol->getType().getInterfaceBlock() && !variable->getType().isInterfaceBlock())
505 {
506 needsRewrite = convertNamelessInterfaceBlockField(symbol);
507 }
508 }
509
510 if (needsRewrite)
511 {
512 transformExpression(symbol);
513 }
514 }
515
visitBinary(Visit visit,TIntermBinary * node)516 bool visitBinary(Visit visit, TIntermBinary *node) override
517 {
518 if (node == mInnerPassRoot)
519 {
520 // We only want to process the right-hand side of an assignment in inner passes. When
521 // visit is InVisit, the left-hand side is already processed, and the right-hand side is
522 // next. Set a flag to mark this duration.
523 mIsProcessingInnerPassSubtree = visit == InVisit;
524 }
525
526 return true;
527 }
528
getStructCopyFunctions()529 TIntermSequence *getStructCopyFunctions() { return &mOuterPass.copyFunctionDefinitions; }
530
531 private:
532 typedef std::unordered_map<const TStructure *, StructConversionData> StructMap;
533 typedef std::unordered_map<const TVariable *, TVariable *> InterfaceBlockMap;
534 typedef std::unordered_map<const TField *, bool> InterfaceBlockFieldConverted;
535
RewriteRowMajorMatricesTraverser(TSymbolTable * symbolTable,RewriteRowMajorMatricesTraverser * outerTraverser,InterfaceBlockMap * interfaceBlockMap,const InterfaceBlockFieldConverted & interfaceBlockFieldConverted,StructMap * structMap,TIntermSequence * copyFunctionDefinitions,TIntermBinary * innerPassRoot)536 RewriteRowMajorMatricesTraverser(
537 TSymbolTable *symbolTable,
538 RewriteRowMajorMatricesTraverser *outerTraverser,
539 InterfaceBlockMap *interfaceBlockMap,
540 const InterfaceBlockFieldConverted &interfaceBlockFieldConverted,
541 StructMap *structMap,
542 TIntermSequence *copyFunctionDefinitions,
543 TIntermBinary *innerPassRoot)
544 : TIntermTraverser(true, true, true, symbolTable),
545 mStructMapOut(structMap),
546 mInterfaceBlockMap(interfaceBlockMap),
547 mInterfaceBlockFieldConvertedIn(interfaceBlockFieldConverted),
548 mCopyFunctionDefinitionsOut(copyFunctionDefinitions),
549 mOuterTraverser(outerTraverser),
550 mInnerPassRoot(innerPassRoot),
551 mIsProcessingInnerPassSubtree(false)
552 {}
553
convertInterfaceBlock(TIntermDeclaration * node)554 void convertInterfaceBlock(TIntermDeclaration *node)
555 {
556 ASSERT(mInnerPassRoot == nullptr);
557
558 const TIntermSequence &sequence = *(node->getSequence());
559
560 TIntermTyped *variableNode = sequence.front()->getAsTyped();
561 const TType &type = variableNode->getType();
562 const TInterfaceBlock *block = type.getInterfaceBlock();
563 ASSERT(block);
564
565 bool isBlockRowMajor = type.getLayoutQualifier().matrixPacking == EmpRowMajor;
566
567 // Recreate the struct with its row-major fields converted to column-major equivalents.
568 TIntermSequence newDeclarations;
569
570 TFieldList *newFields = new TFieldList;
571 for (const TField *field : block->fields())
572 {
573 TField *newField = nullptr;
574
575 if (DoesFieldContainRowMajorMatrix(field, isBlockRowMajor))
576 {
577 newField = convertField(field, &newDeclarations);
578
579 // Remember that this field was converted.
580 mOuterPass.interfaceBlockFieldConverted[field] = true;
581 }
582 else
583 {
584 newField = DuplicateField(field);
585 }
586
587 newFields->push_back(newField);
588 }
589
590 // Create a new interface block with these fields.
591 TLayoutQualifier blockLayoutQualifier = type.getLayoutQualifier();
592 blockLayoutQualifier.matrixPacking = EmpColumnMajor;
593
594 TInterfaceBlock *newInterfaceBlock =
595 new TInterfaceBlock(mSymbolTable, block->name(), newFields, blockLayoutQualifier,
596 block->symbolType(), block->extension());
597
598 // Create a new declaration with the new type. Declarations are separated at this point,
599 // so there should be only one variable here.
600 ASSERT(sequence.size() == 1);
601
602 TType *newInterfaceBlockType =
603 new TType(newInterfaceBlock, type.getQualifier(), blockLayoutQualifier);
604
605 TIntermDeclaration *newDeclaration = new TIntermDeclaration;
606 const TVariable *variable = &variableNode->getAsSymbolNode()->variable();
607
608 const TType *newType = newInterfaceBlockType;
609 if (type.isArray())
610 {
611 TType *newArrayType = new TType(*newType);
612 CopyArraySizes(&type, newArrayType);
613 newType = newArrayType;
614 }
615
616 // If the interface block variable itself is temp, use an empty name.
617 bool variableIsTemp = variable->symbolType() == SymbolType::Empty;
618 const ImmutableString &variableName =
619 variableIsTemp ? kEmptyImmutableString : variable->name();
620
621 TVariable *newVariable = new TVariable(mSymbolTable, variableName, newType,
622 variable->symbolType(), variable->extension());
623
624 newDeclaration->appendDeclarator(new TIntermSymbol(newVariable));
625
626 mOuterPass.interfaceBlockMap[variable] = newVariable;
627
628 newDeclarations.push_back(newDeclaration);
629
630 // Replace the interface block definition with the new one, prepending any new struct
631 // definitions.
632 mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node, newDeclarations);
633 }
634
convertNamelessInterfaceBlockField(TIntermSymbol * symbol)635 bool convertNamelessInterfaceBlockField(TIntermSymbol *symbol)
636 {
637 const TVariable *variable = &symbol->variable();
638 const TInterfaceBlock *interfaceBlock = symbol->getType().getInterfaceBlock();
639
640 // Find the variable corresponding to this interface block. If the interface block
641 // is not rewritten, or this refers to a field that is not rewritten, there's
642 // nothing to do.
643 for (auto iter : *mInterfaceBlockMap)
644 {
645 // Skip other rewritten nameless interface block fields.
646 if (!iter.first->getType().isInterfaceBlock())
647 {
648 continue;
649 }
650
651 // Skip if this is not a field of this rewritten interface block.
652 if (iter.first->getType().getInterfaceBlock() != interfaceBlock)
653 {
654 continue;
655 }
656
657 const ImmutableString symbolName = symbol->getName();
658
659 // Find which field it is
660 const TVector<TField *> fields = interfaceBlock->fields();
661 for (size_t fieldIndex = 0; fieldIndex < fields.size(); ++fieldIndex)
662 {
663 const TField *field = fields[fieldIndex];
664 if (field->name() != symbolName)
665 {
666 continue;
667 }
668
669 // If this field doesn't need a rewrite, there's nothing to do.
670 if (mInterfaceBlockFieldConvertedIn.count(field) == 0 ||
671 !mInterfaceBlockFieldConvertedIn.at(field))
672 {
673 break;
674 }
675
676 // Create a new variable that references the replaced interface block.
677 TType *newType = new TType(variable->getType());
678 newType->setInterfaceBlock(iter.second->getType().getInterfaceBlock());
679
680 TVariable *newVariable =
681 new TVariable(mSymbolTable, variable->name(), newType, variable->symbolType(),
682 variable->extension());
683
684 (*mInterfaceBlockMap)[variable] = newVariable;
685
686 return true;
687 }
688
689 break;
690 }
691
692 return false;
693 }
694
convertStruct(const TStructure * structure,TIntermSequence * newDeclarations)695 void convertStruct(const TStructure *structure, TIntermSequence *newDeclarations)
696 {
697 ASSERT(mInnerPassRoot == nullptr);
698
699 ASSERT(mOuterPass.structMap.count(structure) != 0);
700 StructConversionData *structData = &mOuterPass.structMap[structure];
701
702 if (structData->convertedStruct)
703 {
704 return;
705 }
706
707 TFieldList *newFields = new TFieldList;
708 for (const TField *field : structure->fields())
709 {
710 newFields->push_back(convertField(field, newDeclarations));
711 }
712
713 // Create unique names for the converted structs. We can't leave them nameless and have
714 // a name autogenerated similar to temp variables, as nameless structs exist. A fake
715 // variable is created for the sole purpose of generating a temp name.
716 TVariable *newStructTypeName =
717 new TVariable(mSymbolTable, kEmptyImmutableString, StaticType::GetBasic<EbtUInt>(),
718 SymbolType::Empty);
719
720 TStructure *newStruct = new TStructure(mSymbolTable, newStructTypeName->name(), newFields,
721 SymbolType::AngleInternal);
722 TType *newType = new TType(newStruct, true);
723 TVariable *newStructVar =
724 new TVariable(mSymbolTable, kEmptyImmutableString, newType, SymbolType::Empty);
725
726 TIntermDeclaration *structDecl = new TIntermDeclaration;
727 structDecl->appendDeclarator(new TIntermSymbol(newStructVar));
728
729 newDeclarations->push_back(structDecl);
730
731 structData->convertedStruct = newStruct;
732 }
733
convertField(const TField * field,TIntermSequence * newDeclarations)734 TField *convertField(const TField *field, TIntermSequence *newDeclarations)
735 {
736 ASSERT(mInnerPassRoot == nullptr);
737
738 TField *newField = nullptr;
739
740 const TType *fieldType = field->type();
741 TType *newType = nullptr;
742
743 if (fieldType->isStructureContainingMatrices())
744 {
745 // If the field is a struct instance, convert the struct and replace the field
746 // with an instance of the new struct.
747 const TStructure *fieldTypeStruct = fieldType->getStruct();
748 convertStruct(fieldTypeStruct, newDeclarations);
749
750 StructConversionData &structData = mOuterPass.structMap[fieldTypeStruct];
751 newType = new TType(structData.convertedStruct, false);
752 SetColumnMajor(newType);
753 CopyArraySizes(fieldType, newType);
754 }
755 else if (fieldType->isMatrix())
756 {
757 // If the field is a matrix, transpose the matrix and replace the field with
758 // that, removing the matrix packing qualifier.
759 newType = TransposeMatrixType(fieldType);
760 }
761
762 if (newType)
763 {
764 newField = new TField(newType, field->name(), field->line(), field->symbolType());
765 }
766 else
767 {
768 newField = DuplicateField(field);
769 }
770
771 return newField;
772 }
773
determineAccess(TIntermNode * expression,TIntermNode * accessor,bool * isReadOut,bool * isWriteOut)774 void determineAccess(TIntermNode *expression,
775 TIntermNode *accessor,
776 bool *isReadOut,
777 bool *isWriteOut)
778 {
779 // If passing to a function, look at whether the parameter is in, out or inout.
780 TIntermAggregate *functionCall = accessor->getAsAggregate();
781
782 if (functionCall)
783 {
784 TIntermSequence *arguments = functionCall->getSequence();
785 for (size_t argIndex = 0; argIndex < arguments->size(); ++argIndex)
786 {
787 if ((*arguments)[argIndex] == expression)
788 {
789 TQualifier qualifier = EvqIn;
790
791 // If the aggregate is not a function call, it's a constructor, and so every
792 // argument is an input.
793 const TFunction *function = functionCall->getFunction();
794 if (function)
795 {
796 const TVariable *param = function->getParam(argIndex);
797 qualifier = param->getType().getQualifier();
798 }
799
800 *isReadOut = qualifier != EvqOut;
801 *isWriteOut = qualifier == EvqOut || qualifier == EvqInOut;
802 break;
803 }
804 }
805 return;
806 }
807
808 TIntermBinary *assignment = accessor->getAsBinaryNode();
809 if (assignment && IsAssignment(assignment->getOp()))
810 {
811 // If expression is on the right of assignment, it's being read from.
812 *isReadOut = assignment->getRight() == expression;
813 // If it's on the left of assignment, it's being written to.
814 *isWriteOut = assignment->getLeft() == expression;
815 return;
816 }
817
818 // Any other usage is a read.
819 *isReadOut = true;
820 *isWriteOut = false;
821 }
822
transformExpression(TIntermSymbol * symbol)823 void transformExpression(TIntermSymbol *symbol)
824 {
825 // Walk up the parent chain while the nodes are EOpIndex* (whether array indexing or struct
826 // field selection) or swizzle and construct the replacement expression. This traversal can
827 // lead to one of the following possibilities:
828 //
829 // - a.b[N].etc.s (struct, or struct array): copy function should be declared and used,
830 // - a.b[N].etc.M (matrix or matrix array): transpose() should be used,
831 // - a.b[N].etc.M[c] (a column): each element in column needs to be handled separately,
832 // - a.b[N].etc.M[c].yz (multiple elements): similar to whole column, but a subset of
833 // elements,
834 // - a.b[N].etc.M[c][r] (an element): single element to handle.
835 // - a.b[N].etc.x (not struct or matrix): not modified
836 //
837 // primaryIndex will contain c, if any. secondaryIndices will contain {0, ..., R-1}
838 // (if no [r] or swizzle), {r} (if [r]), or {1, 2} (corresponding to .yz) if any.
839 //
840 // In all cases, the base symbol is replaced. |baseExpression| will contain everything up
841 // to (and not including) the last index/swizzle operations, i.e. a.b[N].etc.s/M/x. Any
842 // non constant array subscript is assigned to a temp variable to avoid duplicating side
843 // effects.
844 //
845 // ---
846 //
847 // NOTE that due to the use of insertStatementsInParentBlock, cases like this will be
848 // mistranslated, and this bug is likely present in most transformations that use this
849 // feature:
850 //
851 // if (x == 1 && a.b[x = 2].etc.M = value)
852 //
853 // which will translate to:
854 //
855 // temp = (x = 2)
856 // if (x == 1 && a.b[temp].etc.M = transpose(value))
857 //
858 // See http://anglebug.com/3829.
859 //
860 TIntermTyped *baseExpression =
861 new TIntermSymbol(mInterfaceBlockMap->at(&symbol->variable()));
862 const TStructure *structure = nullptr;
863
864 TIntermNode *primaryIndex = nullptr;
865 TIntermSequence secondaryIndices;
866
867 // In some cases, it is necessary to prepend or append statements. Those are captured in
868 // |prependStatements| and |appendStatements|.
869 TIntermSequence prependStatements;
870 TIntermSequence appendStatements;
871
872 // If the expression is neither a struct or matrix, no modification is necessary.
873 // If it's a struct that doesn't have matrices, again there's no transformation necessary.
874 // If it's an interface block matrix field that didn't need to be transposed, no
875 // transpformation is necessary.
876 //
877 // In all these cases, |baseExpression| contains all of the original expression.
878 //
879 // If the starting symbol itself is a field of a nameless interface block, it needs
880 // conversion if we reach here.
881 bool requiresTransformation = !symbol->getType().isInterfaceBlock();
882
883 uint32_t accessorIndex = 0;
884 TIntermTyped *previousAncestor = symbol;
885 while (IsIndexNode(getAncestorNode(accessorIndex), previousAncestor))
886 {
887 TIntermTyped *ancestor = getAncestorNode(accessorIndex)->getAsTyped();
888 ASSERT(ancestor);
889
890 const TType &previousAncestorType = previousAncestor->getType();
891
892 TIntermSequence indices;
893 TOperator op = GetIndex(mSymbolTable, ancestor, &indices, &prependStatements);
894
895 bool opIsIndex = op == EOpIndexDirect || op == EOpIndexIndirect;
896 bool isArrayIndex = opIsIndex && previousAncestorType.isArray();
897 bool isMatrixIndex = opIsIndex && previousAncestorType.isMatrix();
898
899 // If it's a direct index in a matrix, it's the primary index.
900 bool isMatrixPrimarySubscript = isMatrixIndex && !isArrayIndex;
901 ASSERT(!isMatrixPrimarySubscript ||
902 (primaryIndex == nullptr && secondaryIndices.empty()));
903 // If primary index is seen and the ancestor is still an index, it must be a direct
904 // index as the secondary one. Note that if primaryIndex is set, there can only ever be
905 // one more parent of interest, and that's subscripting the second dimension.
906 bool isMatrixSecondarySubscript = primaryIndex != nullptr;
907 ASSERT(!isMatrixSecondarySubscript || (opIsIndex && !isArrayIndex));
908
909 if (requiresTransformation && isMatrixPrimarySubscript)
910 {
911 ASSERT(indices.size() == 1);
912 primaryIndex = indices.front();
913
914 // Default the secondary indices to include every row. If there's a secondary
915 // subscript provided, it will override this.
916 int rows = previousAncestorType.getRows();
917 for (int r = 0; r < rows; ++r)
918 {
919 secondaryIndices.push_back(CreateIndexNode(r));
920 }
921 }
922 else if (isMatrixSecondarySubscript)
923 {
924 ASSERT(requiresTransformation);
925
926 secondaryIndices = indices;
927
928 // Indices after this point are not interesting. There can't actually be any other
929 // index nodes other than desktop GLSL's swizzles on scalars, like M[1][2].yyy.
930 ++accessorIndex;
931 break;
932 }
933 else
934 {
935 // Replicate the expression otherwise.
936 baseExpression =
937 ReplicateIndexNode(mSymbolTable, ancestor, baseExpression, &indices);
938
939 const TType &ancestorType = ancestor->getType();
940 structure = ancestorType.getStruct();
941
942 requiresTransformation =
943 requiresTransformation ||
944 IsConvertedField(ancestor, mInterfaceBlockFieldConvertedIn);
945
946 // If we reach a point where the expression is neither a matrix-containing struct
947 // nor a matrix, there's no transformation required. This can happen if we decend
948 // through a struct marked with row-major but arrive at a member that doesn't
949 // include a matrix.
950 if (!ancestorType.isMatrix() && !ancestorType.isStructureContainingMatrices())
951 {
952 requiresTransformation = false;
953 }
954 }
955
956 previousAncestor = ancestor;
957 ++accessorIndex;
958 }
959
960 TIntermNode *originalExpression =
961 accessorIndex == 0 ? symbol : getAncestorNode(accessorIndex - 1);
962 TIntermNode *accessor = getAncestorNode(accessorIndex);
963
964 // if accessor is EOpArrayLength, we don't need to perform any transformations either.
965 // Note that this only applies to unsized arrays, as the RemoveArrayLengthMethod()
966 // transformation would have removed this operation otherwise.
967 TIntermUnary *accessorAsUnary = accessor->getAsUnaryNode();
968 if (requiresTransformation && accessorAsUnary && accessorAsUnary->getOp() == EOpArrayLength)
969 {
970 ASSERT(accessorAsUnary->getOperand() == originalExpression);
971 ASSERT(accessorAsUnary->getOperand()->getType().isUnsizedArray());
972
973 requiresTransformation = false;
974
975 // We need to replace the whole expression including the EOpArrayLength, to avoid
976 // confusing the replacement code as the original and new expressions don't have the
977 // same type (one is the transpose of the other). This doesn't affect the .length()
978 // operation, so this replacement is ok, though it's not worth special-casing this in
979 // the node replacement algorithm.
980 //
981 // Note: the |if (!requiresTransformation)| immediately below will be entered after
982 // this.
983 originalExpression = accessor;
984 accessor = getAncestorNode(accessorIndex + 1);
985 baseExpression = new TIntermUnary(EOpArrayLength, baseExpression, nullptr);
986 }
987
988 if (!requiresTransformation)
989 {
990 ASSERT(primaryIndex == nullptr);
991 queueReplacementWithParent(accessor, originalExpression, baseExpression,
992 OriginalNode::IS_DROPPED);
993
994 RewriteRowMajorMatricesTraverser *traverser = mOuterTraverser ? mOuterTraverser : this;
995 traverser->insertStatementsInParentBlock(prependStatements, appendStatements);
996 return;
997 }
998
999 ASSERT(structure == nullptr || primaryIndex == nullptr);
1000 ASSERT(structure != nullptr || baseExpression->getType().isMatrix());
1001
1002 // At the end, we can determine if the expression is being read from or written to (or both,
1003 // if sent as an inout parameter to a function). For the sake of the transformation, the
1004 // left-hand side of operations like += can be treated as "written to", without necessarily
1005 // "read from".
1006 bool isRead = false;
1007 bool isWrite = false;
1008
1009 determineAccess(originalExpression, accessor, &isRead, &isWrite);
1010
1011 ASSERT(isRead || isWrite);
1012
1013 TIntermTyped *readExpression = nullptr;
1014 if (isRead)
1015 {
1016 readExpression = transformReadExpression(
1017 baseExpression, primaryIndex, &secondaryIndices, structure, &prependStatements);
1018
1019 // If both read from and written to (i.e. passed to inout parameter), store the
1020 // expression in a temp variable and pass that to the function.
1021 if (isWrite)
1022 {
1023 readExpression =
1024 CopyToTempVariable(mSymbolTable, readExpression, &prependStatements);
1025 }
1026
1027 // Replace the original expression with the transformed one. Read transformations
1028 // always generate a single expression that can be used in place of the original (as
1029 // oppposed to write transformations that can generate multiple statements).
1030 queueReplacementWithParent(accessor, originalExpression, readExpression,
1031 OriginalNode::IS_DROPPED);
1032 }
1033
1034 TIntermSequence postTransformPrependStatements;
1035 TIntermSequence *writeStatements = &appendStatements;
1036 TOperator assignmentOperator = EOpAssign;
1037
1038 if (isWrite)
1039 {
1040 TIntermTyped *valueExpression = readExpression;
1041
1042 if (!valueExpression)
1043 {
1044 // If there's already a read expression, this was an inout parameter and
1045 // |valueExpression| will contain the temp variable that was passed to the function
1046 // instead.
1047 //
1048 // If not, then the modification is either through being passed as an out parameter
1049 // to a function, or an assignment. In the former case, create a temp variable to
1050 // be passed to the function. In the latter case, create a temp variable that holds
1051 // the right hand side expression.
1052 //
1053 // In either case, use that temp value as the value to assign to |baseExpression|.
1054
1055 TVariable *temp =
1056 CreateTempVariable(mSymbolTable, &originalExpression->getAsTyped()->getType());
1057 TIntermDeclaration *tempDecl = nullptr;
1058
1059 valueExpression = new TIntermSymbol(temp);
1060
1061 TIntermBinary *assignment = accessor->getAsBinaryNode();
1062 if (assignment)
1063 {
1064 assignmentOperator = assignment->getOp();
1065 ASSERT(IsAssignment(assignmentOperator));
1066
1067 // We are converting the assignment to the left-hand side of an expression in
1068 // the form M=exp. A subexpression of exp itself could require a
1069 // transformation. This complicates things as there would be two replacements:
1070 //
1071 // - Replace M=exp with temp (because the return value of the assignment could
1072 // be used)
1073 // - Replace exp with exp2, where parent is M=exp
1074 //
1075 // The second replacement however is ineffective as the whole of M=exp is
1076 // already transformed. What's worse, M=exp is transformed without taking exp's
1077 // transformations into account. To address this issue, this same traverser is
1078 // called on the right-hand side expression, with a special flag such that it
1079 // only processes that expression.
1080 //
1081 RewriteRowMajorMatricesTraverser *outerTraverser =
1082 mOuterTraverser ? mOuterTraverser : this;
1083 RewriteRowMajorMatricesTraverser rhsTraverser(
1084 mSymbolTable, outerTraverser, mInterfaceBlockMap,
1085 mInterfaceBlockFieldConvertedIn, mStructMapOut, mCopyFunctionDefinitionsOut,
1086 assignment);
1087 getRootNode()->traverse(&rhsTraverser);
1088 bool valid = rhsTraverser.updateTree(mCompiler, getRootNode());
1089 ASSERT(valid);
1090
1091 tempDecl = CreateTempInitDeclarationNode(temp, assignment->getRight());
1092
1093 // Replace the whole assignment expression with the right-hand side as a read
1094 // expression, in case the result of the assignment is used. For example, this
1095 // transforms:
1096 //
1097 // if ((M += exp) == X)
1098 // {
1099 // // use M
1100 // }
1101 //
1102 // to:
1103 //
1104 // temp = exp;
1105 // M += transform(temp);
1106 // if (transform(M) == X)
1107 // {
1108 // // use M
1109 // }
1110 //
1111 // Note that in this case the assignment to M must be prepended in the parent
1112 // block. In contrast, when sent to a function, the assignment to M should be
1113 // done after the current function call is done.
1114 //
1115 // If the read from M itself (to replace assigmnet) needs to generate extra
1116 // statements, they should be appended after the statements that write to M.
1117 // These statements are stored in postTransformPrependStatements and appended to
1118 // prependStatements in the end.
1119 //
1120 writeStatements = &prependStatements;
1121
1122 TIntermTyped *assignmentResultExpression = transformReadExpression(
1123 baseExpression->deepCopy(), primaryIndex, &secondaryIndices, structure,
1124 &postTransformPrependStatements);
1125
1126 // Replace the whole assignment, instead of just the right hand side.
1127 TIntermNode *accessorParent = getAncestorNode(accessorIndex + 1);
1128 queueReplacementWithParent(accessorParent, accessor, assignmentResultExpression,
1129 OriginalNode::IS_DROPPED);
1130 }
1131 else
1132 {
1133 tempDecl = CreateTempDeclarationNode(temp);
1134
1135 // Replace the write expression (a function call argument) with the temp
1136 // variable.
1137 queueReplacementWithParent(accessor, originalExpression, valueExpression,
1138 OriginalNode::IS_DROPPED);
1139 }
1140 prependStatements.push_back(tempDecl);
1141 }
1142
1143 if (isRead)
1144 {
1145 baseExpression = baseExpression->deepCopy();
1146 }
1147 transformWriteExpression(baseExpression, primaryIndex, &secondaryIndices, structure,
1148 valueExpression, assignmentOperator, writeStatements);
1149 }
1150
1151 prependStatements.insert(prependStatements.end(), postTransformPrependStatements.begin(),
1152 postTransformPrependStatements.end());
1153
1154 RewriteRowMajorMatricesTraverser *traverser = mOuterTraverser ? mOuterTraverser : this;
1155 traverser->insertStatementsInParentBlock(prependStatements, appendStatements);
1156 }
1157
transformReadExpression(TIntermTyped * baseExpression,TIntermNode * primaryIndex,TIntermSequence * secondaryIndices,const TStructure * structure,TIntermSequence * prependStatements)1158 TIntermTyped *transformReadExpression(TIntermTyped *baseExpression,
1159 TIntermNode *primaryIndex,
1160 TIntermSequence *secondaryIndices,
1161 const TStructure *structure,
1162 TIntermSequence *prependStatements)
1163 {
1164 const TType &baseExpressionType = baseExpression->getType();
1165
1166 if (structure)
1167 {
1168 ASSERT(primaryIndex == nullptr && secondaryIndices->empty());
1169 ASSERT(mStructMapOut->count(structure) != 0);
1170 ASSERT((*mStructMapOut)[structure].convertedStruct != nullptr);
1171
1172 // Declare copy-from-converted-to-original-struct function (if not already).
1173 declareStructCopyToOriginal(structure);
1174
1175 const TFunction *copyToOriginal = (*mStructMapOut)[structure].copyToOriginal;
1176
1177 if (baseExpressionType.isArray())
1178 {
1179 // If base expression is an array, transform every element.
1180 TransformArrayHelper transformHelper(baseExpression);
1181
1182 TIntermTyped *element = nullptr;
1183 while ((element = transformHelper.getNextElement(nullptr, nullptr)) != nullptr)
1184 {
1185 TIntermTyped *transformedElement =
1186 CreateStructCopyCall(copyToOriginal, element);
1187 transformHelper.accumulateForRead(mSymbolTable, transformedElement,
1188 prependStatements);
1189 }
1190 return transformHelper.constructReadTransformExpression();
1191 }
1192 else
1193 {
1194 // If not reading an array, the result is simply a call to this function with the
1195 // base expression.
1196 return CreateStructCopyCall(copyToOriginal, baseExpression);
1197 }
1198 }
1199
1200 // If not indexed, the result is transpose(exp)
1201 if (primaryIndex == nullptr)
1202 {
1203 ASSERT(secondaryIndices->empty());
1204
1205 if (baseExpressionType.isArray())
1206 {
1207 // If array, transpose every element.
1208 TransformArrayHelper transformHelper(baseExpression);
1209
1210 TIntermTyped *element = nullptr;
1211 while ((element = transformHelper.getNextElement(nullptr, nullptr)) != nullptr)
1212 {
1213 TIntermTyped *transformedElement = CreateTransposeCall(mSymbolTable, element);
1214 transformHelper.accumulateForRead(mSymbolTable, transformedElement,
1215 prependStatements);
1216 }
1217 return transformHelper.constructReadTransformExpression();
1218 }
1219 else
1220 {
1221 return CreateTransposeCall(mSymbolTable, baseExpression);
1222 }
1223 }
1224
1225 // If indexed the result is a vector (or just one element) where the primary and secondary
1226 // indices are swapped.
1227 ASSERT(!secondaryIndices->empty());
1228
1229 TOperator primaryIndexOp = GetIndexOp(primaryIndex);
1230 TIntermTyped *primaryIndexAsTyped = primaryIndex->getAsTyped();
1231
1232 TIntermSequence transposedColumn;
1233 for (TIntermNode *secondaryIndex : *secondaryIndices)
1234 {
1235 TOperator secondaryIndexOp = GetIndexOp(secondaryIndex);
1236 TIntermTyped *secondaryIndexAsTyped = secondaryIndex->getAsTyped();
1237
1238 TIntermBinary *colIndexed = new TIntermBinary(
1239 secondaryIndexOp, baseExpression->deepCopy(), secondaryIndexAsTyped->deepCopy());
1240 TIntermBinary *colRowIndexed =
1241 new TIntermBinary(primaryIndexOp, colIndexed, primaryIndexAsTyped->deepCopy());
1242
1243 transposedColumn.push_back(colRowIndexed);
1244 }
1245
1246 if (secondaryIndices->size() == 1)
1247 {
1248 // If only one element, return that directly.
1249 return transposedColumn.front()->getAsTyped();
1250 }
1251
1252 // Otherwise create a constructor with the appropriate dimension.
1253 TType *vecType = new TType(baseExpressionType.getBasicType(), secondaryIndices->size());
1254 return TIntermAggregate::CreateConstructor(*vecType, &transposedColumn);
1255 }
1256
transformWriteExpression(TIntermTyped * baseExpression,TIntermNode * primaryIndex,TIntermSequence * secondaryIndices,const TStructure * structure,TIntermTyped * valueExpression,TOperator assignmentOperator,TIntermSequence * writeStatements)1257 void transformWriteExpression(TIntermTyped *baseExpression,
1258 TIntermNode *primaryIndex,
1259 TIntermSequence *secondaryIndices,
1260 const TStructure *structure,
1261 TIntermTyped *valueExpression,
1262 TOperator assignmentOperator,
1263 TIntermSequence *writeStatements)
1264 {
1265 const TType &baseExpressionType = baseExpression->getType();
1266
1267 if (structure)
1268 {
1269 ASSERT(primaryIndex == nullptr && secondaryIndices->empty());
1270 ASSERT(mStructMapOut->count(structure) != 0);
1271 ASSERT((*mStructMapOut)[structure].convertedStruct != nullptr);
1272
1273 // Declare copy-to-converted-from-original-struct function (if not already).
1274 declareStructCopyFromOriginal(structure);
1275
1276 // The result is call to this function with the value expression assigned to base
1277 // expression.
1278 const TFunction *copyFromOriginal = (*mStructMapOut)[structure].copyFromOriginal;
1279
1280 if (baseExpressionType.isArray())
1281 {
1282 // If array, assign every element.
1283 TransformArrayHelper transformHelper(baseExpression);
1284
1285 TIntermTyped *element = nullptr;
1286 TIntermTyped *valueElement = nullptr;
1287 while ((element = transformHelper.getNextElement(valueExpression, &valueElement)) !=
1288 nullptr)
1289 {
1290 TIntermTyped *functionCall =
1291 CreateStructCopyCall(copyFromOriginal, valueElement);
1292 writeStatements->push_back(new TIntermBinary(EOpAssign, element, functionCall));
1293 }
1294 }
1295 else
1296 {
1297 TIntermTyped *functionCall =
1298 CreateStructCopyCall(copyFromOriginal, valueExpression->deepCopy());
1299 writeStatements->push_back(
1300 new TIntermBinary(EOpAssign, baseExpression, functionCall));
1301 }
1302
1303 return;
1304 }
1305
1306 // If not indexed, the result is transpose(exp)
1307 if (primaryIndex == nullptr)
1308 {
1309 ASSERT(secondaryIndices->empty());
1310
1311 if (baseExpressionType.isArray())
1312 {
1313 // If array, assign every element.
1314 TransformArrayHelper transformHelper(baseExpression);
1315
1316 TIntermTyped *element = nullptr;
1317 TIntermTyped *valueElement = nullptr;
1318 while ((element = transformHelper.getNextElement(valueExpression, &valueElement)) !=
1319 nullptr)
1320 {
1321 TIntermTyped *valueTransposed = CreateTransposeCall(mSymbolTable, valueElement);
1322 writeStatements->push_back(
1323 new TIntermBinary(EOpAssign, element, valueTransposed));
1324 }
1325 }
1326 else
1327 {
1328 TIntermTyped *valueTransposed =
1329 CreateTransposeCall(mSymbolTable, valueExpression->deepCopy());
1330 writeStatements->push_back(
1331 new TIntermBinary(assignmentOperator, baseExpression, valueTransposed));
1332 }
1333
1334 return;
1335 }
1336
1337 // If indexed, create one assignment per secondary index. If the right-hand side is a
1338 // scalar, it's used with every assignment. If it's a vector, the assignment is
1339 // per-component. The right-hand side cannot be a matrix as that would imply left-hand
1340 // side being a matrix too, which is covered above where |primaryIndex == nullptr|.
1341 ASSERT(!secondaryIndices->empty());
1342
1343 bool isValueExpressionScalar = valueExpression->getType().getNominalSize() == 1;
1344 ASSERT(isValueExpressionScalar || valueExpression->getType().getNominalSize() ==
1345 static_cast<int>(secondaryIndices->size()));
1346
1347 TOperator primaryIndexOp = GetIndexOp(primaryIndex);
1348 TIntermTyped *primaryIndexAsTyped = primaryIndex->getAsTyped();
1349
1350 for (TIntermNode *secondaryIndex : *secondaryIndices)
1351 {
1352 TOperator secondaryIndexOp = GetIndexOp(secondaryIndex);
1353 TIntermTyped *secondaryIndexAsTyped = secondaryIndex->getAsTyped();
1354
1355 TIntermBinary *colIndexed = new TIntermBinary(
1356 secondaryIndexOp, baseExpression->deepCopy(), secondaryIndexAsTyped->deepCopy());
1357 TIntermBinary *colRowIndexed =
1358 new TIntermBinary(primaryIndexOp, colIndexed, primaryIndexAsTyped->deepCopy());
1359
1360 TIntermTyped *valueExpressionIndexed = valueExpression->deepCopy();
1361 if (!isValueExpressionScalar)
1362 {
1363 valueExpressionIndexed = new TIntermBinary(secondaryIndexOp, valueExpressionIndexed,
1364 secondaryIndexAsTyped->deepCopy());
1365 }
1366
1367 writeStatements->push_back(
1368 new TIntermBinary(assignmentOperator, colRowIndexed, valueExpressionIndexed));
1369 }
1370 }
1371
getCopyStructFieldFunction(const TType * fromFieldType,const TType * toFieldType,bool isCopyToOriginal)1372 const TFunction *getCopyStructFieldFunction(const TType *fromFieldType,
1373 const TType *toFieldType,
1374 bool isCopyToOriginal)
1375 {
1376 ASSERT(fromFieldType->getStruct());
1377 ASSERT(toFieldType->getStruct());
1378
1379 // If copying from or to the original struct, the "to" field struct could require
1380 // conversion to or from the "from" field struct. |isCopyToOriginal| tells us if we
1381 // should expect to find toField or fromField in mStructMapOut, if true or false
1382 // respectively.
1383 const TFunction *fieldCopyFunction = nullptr;
1384 if (isCopyToOriginal)
1385 {
1386 const TStructure *toFieldStruct = toFieldType->getStruct();
1387
1388 auto iter = mStructMapOut->find(toFieldStruct);
1389 if (iter != mStructMapOut->end())
1390 {
1391 declareStructCopyToOriginal(toFieldStruct);
1392 fieldCopyFunction = iter->second.copyToOriginal;
1393 }
1394 }
1395 else
1396 {
1397 const TStructure *fromFieldStruct = fromFieldType->getStruct();
1398
1399 auto iter = mStructMapOut->find(fromFieldStruct);
1400 if (iter != mStructMapOut->end())
1401 {
1402 declareStructCopyFromOriginal(fromFieldStruct);
1403 fieldCopyFunction = iter->second.copyFromOriginal;
1404 }
1405 }
1406
1407 return fieldCopyFunction;
1408 }
1409
addFieldCopy(TIntermBlock * body,TIntermTyped * to,TIntermTyped * from,bool isCopyToOriginal)1410 void addFieldCopy(TIntermBlock *body,
1411 TIntermTyped *to,
1412 TIntermTyped *from,
1413 bool isCopyToOriginal)
1414 {
1415 const TType &fromType = from->getType();
1416 const TType &toType = to->getType();
1417
1418 TIntermTyped *rhs = from;
1419
1420 if (fromType.getStruct())
1421 {
1422 const TFunction *fieldCopyFunction =
1423 getCopyStructFieldFunction(&fromType, &toType, isCopyToOriginal);
1424
1425 if (fieldCopyFunction)
1426 {
1427 rhs = CreateStructCopyCall(fieldCopyFunction, from);
1428 }
1429 }
1430 else if (fromType.isMatrix())
1431 {
1432 rhs = CreateTransposeCall(mSymbolTable, from);
1433 }
1434
1435 body->appendStatement(new TIntermBinary(EOpAssign, to, rhs));
1436 }
1437
declareStructCopy(const TStructure * from,const TStructure * to,bool isCopyToOriginal)1438 TFunction *declareStructCopy(const TStructure *from,
1439 const TStructure *to,
1440 bool isCopyToOriginal)
1441 {
1442 TType *fromType = new TType(from, true);
1443 TType *toType = new TType(to, true);
1444
1445 // Create the parameter and return value variables.
1446 TVariable *fromVar = new TVariable(mSymbolTable, ImmutableString("from"), fromType,
1447 SymbolType::AngleInternal);
1448 TVariable *toVar =
1449 new TVariable(mSymbolTable, ImmutableString("to"), toType, SymbolType::AngleInternal);
1450
1451 TIntermSymbol *fromSymbol = new TIntermSymbol(fromVar);
1452 TIntermSymbol *toSymbol = new TIntermSymbol(toVar);
1453
1454 // Create the function body as statements are generated.
1455 TIntermBlock *body = new TIntermBlock;
1456
1457 // Declare the result variable.
1458 TIntermDeclaration *toDecl = new TIntermDeclaration();
1459 toDecl->appendDeclarator(toSymbol);
1460 body->appendStatement(toDecl);
1461
1462 // Iterate over fields of the struct and copy one by one, transposing the matrices. If a
1463 // struct is encountered that requires a transformation, this function is recursively
1464 // called. As a result, it is important that the copy functions are placed in the code in
1465 // order.
1466 const TFieldList &fromFields = from->fields();
1467 const TFieldList &toFields = to->fields();
1468 ASSERT(fromFields.size() == toFields.size());
1469
1470 for (size_t fieldIndex = 0; fieldIndex < fromFields.size(); ++fieldIndex)
1471 {
1472 TIntermTyped *fieldIndexNode = CreateIndexNode(static_cast<int>(fieldIndex));
1473
1474 TIntermTyped *fromField =
1475 new TIntermBinary(EOpIndexDirectStruct, fromSymbol->deepCopy(), fieldIndexNode);
1476 TIntermTyped *toField = new TIntermBinary(EOpIndexDirectStruct, toSymbol->deepCopy(),
1477 fieldIndexNode->deepCopy());
1478
1479 const TType *fromFieldType = fromFields[fieldIndex]->type();
1480 bool isStructOrMatrix = fromFieldType->getStruct() || fromFieldType->isMatrix();
1481
1482 if (fromFieldType->isArray() && isStructOrMatrix)
1483 {
1484 // If struct or matrix array, we need to copy element by element.
1485 TransformArrayHelper transformHelper(toField);
1486
1487 TIntermTyped *toElement = nullptr;
1488 TIntermTyped *fromElement = nullptr;
1489 while ((toElement = transformHelper.getNextElement(fromField, &fromElement)) !=
1490 nullptr)
1491 {
1492 addFieldCopy(body, toElement, fromElement, isCopyToOriginal);
1493 }
1494 }
1495 else
1496 {
1497 addFieldCopy(body, toField, fromField, isCopyToOriginal);
1498 }
1499 }
1500
1501 // Add return statement.
1502 body->appendStatement(new TIntermBranch(EOpReturn, toSymbol->deepCopy()));
1503
1504 // Declare the function
1505 TFunction *copyFunction = new TFunction(mSymbolTable, kEmptyImmutableString,
1506 SymbolType::AngleInternal, toType, true);
1507 copyFunction->addParameter(fromVar);
1508
1509 TIntermFunctionDefinition *functionDef =
1510 CreateInternalFunctionDefinitionNode(*copyFunction, body);
1511 mCopyFunctionDefinitionsOut->push_back(functionDef);
1512
1513 return copyFunction;
1514 }
1515
declareStructCopyFromOriginal(const TStructure * structure)1516 void declareStructCopyFromOriginal(const TStructure *structure)
1517 {
1518 StructConversionData *structData = &(*mStructMapOut)[structure];
1519 if (structData->copyFromOriginal)
1520 {
1521 return;
1522 }
1523
1524 structData->copyFromOriginal =
1525 declareStructCopy(structure, structData->convertedStruct, false);
1526 }
1527
declareStructCopyToOriginal(const TStructure * structure)1528 void declareStructCopyToOriginal(const TStructure *structure)
1529 {
1530 StructConversionData *structData = &(*mStructMapOut)[structure];
1531 if (structData->copyToOriginal)
1532 {
1533 return;
1534 }
1535
1536 structData->copyToOriginal =
1537 declareStructCopy(structData->convertedStruct, structure, true);
1538 }
1539
1540 TCompiler *mCompiler;
1541
1542 // This traverser can call itself to transform a subexpression before moving on. However, it
1543 // needs to accumulate conversion functions in inner passes. The fields below marked with Out
1544 // or In are inherited from the outer pass (for inner passes), or point to storage fields in
1545 // mOuterPass (for the outer pass). The latter should not be used by the inner passes as they
1546 // would be empty, so they are placed inside a struct to make them explicit.
1547 struct
1548 {
1549 StructMap structMap;
1550 InterfaceBlockMap interfaceBlockMap;
1551 InterfaceBlockFieldConverted interfaceBlockFieldConverted;
1552 TIntermSequence copyFunctionDefinitions;
1553 } mOuterPass;
1554
1555 // A map from structures with matrices to their converted version.
1556 StructMap *mStructMapOut;
1557 // A map from interface block instances with row-major matrices to their converted variable. If
1558 // an interface block is nameless, its fields are placed in this map instead. When a variable
1559 // in this map is encountered, it signals the start of an expression that my need conversion,
1560 // which is either "interfaceBlock.field..." or "field..." if nameless.
1561 InterfaceBlockMap *mInterfaceBlockMap;
1562 // A map from interface block fields to whether they need to be converted. If a field was
1563 // already column-major, it shouldn't be transposed.
1564 const InterfaceBlockFieldConverted &mInterfaceBlockFieldConvertedIn;
1565
1566 TIntermSequence *mCopyFunctionDefinitionsOut;
1567
1568 // If set, it's an inner pass and this will point to the outer pass traverser. All statement
1569 // insertions are stored in the outer traverser and applied at once in the end. This prevents
1570 // the inner passes from adding statements which invalidates the outer traverser's statement
1571 // position tracking.
1572 RewriteRowMajorMatricesTraverser *mOuterTraverser;
1573
1574 // If set, it's an inner pass that should only process the right-hand side of this particular
1575 // node.
1576 TIntermBinary *mInnerPassRoot;
1577 bool mIsProcessingInnerPassSubtree;
1578 };
1579
1580 } // anonymous namespace
1581
RewriteRowMajorMatrices(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable)1582 bool RewriteRowMajorMatrices(TCompiler *compiler, TIntermBlock *root, TSymbolTable *symbolTable)
1583 {
1584 RewriteRowMajorMatricesTraverser traverser(compiler, symbolTable);
1585 root->traverse(&traverser);
1586 if (!traverser.updateTree(compiler, root))
1587 {
1588 return false;
1589 }
1590
1591 size_t firstFunctionIndex = FindFirstFunctionDefinitionIndex(root);
1592 root->insertChildNodes(firstFunctionIndex, *traverser.getStructCopyFunctions());
1593
1594 return compiler->validateAST(root);
1595 }
1596 } // namespace sh
1597