1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved. Use of this
3 // source code is governed by a BSD-style license that can be found in the
4 // LICENSE file.
5 //
6 // ReplaceArrayOfMatrixVarying: Find any references to array of matrices varying
7 // and replace it with array of vectors.
8 //
9
10 #include "compiler/translator/tree_util/ReplaceArrayOfMatrixVarying.h"
11
12 #include <vector>
13
14 #include "common/bitset_utils.h"
15 #include "common/debug.h"
16 #include "common/utilities.h"
17 #include "compiler/translator/Compiler.h"
18 #include "compiler/translator/SymbolTable.h"
19 #include "compiler/translator/tree_util/BuiltIn.h"
20 #include "compiler/translator/tree_util/FindMain.h"
21 #include "compiler/translator/tree_util/IntermNode_util.h"
22 #include "compiler/translator/tree_util/IntermTraverse.h"
23 #include "compiler/translator/tree_util/ReplaceVariable.h"
24 #include "compiler/translator/tree_util/RunAtTheEndOfShader.h"
25 #include "compiler/translator/util.h"
26
27 namespace sh
28 {
29
30 // We create two variables to replace the given varying:
31 // - The new varying which is an array of vectors to be used at input/ouput only.
32 // - The new global variable which is a same type as given variable, to temporarily be used
33 // as replacements for assignments, arithmetic ops and so on. During input/ouput phrase, this temp
34 // variable will be copied from/to the array of vectors variable above.
35 // NOTE(hqle): Consider eliminating the need for using temp variable.
36
37 namespace
38 {
39 class CollectVaryingTraverser : public TIntermTraverser
40 {
41 public:
CollectVaryingTraverser(std::vector<const TVariable * > * varyingsOut)42 CollectVaryingTraverser(std::vector<const TVariable *> *varyingsOut)
43 : TIntermTraverser(true, false, false), mVaryingsOut(varyingsOut)
44 {}
45
visitDeclaration(Visit visit,TIntermDeclaration * node)46 bool visitDeclaration(Visit visit, TIntermDeclaration *node) override
47 {
48 const TIntermSequence &sequence = *(node->getSequence());
49
50 if (sequence.size() != 1)
51 {
52 return false;
53 }
54
55 TIntermTyped *variableType = sequence.front()->getAsTyped();
56 if (!variableType || !IsVarying(variableType->getQualifier()) ||
57 !variableType->isMatrix() || !variableType->isArray())
58 {
59 return false;
60 }
61
62 TIntermSymbol *variableSymbol = variableType->getAsSymbolNode();
63 if (!variableSymbol)
64 {
65 return false;
66 }
67
68 mVaryingsOut->push_back(&variableSymbol->variable());
69
70 return false;
71 }
72
73 private:
74 std::vector<const TVariable *> *mVaryingsOut;
75 };
76 } // namespace
77
ReplaceArrayOfMatrixVarying(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,const TVariable * varying)78 ANGLE_NO_DISCARD bool ReplaceArrayOfMatrixVarying(TCompiler *compiler,
79 TIntermBlock *root,
80 TSymbolTable *symbolTable,
81 const TVariable *varying)
82 {
83 const TType &type = varying->getType();
84
85 // Create global variable to temporarily acts as the given variable in places such as
86 // arithmetic, assignments an so on.
87 TType *tmpReplacementType = new TType(type);
88 tmpReplacementType->setQualifier(EvqGlobal);
89 tmpReplacementType->realize();
90
91 TVariable *tempReplaceVar = new TVariable(
92 symbolTable, ImmutableString(std::string("ANGLE_AOM_Temp_") + varying->name().data()),
93 tmpReplacementType, SymbolType::AngleInternal);
94
95 if (!ReplaceVariable(compiler, root, varying, tempReplaceVar))
96 {
97 return false;
98 }
99
100 // Create array of vectors type
101 TType *varyingReplaceType =
102 new TType(type.getBasicType(), type.getPrecision(), type.getQualifier(),
103 static_cast<unsigned char>(type.getRows()), 1);
104 varyingReplaceType->setInvariant(type.isInvariant());
105 varyingReplaceType->setMemoryQualifier(type.getMemoryQualifier());
106 varyingReplaceType->setLayoutQualifier(type.getLayoutQualifier());
107 varyingReplaceType->makeArray(type.getCols() * type.getOutermostArraySize());
108 varyingReplaceType->realize();
109
110 TVariable *varyingReplaceVar =
111 new TVariable(symbolTable, varying->name(), varyingReplaceType, SymbolType::UserDefined);
112
113 TIntermSymbol *varyingReplaceDeclarator = new TIntermSymbol(varyingReplaceVar);
114 TIntermDeclaration *varyingReplaceDecl = new TIntermDeclaration;
115 varyingReplaceDecl->appendDeclarator(varyingReplaceDeclarator);
116 root->insertStatement(0, varyingReplaceDecl);
117
118 // Copy from/to the temp variable
119 TIntermBlock *reassignBlock = new TIntermBlock;
120 TIntermSymbol *tempReplaceSymbol = new TIntermSymbol(tempReplaceVar);
121 TIntermSymbol *varyingReplaceSymbol = new TIntermSymbol(varyingReplaceVar);
122 bool isInput = IsVaryingIn(type.getQualifier());
123
124 for (unsigned int i = 0; i < type.getOutermostArraySize(); ++i)
125 {
126 TIntermBinary *tempMatrixIndexed =
127 new TIntermBinary(EOpIndexDirect, tempReplaceSymbol->deepCopy(), CreateIndexNode(i));
128 for (int col = 0; col < type.getCols(); ++col)
129 {
130
131 TIntermBinary *tempMatrixColIndexed = new TIntermBinary(
132 EOpIndexDirect, tempMatrixIndexed->deepCopy(), CreateIndexNode(col));
133 TIntermBinary *vectorIndexed =
134 new TIntermBinary(EOpIndexDirect, varyingReplaceSymbol->deepCopy(),
135 CreateIndexNode(i * type.getCols() + col));
136 TIntermBinary *assignment;
137 if (isInput)
138 {
139 assignment = new TIntermBinary(EOpAssign, tempMatrixColIndexed, vectorIndexed);
140 }
141 else
142 {
143 assignment = new TIntermBinary(EOpAssign, vectorIndexed, tempMatrixColIndexed);
144 }
145 reassignBlock->appendStatement(assignment);
146 }
147 }
148
149 if (isInput)
150 {
151 TIntermFunctionDefinition *main = FindMain(root);
152 main->getBody()->insertStatement(0, reassignBlock);
153 return compiler->validateAST(root);
154 }
155 else
156 {
157 return RunAtTheEndOfShader(compiler, root, reassignBlock, symbolTable);
158 }
159 }
160
ReplaceArrayOfMatrixVaryings(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable)161 ANGLE_NO_DISCARD bool ReplaceArrayOfMatrixVaryings(TCompiler *compiler,
162 TIntermBlock *root,
163 TSymbolTable *symbolTable)
164 {
165 std::vector<const TVariable *> arrayOfMatrixVars;
166 CollectVaryingTraverser varCollector(&arrayOfMatrixVars);
167 root->traverse(&varCollector);
168
169 for (const TVariable *var : arrayOfMatrixVars)
170 {
171 if (!ReplaceArrayOfMatrixVarying(compiler, root, symbolTable, var))
172 {
173 return false;
174 }
175 }
176
177 return true;
178 }
179
180 } // namespace sh
181