1 //
2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // Scalarize vector and matrix constructor args, so that vectors built from components don't have
7 // matrix arguments, and matrices built from components don't have vector arguments. This avoids
8 // driver bugs around vector and matrix constructors.
9 //
10
11 #include "compiler/translator/tree_ops/ScalarizeVecAndMatConstructorArgs.h"
12 #include "common/debug.h"
13
14 #include <algorithm>
15
16 #include "angle_gl.h"
17 #include "common/angleutils.h"
18 #include "compiler/translator/Compiler.h"
19 #include "compiler/translator/tree_util/IntermNodePatternMatcher.h"
20 #include "compiler/translator/tree_util/IntermNode_util.h"
21 #include "compiler/translator/tree_util/IntermTraverse.h"
22 #include "compiler/translator/util.h"
23
24 namespace sh
25 {
26
27 namespace
28 {
29
ConstructVectorIndexBinaryNode(TIntermTyped * symbolNode,int index)30 TIntermBinary *ConstructVectorIndexBinaryNode(TIntermTyped *symbolNode, int index)
31 {
32 return new TIntermBinary(EOpIndexDirect, symbolNode, CreateIndexNode(index));
33 }
34
ConstructMatrixIndexBinaryNode(TIntermTyped * symbolNode,int colIndex,int rowIndex)35 TIntermBinary *ConstructMatrixIndexBinaryNode(TIntermTyped *symbolNode, int colIndex, int rowIndex)
36 {
37 TIntermBinary *colVectorNode = ConstructVectorIndexBinaryNode(symbolNode, colIndex);
38
39 return new TIntermBinary(EOpIndexDirect, colVectorNode, CreateIndexNode(rowIndex));
40 }
41
42 class ScalarizeArgsTraverser : public TIntermTraverser
43 {
44 public:
ScalarizeArgsTraverser(TSymbolTable * symbolTable)45 ScalarizeArgsTraverser(TSymbolTable *symbolTable)
46 : TIntermTraverser(true, false, false, symbolTable),
47 mNodesToScalarize(IntermNodePatternMatcher::kScalarizedVecOrMatConstructor)
48 {}
49
50 protected:
51 bool visitAggregate(Visit visit, TIntermAggregate *node) override;
52 bool visitBlock(Visit visit, TIntermBlock *node) override;
53
54 private:
55 void scalarizeArgs(TIntermAggregate *aggregate, bool scalarizeVector, bool scalarizeMatrix);
56
57 // If we have the following code:
58 // mat4 m(0);
59 // vec4 v(1, m);
60 // We will rewrite to:
61 // mat4 m(0);
62 // mat4 s0 = m;
63 // vec4 v(1, s0[0][0], s0[0][1], s0[0][2]);
64 // This function is to create nodes for "mat4 s0 = m;" and insert it to the code sequence. This
65 // way the possible side effects of the constructor argument will only be evaluated once.
66 TIntermTyped *createTempVariable(TIntermTyped *original);
67
68 std::vector<TIntermSequence> mBlockStack;
69
70 IntermNodePatternMatcher mNodesToScalarize;
71 };
72
visitAggregate(Visit visit,TIntermAggregate * node)73 bool ScalarizeArgsTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
74 {
75 ASSERT(visit == PreVisit);
76 if (mNodesToScalarize.match(node, getParentNode()))
77 {
78 if (node->getType().isVector())
79 {
80 scalarizeArgs(node, false, true);
81 }
82 else
83 {
84 ASSERT(node->getType().isMatrix());
85 scalarizeArgs(node, true, false);
86 }
87 }
88 return true;
89 }
90
visitBlock(Visit visit,TIntermBlock * node)91 bool ScalarizeArgsTraverser::visitBlock(Visit visit, TIntermBlock *node)
92 {
93 mBlockStack.push_back(TIntermSequence());
94 {
95 for (TIntermNode *child : *node->getSequence())
96 {
97 ASSERT(child != nullptr);
98 child->traverse(this);
99 mBlockStack.back().push_back(child);
100 }
101 }
102 if (mBlockStack.back().size() > node->getSequence()->size())
103 {
104 node->getSequence()->clear();
105 *(node->getSequence()) = mBlockStack.back();
106 }
107 mBlockStack.pop_back();
108 return false;
109 }
110
scalarizeArgs(TIntermAggregate * aggregate,bool scalarizeVector,bool scalarizeMatrix)111 void ScalarizeArgsTraverser::scalarizeArgs(TIntermAggregate *aggregate,
112 bool scalarizeVector,
113 bool scalarizeMatrix)
114 {
115 ASSERT(aggregate);
116 ASSERT(!aggregate->isArray());
117 int size = static_cast<int>(aggregate->getType().getObjectSize());
118 TIntermSequence *sequence = aggregate->getSequence();
119 TIntermSequence originalArgs(*sequence);
120 sequence->clear();
121 for (TIntermNode *originalArgNode : originalArgs)
122 {
123 ASSERT(size > 0);
124 TIntermTyped *originalArg = originalArgNode->getAsTyped();
125 ASSERT(originalArg);
126 TIntermTyped *argVariable = createTempVariable(originalArg);
127 if (originalArg->isScalar())
128 {
129 sequence->push_back(argVariable);
130 size--;
131 }
132 else if (originalArg->isVector())
133 {
134 if (scalarizeVector)
135 {
136 int repeat = std::min<int>(size, originalArg->getNominalSize());
137 size -= repeat;
138 for (int index = 0; index < repeat; ++index)
139 {
140 TIntermBinary *newNode =
141 ConstructVectorIndexBinaryNode(argVariable->deepCopy(), index);
142 sequence->push_back(newNode);
143 }
144 }
145 else
146 {
147 sequence->push_back(argVariable);
148 size -= originalArg->getNominalSize();
149 }
150 }
151 else
152 {
153 ASSERT(originalArg->isMatrix());
154 if (scalarizeMatrix)
155 {
156 int colIndex = 0, rowIndex = 0;
157 int repeat = std::min<int>(size, originalArg->getCols() * originalArg->getRows());
158 size -= repeat;
159 while (repeat > 0)
160 {
161 TIntermBinary *newNode =
162 ConstructMatrixIndexBinaryNode(argVariable->deepCopy(), colIndex, rowIndex);
163 sequence->push_back(newNode);
164 rowIndex++;
165 if (rowIndex >= originalArg->getRows())
166 {
167 rowIndex = 0;
168 colIndex++;
169 }
170 repeat--;
171 }
172 }
173 else
174 {
175 sequence->push_back(argVariable);
176 size -= originalArg->getCols() * originalArg->getRows();
177 }
178 }
179 }
180 }
181
createTempVariable(TIntermTyped * original)182 TIntermTyped *ScalarizeArgsTraverser::createTempVariable(TIntermTyped *original)
183 {
184 ASSERT(original);
185
186 TType *type = new TType(original->getType());
187 type->setQualifier(EvqTemporary);
188
189 // The precision of the constant must have been retained (or derived), which will now apply to
190 // the temp variable. In some cases, the precision cannot be derived, so use the constant as
191 // is. For example, in the following standalone statement, the precision of the constant 0
192 // cannot be determined:
193 //
194 // mat2(0, bvec3(m));
195 //
196 if (IsPrecisionApplicableToType(type->getBasicType()) && type->getPrecision() == EbpUndefined)
197 {
198 return original;
199 }
200
201 TVariable *variable = CreateTempVariable(mSymbolTable, type);
202
203 ASSERT(mBlockStack.size() > 0);
204 TIntermSequence &sequence = mBlockStack.back();
205 TIntermDeclaration *declaration = CreateTempInitDeclarationNode(variable, original);
206 sequence.push_back(declaration);
207
208 return CreateTempSymbolNode(variable);
209 }
210
211 } // namespace
212
ScalarizeVecAndMatConstructorArgs(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable)213 bool ScalarizeVecAndMatConstructorArgs(TCompiler *compiler,
214 TIntermBlock *root,
215 TSymbolTable *symbolTable)
216 {
217 ScalarizeArgsTraverser scalarizer(symbolTable);
218 root->traverse(&scalarizer);
219
220 return compiler->validateAST(root);
221 }
222
223 } // namespace sh
224