• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include "compiler/translator/tree_ops/ConvertUnsupportedConstructorsToFunctionCalls.h"
8 #include "compiler/translator/ImmutableString.h"
9 #include "compiler/translator/Symbol.h"
10 #include "compiler/translator/tree_util/FindFunction.h"
11 #include "compiler/translator/tree_util/IntermNode_util.h"
12 #include "compiler/translator/tree_util/IntermRebuild.h"
13 
14 using namespace sh;
15 
16 namespace
17 {
18 
AppendMatrixElementArgument(TIntermSymbol * parameter,int colIndex,int rowIndex,TIntermSequence * returnCtorArgs)19 void AppendMatrixElementArgument(TIntermSymbol *parameter,
20                                  int colIndex,
21                                  int rowIndex,
22                                  TIntermSequence *returnCtorArgs)
23 {
24     TIntermBinary *matColN =
25         new TIntermBinary(EOpIndexDirect, parameter->deepCopy(), CreateIndexNode(colIndex));
26     TIntermSwizzle *matElem = new TIntermSwizzle(matColN, {rowIndex});
27     returnCtorArgs->push_back(matElem);
28 }
29 
30 // Adds the argument to sequence for a scalar constructor.
31 // Given scalar(scalarA) appends scalarA
32 // Given scalar(vecA) appends vecA.x
33 // Given scalar(matA) appends matA[0].x
AppendScalarFromNonScalarArguments(TFunction & function,TIntermSequence * returnCtorArgs)34 void AppendScalarFromNonScalarArguments(TFunction &function, TIntermSequence *returnCtorArgs)
35 {
36     const TVariable *var = function.getParam(0);
37     TIntermSymbol *arg0  = new TIntermSymbol(var);
38 
39     const TType &type = arg0->getType();
40 
41     if (type.isScalar())
42     {
43         returnCtorArgs->push_back(arg0);
44     }
45     else if (type.isVector())
46     {
47         TIntermSwizzle *vecX = new TIntermSwizzle(arg0, {0});
48         returnCtorArgs->push_back(vecX);
49     }
50     else if (type.isMatrix())
51     {
52         AppendMatrixElementArgument(arg0, 0, 0, returnCtorArgs);
53     }
54 }
55 
56 // Adds the arguments to sequence for a vector constructor from a scalar.
57 // Given vecN(scalarA) appends scalarA, scalarA, ... n times
AppendVectorFromScalarArgument(const TType & type,TFunction & function,TIntermSequence * returnCtorArgs)58 void AppendVectorFromScalarArgument(const TType &type,
59                                     TFunction &function,
60                                     TIntermSequence *returnCtorArgs)
61 {
62     const uint8_t vectorSize = type.getNominalSize();
63     const TVariable *var     = function.getParam(0);
64     TIntermSymbol *v         = new TIntermSymbol(var);
65     for (uint8_t i = 0; i < vectorSize; ++i)
66     {
67         returnCtorArgs->push_back(v->deepCopy());
68     }
69 }
70 
71 // Adds the arguments to sequence for a vector or matrix constructor from the available arguments
72 // applying arguments in order until the requested number of values have been extracted from the
73 // given arguments or until there are no more arguments.
AppendValuesFromMultipleArguments(int numValuesNeeded,TFunction & function,TIntermSequence * returnCtorArgs)74 void AppendValuesFromMultipleArguments(int numValuesNeeded,
75                                        TFunction &function,
76                                        TIntermSequence *returnCtorArgs)
77 {
78     size_t numParameters = function.getParamCount();
79     size_t paramIndex    = 0;
80     uint8_t colIndex     = 0;
81     uint8_t rowIndex     = 0;
82 
83     for (int i = 0; i < numValuesNeeded && paramIndex < numParameters; ++i)
84     {
85         const TVariable *p       = function.getParam(paramIndex);
86         TIntermSymbol *parameter = new TIntermSymbol(p);
87         if (parameter->isScalar())
88         {
89             returnCtorArgs->push_back(parameter);
90             ++paramIndex;
91         }
92         else if (parameter->isVector())
93         {
94             TIntermSwizzle *vecS = new TIntermSwizzle(parameter->deepCopy(), {rowIndex++});
95             returnCtorArgs->push_back(vecS);
96             if (rowIndex == parameter->getNominalSize())
97             {
98                 ++paramIndex;
99                 rowIndex = 0;
100             }
101         }
102         else if (parameter->isMatrix())
103         {
104             AppendMatrixElementArgument(parameter, colIndex, rowIndex++, returnCtorArgs);
105             if (rowIndex == parameter->getSecondarySize())
106             {
107                 rowIndex = 0;
108                 ++colIndex;
109                 if (colIndex == parameter->getNominalSize())
110                 {
111                     colIndex = 0;
112                     ++paramIndex;
113                 }
114             }
115         }
116     }
117 }
118 
119 // Adds the arguments for a matrix constructor from a scalar
120 // putting the scalar along the diagonal and 0 everywhere else.
AppendMatrixFromScalarArgument(const TType & type,TFunction & function,TIntermSequence * returnCtorArgs)121 void AppendMatrixFromScalarArgument(const TType &type,
122                                     TFunction &function,
123                                     TIntermSequence *returnCtorArgs)
124 {
125     const TVariable *var  = function.getParam(0);
126     TIntermSymbol *v      = new TIntermSymbol(var);
127     const uint8_t numCols = type.getNominalSize();
128     const uint8_t numRows = type.getSecondarySize();
129     for (uint8_t col = 0; col < numCols; ++col)
130     {
131         for (uint8_t row = 0; row < numRows; ++row)
132         {
133             if (col == row)
134             {
135                 returnCtorArgs->push_back(v->deepCopy());
136             }
137             else
138             {
139                 returnCtorArgs->push_back(CreateFloatNode(0.0f, sh::EbpUndefined));
140             }
141         }
142     }
143 }
144 
145 // Add the argument for a matrix constructor from a matrix
146 // copying elements from the same column/row and otherwise
147 // initialize to the identity matrix.
AppendMatrixFromMatrixArgument(const TType & type,TFunction & function,TIntermSequence * returnCtorArgs)148 void AppendMatrixFromMatrixArgument(const TType &type,
149                                     TFunction &function,
150                                     TIntermSequence *returnCtorArgs)
151 {
152     const TVariable *var  = function.getParam(0);
153     TIntermSymbol *v      = new TIntermSymbol(var);
154     const uint8_t dstCols = type.getNominalSize();
155     const uint8_t dstRows = type.getSecondarySize();
156     const uint8_t srcCols = v->getNominalSize();
157     const uint8_t srcRows = v->getSecondarySize();
158     for (uint8_t dstCol = 0; dstCol < dstCols; ++dstCol)
159     {
160         for (uint8_t dstRow = 0; dstRow < dstRows; ++dstRow)
161         {
162             if (dstRow < srcRows && dstCol < srcCols)
163             {
164                 AppendMatrixElementArgument(v, dstCol, dstRow, returnCtorArgs);
165             }
166             else
167             {
168                 returnCtorArgs->push_back(
169                     CreateFloatNode(dstRow == dstCol ? 1.0f : 0.0f, sh::EbpUndefined));
170             }
171         }
172     }
173 }
174 
175 class Rebuild : public TIntermRebuild
176 {
177   public:
Rebuild(TCompiler & compiler)178     explicit Rebuild(TCompiler &compiler) : TIntermRebuild(compiler, false, true) {}
visitAggregatePost(TIntermAggregate & node)179     PostResult visitAggregatePost(TIntermAggregate &node) override
180     {
181         if (!node.isConstructor())
182         {
183             return node;
184         }
185 
186         TIntermSequence &arguments = *node.getSequence();
187         if (arguments.empty())
188         {
189             return node;
190         }
191 
192         const TType &type     = node.getType();
193         const TType &arg0Type = arguments[0]->getAsTyped()->getType();
194 
195         if (!type.isScalar() && !type.isVector() && !type.isMatrix())
196         {
197             return node;
198         }
199 
200         if (type.isArray())
201         {
202             return node;
203         }
204 
205         // check for type_ctor(sameType)
206         // scalar(scalar) -> passthrough
207         // vecN(vecN) -> passthrough
208         // matN(matN) -> passthrough
209         if (arguments.size() == 1 && arg0Type == type)
210         {
211             return node;
212         }
213 
214         // The following are simple casts:
215         //
216         // - basic(s) (where basic is int, uint, float or bool, and s is scalar).
217         // - gvecN(vN) (where the argument is a single vector with the same number of components).
218         // - matNxM(mNxM) (where the argument is a single matrix with the same dimensions).  Note
219         // that
220         //   matrices are always float, so there's no actual cast and this would be a no-op.
221         //
222         const bool isSingleScalarCast =
223             arguments.size() == 1 && type.isScalar() && arg0Type.isScalar();
224         const bool isSingleVectorCast = arguments.size() == 1 && type.isVector() &&
225                                         arg0Type.isVector() &&
226                                         type.getNominalSize() == arg0Type.getNominalSize();
227         const bool isSingleMatrixCast =
228             arguments.size() == 1 && type.isMatrix() && arg0Type.isMatrix() &&
229             type.getCols() == arg0Type.getCols() && type.getRows() == arg0Type.getRows();
230         if (isSingleScalarCast || isSingleVectorCast || isSingleMatrixCast)
231         {
232             return node;
233         }
234 
235         // Cases we need to handle:
236         // scalar(vec)
237         // scalar(mat)
238         // vecN(scalar)
239         // vecN(vecM)
240         // vecN(a,...)
241         // matN(scalar) -> diag
242         // matN(vec) -> fail!
243         // manN(matM) -> corner + ident
244         // matN(a, ...)
245 
246         // Build a function and pass all the constructor's arguments to it.
247         TIntermBlock *body  = new TIntermBlock;
248         TFunction *function = new TFunction(&mSymbolTable, ImmutableString(""),
249                                             SymbolType::AngleInternal, &type, true);
250 
251         for (size_t i = 0; i < arguments.size(); ++i)
252         {
253             TIntermTyped &arg = *arguments[i]->getAsTyped();
254             TType *argType    = new TType(arg.getBasicType(), arg.getPrecision(), EvqParamIn,
255                                        arg.getNominalSize(), arg.getSecondarySize());
256             TVariable *var    = CreateTempVariable(&mSymbolTable, argType);
257             function->addParameter(var);
258         }
259 
260         // Build a return statement for the function that
261         // converts the arguments into the required type.
262         TIntermSequence *returnCtorArgs = new TIntermSequence();
263 
264         if (type.isScalar())
265         {
266             AppendScalarFromNonScalarArguments(*function, returnCtorArgs);
267         }
268         else if (type.isVector())
269         {
270             if (arguments.size() == 1 && arg0Type.isScalar())
271             {
272                 AppendVectorFromScalarArgument(type, *function, returnCtorArgs);
273             }
274             else
275             {
276                 AppendValuesFromMultipleArguments(type.getNominalSize(), *function, returnCtorArgs);
277             }
278         }
279         else if (type.isMatrix())
280         {
281             if (arguments.size() == 1 && arg0Type.isScalar())
282             {
283                 // MSL already handles this case
284                 AppendMatrixFromScalarArgument(type, *function, returnCtorArgs);
285             }
286             else if (arg0Type.isMatrix())
287             {
288                 AppendMatrixFromMatrixArgument(type, *function, returnCtorArgs);
289             }
290             else
291             {
292                 AppendValuesFromMultipleArguments(type.getNominalSize() * type.getSecondarySize(),
293                                                   *function, returnCtorArgs);
294             }
295         }
296 
297         TIntermBranch *returnStatement =
298             new TIntermBranch(EOpReturn, TIntermAggregate::CreateConstructor(type, returnCtorArgs));
299         body->appendStatement(returnStatement);
300 
301         TIntermFunctionDefinition *functionDefinition =
302             CreateInternalFunctionDefinitionNode(*function, body);
303         mFunctionDefs.push_back(functionDefinition);
304 
305         TIntermTyped *functionCall = TIntermAggregate::CreateFunctionCall(*function, &arguments);
306 
307         return *functionCall;
308     }
309 
rewrite(TIntermBlock & root)310     bool rewrite(TIntermBlock &root)
311     {
312         if (!rebuildInPlace(root))
313         {
314             return true;
315         }
316 
317         size_t firstFunctionIndex = FindFirstFunctionDefinitionIndex(&root);
318         for (TIntermFunctionDefinition *functionDefinition : mFunctionDefs)
319         {
320             root.insertChildNodes(firstFunctionIndex, TIntermSequence({functionDefinition}));
321         }
322 
323         return mCompiler.validateAST(&root);
324     }
325 
326   private:
327     TVector<TIntermFunctionDefinition *> mFunctionDefs;
328 };
329 
330 }  // anonymous namespace
331 
ConvertUnsupportedConstructorsToFunctionCalls(TCompiler & compiler,TIntermBlock & root)332 bool sh::ConvertUnsupportedConstructorsToFunctionCalls(TCompiler &compiler, TIntermBlock &root)
333 {
334     return Rebuild(compiler).rewrite(root);
335 }
336