1 //
2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // Scalarize vector and matrix constructor args, so that vectors built from components don't have
7 // matrix arguments, and matrices built from components don't have vector arguments. This avoids
8 // driver bugs around vector and matrix constructors.
9 //
10
11 #include "compiler/translator/tree_ops/glsl/ScalarizeVecAndMatConstructorArgs.h"
12
13 #include "angle_gl.h"
14 #include "common/angleutils.h"
15 #include "compiler/translator/Compiler.h"
16 #include "compiler/translator/tree_util/IntermNode_util.h"
17 #include "compiler/translator/tree_util/IntermTraverse.h"
18
19 namespace sh
20 {
21
22 namespace
23 {
24 // Traverser that converts a vector or matrix constructor to one that only uses scalars. To support
25 // all the various places such a constructor could be found, a helper function is created for each
26 // such constructor. The helper function takes the constructor arguments and creates the object.
27 //
28 // Constructors that are transformed are:
29 //
30 // - vecN(scalar): translates to vecN(scalar, ..., scalar)
31 // - vecN(vec1, vec2, ...): translates to vecN(vec1.x, vec1.y, vec2.x, ...)
32 // - vecN(matrix): translates to vecN(matrix[0][0], matrix[0][1], ...)
33 // - matNxM(scalar): translates to matNxM(scalar, 0, ..., 0
34 // 0, scalar, ..., 0
35 // ...
36 // 0, 0, ..., scalar)
37 // - matNxM(vec1, vec2, ...): translates to matNxM(vec1.x, vec1.y, vec2.x, ...)
38 // - matNxM(matrixAxB): translates to matNxM(matrix[0][0], matrix[0][1], ..., 0
39 // matrix[1][0], matrix[1][1], ..., 0
40 // ...
41 // 0, 0, ..., 1)
42 //
43 class ScalarizeTraverser : public TIntermTraverser
44 {
45 public:
ScalarizeTraverser(TSymbolTable * symbolTable)46 ScalarizeTraverser(TSymbolTable *symbolTable)
47 : TIntermTraverser(true, false, false, symbolTable)
48 {}
49
50 bool update(TCompiler *compiler, TIntermBlock *root);
51
52 protected:
53 bool visitAggregate(Visit visit, TIntermAggregate *node) override;
54
55 private:
56 bool shouldScalarize(TIntermTyped *node);
57
58 // Create a helper function that takes the same arguments as the constructor it replaces.
59 const TFunction *createHelper(TIntermAggregate *node);
60 TIntermTyped *createHelperCall(TIntermAggregate *node, const TFunction *helper);
61 void addHelperDefinition(const TFunction *helper, TIntermBlock *body);
62
63 // If given a constructor, convert it to a function call. Recursively processes constructor
64 // arguments. Otherwise, recursively visit the node.
65 TIntermTyped *createConstructor(TIntermTyped *node);
66
67 void extractComponents(const TFunction *helper,
68 size_t componentCount,
69 TIntermSequence *componentsOut);
70
71 void createConstructorVectorFromScalar(TIntermAggregate *node,
72 const TFunction *helper,
73 TIntermSequence *constructorArgsOut);
74 void createConstructorVectorFromMultiple(TIntermAggregate *node,
75 const TFunction *helper,
76 TIntermSequence *constructorArgsOut);
77 void createConstructorMatrixFromScalar(TIntermAggregate *node,
78 const TFunction *helper,
79 TIntermSequence *constructorArgsOut);
80 void createConstructorMatrixFromVectors(TIntermAggregate *node,
81 const TFunction *helper,
82 TIntermSequence *constructorArgsOut);
83 void createConstructorMatrixFromMatrix(TIntermAggregate *node,
84 const TFunction *helper,
85 TIntermSequence *constructorArgsOut);
86
87 TIntermSequence mFunctionsToAdd;
88 };
89
visitAggregate(Visit visit,TIntermAggregate * node)90 bool ScalarizeTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
91 {
92 if (!shouldScalarize(node))
93 {
94 return true;
95 }
96
97 TIntermTyped *replacement = createConstructor(node);
98 if (replacement != node)
99 {
100 queueReplacement(replacement, OriginalNode::IS_DROPPED);
101 }
102 // createConstructor already visits children
103 return false;
104 }
105
shouldScalarize(TIntermTyped * typed)106 bool ScalarizeTraverser::shouldScalarize(TIntermTyped *typed)
107 {
108 TIntermAggregate *node = typed->getAsAggregate();
109 if (node == nullptr || node->getOp() != EOpConstruct)
110 {
111 return false;
112 }
113
114 const TType &type = node->getType();
115 const TIntermSequence &arguments = *node->getSequence();
116 const TType &arg0Type = arguments[0]->getAsTyped()->getType();
117
118 const bool isSingleVectorCast = arguments.size() == 1 && type.isVector() &&
119 arg0Type.isVector() &&
120 type.getNominalSize() == arg0Type.getNominalSize();
121 const bool isSingleMatrixCast = arguments.size() == 1 && type.isMatrix() &&
122 arg0Type.isMatrix() && type.getCols() == arg0Type.getCols() &&
123 type.getRows() == arg0Type.getRows();
124
125 // Skip non-vector non-matrix constructors, as well as trivial constructors.
126 if (type.isArray() || type.getStruct() != nullptr || type.isScalar() || isSingleVectorCast ||
127 isSingleMatrixCast)
128 {
129 return false;
130 }
131
132 return true;
133 }
134
createHelper(TIntermAggregate * node)135 const TFunction *ScalarizeTraverser::createHelper(TIntermAggregate *node)
136 {
137 TFunction *helper = new TFunction(mSymbolTable, kEmptyImmutableString,
138 SymbolType::AngleInternal, &node->getType(), true);
139
140 const TIntermSequence &arguments = *node->getSequence();
141 for (TIntermNode *arg : arguments)
142 {
143 TType *argType = new TType(arg->getAsTyped()->getType());
144 argType->setQualifier(EvqParamIn);
145
146 TVariable *argVar =
147 new TVariable(mSymbolTable, kEmptyImmutableString, argType, SymbolType::AngleInternal);
148 helper->addParameter(argVar);
149 }
150
151 return helper;
152 }
153
createHelperCall(TIntermAggregate * node,const TFunction * helper)154 TIntermTyped *ScalarizeTraverser::createHelperCall(TIntermAggregate *node, const TFunction *helper)
155 {
156 TIntermSequence callArgs;
157
158 const TIntermSequence &arguments = *node->getSequence();
159 for (TIntermNode *arg : arguments)
160 {
161 // Note: createConstructor makes sure the arg is visited even if not constructor.
162 callArgs.push_back(createConstructor(arg->getAsTyped()));
163 }
164
165 return TIntermAggregate::CreateFunctionCall(*helper, &callArgs);
166 }
167
addHelperDefinition(const TFunction * helper,TIntermBlock * body)168 void ScalarizeTraverser::addHelperDefinition(const TFunction *helper, TIntermBlock *body)
169 {
170 mFunctionsToAdd.push_back(
171 new TIntermFunctionDefinition(new TIntermFunctionPrototype(helper), body));
172 }
173
createConstructor(TIntermTyped * typed)174 TIntermTyped *ScalarizeTraverser::createConstructor(TIntermTyped *typed)
175 {
176 if (!shouldScalarize(typed))
177 {
178 typed->traverse(this);
179 return typed;
180 }
181
182 TIntermAggregate *node = typed->getAsAggregate();
183 const TType &type = node->getType();
184 const TIntermSequence &arguments = *node->getSequence();
185 const TType &arg0Type = arguments[0]->getAsTyped()->getType();
186
187 const TFunction *helper = createHelper(node);
188 TIntermSequence constructorArgs;
189
190 if (type.isVector())
191 {
192 if (arguments.size() == 1 && arg0Type.isScalar())
193 {
194 createConstructorVectorFromScalar(node, helper, &constructorArgs);
195 }
196 createConstructorVectorFromMultiple(node, helper, &constructorArgs);
197 }
198 else
199 {
200 ASSERT(type.isMatrix());
201
202 if (arg0Type.isScalar() && arguments.size() == 1)
203 {
204 createConstructorMatrixFromScalar(node, helper, &constructorArgs);
205 }
206 if (arg0Type.isMatrix())
207 {
208 createConstructorMatrixFromMatrix(node, helper, &constructorArgs);
209 }
210 createConstructorMatrixFromVectors(node, helper, &constructorArgs);
211 }
212
213 TIntermBlock *body = new TIntermBlock;
214 body->appendStatement(
215 new TIntermBranch(EOpReturn, TIntermAggregate::CreateConstructor(type, &constructorArgs)));
216 addHelperDefinition(helper, body);
217
218 return createHelperCall(node, helper);
219 }
220
221 // Extract enough scalar arguments from the arguments of helper to produce enough arguments for the
222 // constructor call (given in componentCount).
extractComponents(const TFunction * helper,size_t componentCount,TIntermSequence * componentsOut)223 void ScalarizeTraverser::extractComponents(const TFunction *helper,
224 size_t componentCount,
225 TIntermSequence *componentsOut)
226 {
227 for (size_t argumentIndex = 0;
228 argumentIndex < helper->getParamCount() && componentsOut->size() < componentCount;
229 ++argumentIndex)
230 {
231 TIntermTyped *argument = new TIntermSymbol(helper->getParam(argumentIndex));
232 const TType &argumentType = argument->getType();
233
234 if (argumentType.isScalar())
235 {
236 // For scalar parameters, there's nothing to do
237 componentsOut->push_back(argument);
238 continue;
239 }
240 if (argumentType.isVector())
241 {
242 // For vector parameters, take components out of the vector one by one.
243 for (uint8_t componentIndex = 0; componentIndex < argumentType.getNominalSize() &&
244 componentsOut->size() < componentCount;
245 ++componentIndex)
246 {
247 componentsOut->push_back(
248 new TIntermSwizzle(argument->deepCopy(), {componentIndex}));
249 }
250 continue;
251 }
252
253 ASSERT(argumentType.isMatrix());
254
255 // For matrix parameters, take components out of the matrix one by one in column-major
256 // order.
257 for (uint8_t columnIndex = 0;
258 columnIndex < argumentType.getCols() && componentsOut->size() < componentCount;
259 ++columnIndex)
260 {
261 TIntermTyped *col = new TIntermBinary(EOpIndexDirect, argument->deepCopy(),
262 CreateIndexNode(columnIndex));
263
264 for (uint8_t componentIndex = 0;
265 componentIndex < argumentType.getRows() && componentsOut->size() < componentCount;
266 ++componentIndex)
267 {
268 componentsOut->push_back(new TIntermSwizzle(col->deepCopy(), {componentIndex}));
269 }
270 }
271 }
272 }
273
createConstructorVectorFromScalar(TIntermAggregate * node,const TFunction * helper,TIntermSequence * constructorArgsOut)274 void ScalarizeTraverser::createConstructorVectorFromScalar(TIntermAggregate *node,
275 const TFunction *helper,
276 TIntermSequence *constructorArgsOut)
277 {
278 ASSERT(helper->getParamCount() == 1);
279 TIntermTyped *scalar = new TIntermSymbol(helper->getParam(0));
280 const TType &type = node->getType();
281
282 // Replicate the single scalar argument as many times as necessary.
283 for (size_t index = 0; index < type.getNominalSize(); ++index)
284 {
285 constructorArgsOut->push_back(scalar->deepCopy());
286 }
287 }
288
createConstructorVectorFromMultiple(TIntermAggregate * node,const TFunction * helper,TIntermSequence * constructorArgsOut)289 void ScalarizeTraverser::createConstructorVectorFromMultiple(TIntermAggregate *node,
290 const TFunction *helper,
291 TIntermSequence *constructorArgsOut)
292 {
293 extractComponents(helper, node->getType().getNominalSize(), constructorArgsOut);
294 }
295
createConstructorMatrixFromScalar(TIntermAggregate * node,const TFunction * helper,TIntermSequence * constructorArgsOut)296 void ScalarizeTraverser::createConstructorMatrixFromScalar(TIntermAggregate *node,
297 const TFunction *helper,
298 TIntermSequence *constructorArgsOut)
299 {
300 ASSERT(helper->getParamCount() == 1);
301 TIntermTyped *scalar = new TIntermSymbol(helper->getParam(0));
302 const TType &type = node->getType();
303
304 // Create the scalar over the diagonal. Every other element is 0.
305 for (uint8_t columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
306 {
307 for (uint8_t rowIndex = 0; rowIndex < type.getRows(); ++rowIndex)
308 {
309 if (columnIndex == rowIndex)
310 {
311 constructorArgsOut->push_back(scalar->deepCopy());
312 }
313 else
314 {
315 ASSERT(type.getBasicType() == EbtFloat);
316 constructorArgsOut->push_back(CreateFloatNode(0, type.getPrecision()));
317 }
318 }
319 }
320 }
321
createConstructorMatrixFromVectors(TIntermAggregate * node,const TFunction * helper,TIntermSequence * constructorArgsOut)322 void ScalarizeTraverser::createConstructorMatrixFromVectors(TIntermAggregate *node,
323 const TFunction *helper,
324 TIntermSequence *constructorArgsOut)
325 {
326 const TType &type = node->getType();
327 extractComponents(helper, type.getCols() * type.getRows(), constructorArgsOut);
328 }
329
createConstructorMatrixFromMatrix(TIntermAggregate * node,const TFunction * helper,TIntermSequence * constructorArgsOut)330 void ScalarizeTraverser::createConstructorMatrixFromMatrix(TIntermAggregate *node,
331 const TFunction *helper,
332 TIntermSequence *constructorArgsOut)
333 {
334 ASSERT(helper->getParamCount() == 1);
335 TIntermTyped *matrix = new TIntermSymbol(helper->getParam(0));
336 const TType &type = node->getType();
337
338 // The result is the identity matrix with the size of the result, superimposed by the input
339 for (uint8_t columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
340 {
341 for (uint8_t rowIndex = 0; rowIndex < type.getRows(); ++rowIndex)
342 {
343 if (columnIndex < matrix->getType().getCols() && rowIndex < matrix->getType().getRows())
344 {
345 TIntermTyped *col = new TIntermBinary(EOpIndexDirect, matrix->deepCopy(),
346 CreateIndexNode(columnIndex));
347 constructorArgsOut->push_back(
348 new TIntermSwizzle(col, {static_cast<int>(rowIndex)}));
349 }
350 else
351 {
352 ASSERT(type.getBasicType() == EbtFloat);
353 constructorArgsOut->push_back(
354 CreateFloatNode(columnIndex == rowIndex ? 1.0f : 0.0f, type.getPrecision()));
355 }
356 }
357 }
358 }
359
update(TCompiler * compiler,TIntermBlock * root)360 bool ScalarizeTraverser::update(TCompiler *compiler, TIntermBlock *root)
361 {
362 // Insert any added function definitions at the tope of the block
363 root->insertChildNodes(0, mFunctionsToAdd);
364
365 // Apply updates and validate
366 return updateTree(compiler, root);
367 }
368 } // namespace
369
ScalarizeVecAndMatConstructorArgs(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable)370 bool ScalarizeVecAndMatConstructorArgs(TCompiler *compiler,
371 TIntermBlock *root,
372 TSymbolTable *symbolTable)
373 {
374 ScalarizeTraverser scalarizer(symbolTable);
375 root->traverse(&scalarizer);
376 return scalarizer.update(compiler, root);
377 }
378 } // namespace sh
379