• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // Scalarize vector and matrix constructor args, so that vectors built from components don't have
7 // matrix arguments, and matrices built from components don't have vector arguments. This avoids
8 // driver bugs around vector and matrix constructors.
9 //
10 
11 #include "compiler/translator/tree_ops/glsl/ScalarizeVecAndMatConstructorArgs.h"
12 
13 #include "angle_gl.h"
14 #include "common/angleutils.h"
15 #include "compiler/translator/Compiler.h"
16 #include "compiler/translator/tree_util/IntermNode_util.h"
17 #include "compiler/translator/tree_util/IntermTraverse.h"
18 
19 namespace sh
20 {
21 
22 namespace
23 {
24 // Traverser that converts a vector or matrix constructor to one that only uses scalars.  To support
25 // all the various places such a constructor could be found, a helper function is created for each
26 // such constructor.  The helper function takes the constructor arguments and creates the object.
27 //
28 // Constructors that are transformed are:
29 //
30 // - vecN(scalar): translates to vecN(scalar, ..., scalar)
31 // - vecN(vec1, vec2, ...): translates to vecN(vec1.x, vec1.y, vec2.x, ...)
32 // - vecN(matrix): translates to vecN(matrix[0][0], matrix[0][1], ...)
33 // - matNxM(scalar): translates to matNxM(scalar, 0, ..., 0
34 //                                        0, scalar, ..., 0
35 //                                        ...
36 //                                        0, 0, ..., scalar)
37 // - matNxM(vec1, vec2, ...): translates to matNxM(vec1.x, vec1.y, vec2.x, ...)
38 // - matNxM(matrixAxB): translates to matNxM(matrix[0][0], matrix[0][1], ..., 0
39 //                                           matrix[1][0], matrix[1][1], ..., 0
40 //                                           ...
41 //                                           0,            0,            ..., 1)
42 //
43 class ScalarizeTraverser : public TIntermTraverser
44 {
45   public:
ScalarizeTraverser(TSymbolTable * symbolTable)46     ScalarizeTraverser(TSymbolTable *symbolTable)
47         : TIntermTraverser(true, false, false, symbolTable)
48     {}
49 
50     bool update(TCompiler *compiler, TIntermBlock *root);
51 
52   protected:
53     bool visitAggregate(Visit visit, TIntermAggregate *node) override;
54 
55   private:
56     bool shouldScalarize(TIntermTyped *node);
57 
58     // Create a helper function that takes the same arguments as the constructor it replaces.
59     const TFunction *createHelper(TIntermAggregate *node);
60     TIntermTyped *createHelperCall(TIntermAggregate *node, const TFunction *helper);
61     void addHelperDefinition(const TFunction *helper, TIntermBlock *body);
62 
63     // If given a constructor, convert it to a function call.  Recursively processes constructor
64     // arguments.  Otherwise, recursively visit the node.
65     TIntermTyped *createConstructor(TIntermTyped *node);
66 
67     void extractComponents(const TFunction *helper,
68                            size_t componentCount,
69                            TIntermSequence *componentsOut);
70 
71     void createConstructorVectorFromScalar(TIntermAggregate *node,
72                                            const TFunction *helper,
73                                            TIntermSequence *constructorArgsOut);
74     void createConstructorVectorFromMultiple(TIntermAggregate *node,
75                                              const TFunction *helper,
76                                              TIntermSequence *constructorArgsOut);
77     void createConstructorMatrixFromScalar(TIntermAggregate *node,
78                                            const TFunction *helper,
79                                            TIntermSequence *constructorArgsOut);
80     void createConstructorMatrixFromVectors(TIntermAggregate *node,
81                                             const TFunction *helper,
82                                             TIntermSequence *constructorArgsOut);
83     void createConstructorMatrixFromMatrix(TIntermAggregate *node,
84                                            const TFunction *helper,
85                                            TIntermSequence *constructorArgsOut);
86 
87     TIntermSequence mFunctionsToAdd;
88 };
89 
visitAggregate(Visit visit,TIntermAggregate * node)90 bool ScalarizeTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
91 {
92     if (!shouldScalarize(node))
93     {
94         return true;
95     }
96 
97     TIntermTyped *replacement = createConstructor(node);
98     if (replacement != node)
99     {
100         queueReplacement(replacement, OriginalNode::IS_DROPPED);
101     }
102     // createConstructor already visits children
103     return false;
104 }
105 
shouldScalarize(TIntermTyped * typed)106 bool ScalarizeTraverser::shouldScalarize(TIntermTyped *typed)
107 {
108     TIntermAggregate *node = typed->getAsAggregate();
109     if (node == nullptr || node->getOp() != EOpConstruct)
110     {
111         return false;
112     }
113 
114     const TType &type                = node->getType();
115     const TIntermSequence &arguments = *node->getSequence();
116     const TType &arg0Type            = arguments[0]->getAsTyped()->getType();
117 
118     const bool isSingleVectorCast = arguments.size() == 1 && type.isVector() &&
119                                     arg0Type.isVector() &&
120                                     type.getNominalSize() == arg0Type.getNominalSize();
121     const bool isSingleMatrixCast = arguments.size() == 1 && type.isMatrix() &&
122                                     arg0Type.isMatrix() && type.getCols() == arg0Type.getCols() &&
123                                     type.getRows() == arg0Type.getRows();
124 
125     // Skip non-vector non-matrix constructors, as well as trivial constructors.
126     if (type.isArray() || type.getStruct() != nullptr || type.isScalar() || isSingleVectorCast ||
127         isSingleMatrixCast)
128     {
129         return false;
130     }
131 
132     return true;
133 }
134 
createHelper(TIntermAggregate * node)135 const TFunction *ScalarizeTraverser::createHelper(TIntermAggregate *node)
136 {
137     TFunction *helper = new TFunction(mSymbolTable, kEmptyImmutableString,
138                                       SymbolType::AngleInternal, &node->getType(), true);
139 
140     const TIntermSequence &arguments = *node->getSequence();
141     for (TIntermNode *arg : arguments)
142     {
143         TType *argType = new TType(arg->getAsTyped()->getType());
144         argType->setQualifier(EvqParamIn);
145 
146         TVariable *argVar =
147             new TVariable(mSymbolTable, kEmptyImmutableString, argType, SymbolType::AngleInternal);
148         helper->addParameter(argVar);
149     }
150 
151     return helper;
152 }
153 
createHelperCall(TIntermAggregate * node,const TFunction * helper)154 TIntermTyped *ScalarizeTraverser::createHelperCall(TIntermAggregate *node, const TFunction *helper)
155 {
156     TIntermSequence callArgs;
157 
158     const TIntermSequence &arguments = *node->getSequence();
159     for (TIntermNode *arg : arguments)
160     {
161         // Note: createConstructor makes sure the arg is visited even if not constructor.
162         callArgs.push_back(createConstructor(arg->getAsTyped()));
163     }
164 
165     return TIntermAggregate::CreateFunctionCall(*helper, &callArgs);
166 }
167 
addHelperDefinition(const TFunction * helper,TIntermBlock * body)168 void ScalarizeTraverser::addHelperDefinition(const TFunction *helper, TIntermBlock *body)
169 {
170     mFunctionsToAdd.push_back(
171         new TIntermFunctionDefinition(new TIntermFunctionPrototype(helper), body));
172 }
173 
createConstructor(TIntermTyped * typed)174 TIntermTyped *ScalarizeTraverser::createConstructor(TIntermTyped *typed)
175 {
176     if (!shouldScalarize(typed))
177     {
178         typed->traverse(this);
179         return typed;
180     }
181 
182     TIntermAggregate *node           = typed->getAsAggregate();
183     const TType &type                = node->getType();
184     const TIntermSequence &arguments = *node->getSequence();
185     const TType &arg0Type            = arguments[0]->getAsTyped()->getType();
186 
187     const TFunction *helper = createHelper(node);
188     TIntermSequence constructorArgs;
189 
190     if (type.isVector())
191     {
192         if (arguments.size() == 1 && arg0Type.isScalar())
193         {
194             createConstructorVectorFromScalar(node, helper, &constructorArgs);
195         }
196         createConstructorVectorFromMultiple(node, helper, &constructorArgs);
197     }
198     else
199     {
200         ASSERT(type.isMatrix());
201 
202         if (arg0Type.isScalar() && arguments.size() == 1)
203         {
204             createConstructorMatrixFromScalar(node, helper, &constructorArgs);
205         }
206         if (arg0Type.isMatrix())
207         {
208             createConstructorMatrixFromMatrix(node, helper, &constructorArgs);
209         }
210         createConstructorMatrixFromVectors(node, helper, &constructorArgs);
211     }
212 
213     TIntermBlock *body = new TIntermBlock;
214     body->appendStatement(
215         new TIntermBranch(EOpReturn, TIntermAggregate::CreateConstructor(type, &constructorArgs)));
216     addHelperDefinition(helper, body);
217 
218     return createHelperCall(node, helper);
219 }
220 
221 // Extract enough scalar arguments from the arguments of helper to produce enough arguments for the
222 // constructor call (given in componentCount).
extractComponents(const TFunction * helper,size_t componentCount,TIntermSequence * componentsOut)223 void ScalarizeTraverser::extractComponents(const TFunction *helper,
224                                            size_t componentCount,
225                                            TIntermSequence *componentsOut)
226 {
227     for (size_t argumentIndex = 0;
228          argumentIndex < helper->getParamCount() && componentsOut->size() < componentCount;
229          ++argumentIndex)
230     {
231         TIntermTyped *argument    = new TIntermSymbol(helper->getParam(argumentIndex));
232         const TType &argumentType = argument->getType();
233 
234         if (argumentType.isScalar())
235         {
236             // For scalar parameters, there's nothing to do
237             componentsOut->push_back(argument);
238             continue;
239         }
240         if (argumentType.isVector())
241         {
242             // For vector parameters, take components out of the vector one by one.
243             for (uint8_t componentIndex = 0; componentIndex < argumentType.getNominalSize() &&
244                                              componentsOut->size() < componentCount;
245                  ++componentIndex)
246             {
247                 componentsOut->push_back(
248                     new TIntermSwizzle(argument->deepCopy(), {componentIndex}));
249             }
250             continue;
251         }
252 
253         ASSERT(argumentType.isMatrix());
254 
255         // For matrix parameters, take components out of the matrix one by one in column-major
256         // order.
257         for (uint8_t columnIndex = 0;
258              columnIndex < argumentType.getCols() && componentsOut->size() < componentCount;
259              ++columnIndex)
260         {
261             TIntermTyped *col = new TIntermBinary(EOpIndexDirect, argument->deepCopy(),
262                                                   CreateIndexNode(columnIndex));
263 
264             for (uint8_t componentIndex = 0;
265                  componentIndex < argumentType.getRows() && componentsOut->size() < componentCount;
266                  ++componentIndex)
267             {
268                 componentsOut->push_back(new TIntermSwizzle(col->deepCopy(), {componentIndex}));
269             }
270         }
271     }
272 }
273 
createConstructorVectorFromScalar(TIntermAggregate * node,const TFunction * helper,TIntermSequence * constructorArgsOut)274 void ScalarizeTraverser::createConstructorVectorFromScalar(TIntermAggregate *node,
275                                                            const TFunction *helper,
276                                                            TIntermSequence *constructorArgsOut)
277 {
278     ASSERT(helper->getParamCount() == 1);
279     TIntermTyped *scalar = new TIntermSymbol(helper->getParam(0));
280     const TType &type    = node->getType();
281 
282     // Replicate the single scalar argument as many times as necessary.
283     for (size_t index = 0; index < type.getNominalSize(); ++index)
284     {
285         constructorArgsOut->push_back(scalar->deepCopy());
286     }
287 }
288 
createConstructorVectorFromMultiple(TIntermAggregate * node,const TFunction * helper,TIntermSequence * constructorArgsOut)289 void ScalarizeTraverser::createConstructorVectorFromMultiple(TIntermAggregate *node,
290                                                              const TFunction *helper,
291                                                              TIntermSequence *constructorArgsOut)
292 {
293     extractComponents(helper, node->getType().getNominalSize(), constructorArgsOut);
294 }
295 
createConstructorMatrixFromScalar(TIntermAggregate * node,const TFunction * helper,TIntermSequence * constructorArgsOut)296 void ScalarizeTraverser::createConstructorMatrixFromScalar(TIntermAggregate *node,
297                                                            const TFunction *helper,
298                                                            TIntermSequence *constructorArgsOut)
299 {
300     ASSERT(helper->getParamCount() == 1);
301     TIntermTyped *scalar = new TIntermSymbol(helper->getParam(0));
302     const TType &type    = node->getType();
303 
304     // Create the scalar over the diagonal.  Every other element is 0.
305     for (uint8_t columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
306     {
307         for (uint8_t rowIndex = 0; rowIndex < type.getRows(); ++rowIndex)
308         {
309             if (columnIndex == rowIndex)
310             {
311                 constructorArgsOut->push_back(scalar->deepCopy());
312             }
313             else
314             {
315                 ASSERT(type.getBasicType() == EbtFloat);
316                 constructorArgsOut->push_back(CreateFloatNode(0, type.getPrecision()));
317             }
318         }
319     }
320 }
321 
createConstructorMatrixFromVectors(TIntermAggregate * node,const TFunction * helper,TIntermSequence * constructorArgsOut)322 void ScalarizeTraverser::createConstructorMatrixFromVectors(TIntermAggregate *node,
323                                                             const TFunction *helper,
324                                                             TIntermSequence *constructorArgsOut)
325 {
326     const TType &type = node->getType();
327     extractComponents(helper, type.getCols() * type.getRows(), constructorArgsOut);
328 }
329 
createConstructorMatrixFromMatrix(TIntermAggregate * node,const TFunction * helper,TIntermSequence * constructorArgsOut)330 void ScalarizeTraverser::createConstructorMatrixFromMatrix(TIntermAggregate *node,
331                                                            const TFunction *helper,
332                                                            TIntermSequence *constructorArgsOut)
333 {
334     ASSERT(helper->getParamCount() == 1);
335     TIntermTyped *matrix = new TIntermSymbol(helper->getParam(0));
336     const TType &type    = node->getType();
337 
338     // The result is the identity matrix with the size of the result, superimposed by the input
339     for (uint8_t columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
340     {
341         for (uint8_t rowIndex = 0; rowIndex < type.getRows(); ++rowIndex)
342         {
343             if (columnIndex < matrix->getType().getCols() && rowIndex < matrix->getType().getRows())
344             {
345                 TIntermTyped *col = new TIntermBinary(EOpIndexDirect, matrix->deepCopy(),
346                                                       CreateIndexNode(columnIndex));
347                 constructorArgsOut->push_back(
348                     new TIntermSwizzle(col, {static_cast<int>(rowIndex)}));
349             }
350             else
351             {
352                 ASSERT(type.getBasicType() == EbtFloat);
353                 constructorArgsOut->push_back(
354                     CreateFloatNode(columnIndex == rowIndex ? 1.0f : 0.0f, type.getPrecision()));
355             }
356         }
357     }
358 }
359 
update(TCompiler * compiler,TIntermBlock * root)360 bool ScalarizeTraverser::update(TCompiler *compiler, TIntermBlock *root)
361 {
362     // Insert any added function definitions at the tope of the block
363     root->insertChildNodes(0, mFunctionsToAdd);
364 
365     // Apply updates and validate
366     return updateTree(compiler, root);
367 }
368 }  // namespace
369 
ScalarizeVecAndMatConstructorArgs(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable)370 bool ScalarizeVecAndMatConstructorArgs(TCompiler *compiler,
371                                        TIntermBlock *root,
372                                        TSymbolTable *symbolTable)
373 {
374     ScalarizeTraverser scalarizer(symbolTable);
375     root->traverse(&scalarizer);
376     return scalarizer.update(compiler, root);
377 }
378 }  // namespace sh
379