1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include "compiler/translator/tree_ops/ConvertUnsupportedConstructorsToFunctionCalls.h"
8 #include "compiler/translator/ImmutableString.h"
9 #include "compiler/translator/Symbol.h"
10 #include "compiler/translator/tree_util/FindFunction.h"
11 #include "compiler/translator/tree_util/IntermNode_util.h"
12 #include "compiler/translator/tree_util/IntermRebuild.h"
13
14 using namespace sh;
15
16 namespace
17 {
18
AppendMatrixElementArgument(TIntermSymbol * parameter,int colIndex,int rowIndex,TIntermSequence * returnCtorArgs)19 void AppendMatrixElementArgument(TIntermSymbol *parameter,
20 int colIndex,
21 int rowIndex,
22 TIntermSequence *returnCtorArgs)
23 {
24 TIntermBinary *matColN =
25 new TIntermBinary(EOpIndexDirect, parameter->deepCopy(), CreateIndexNode(colIndex));
26 TIntermSwizzle *matElem = new TIntermSwizzle(matColN, {rowIndex});
27 returnCtorArgs->push_back(matElem);
28 }
29
30 // Adds the argument to sequence for a scalar constructor.
31 // Given scalar(scalarA) appends scalarA
32 // Given scalar(vecA) appends vecA.x
33 // Given scalar(matA) appends matA[0].x
AppendScalarFromNonScalarArguments(TFunction & function,TIntermSequence * returnCtorArgs)34 void AppendScalarFromNonScalarArguments(TFunction &function, TIntermSequence *returnCtorArgs)
35 {
36 const TVariable *var = function.getParam(0);
37 TIntermSymbol *arg0 = new TIntermSymbol(var);
38
39 const TType &type = arg0->getType();
40
41 if (type.isScalar())
42 {
43 returnCtorArgs->push_back(arg0);
44 }
45 else if (type.isVector())
46 {
47 TIntermSwizzle *vecX = new TIntermSwizzle(arg0, {0});
48 returnCtorArgs->push_back(vecX);
49 }
50 else if (type.isMatrix())
51 {
52 AppendMatrixElementArgument(arg0, 0, 0, returnCtorArgs);
53 }
54 }
55
56 // Adds the arguments to sequence for a vector constructor from a scalar.
57 // Given vecN(scalarA) appends scalarA, scalarA, ... n times
AppendVectorFromScalarArgument(const TType & type,TFunction & function,TIntermSequence * returnCtorArgs)58 void AppendVectorFromScalarArgument(const TType &type,
59 TFunction &function,
60 TIntermSequence *returnCtorArgs)
61 {
62 const uint8_t vectorSize = type.getNominalSize();
63 const TVariable *var = function.getParam(0);
64 TIntermSymbol *v = new TIntermSymbol(var);
65 for (uint8_t i = 0; i < vectorSize; ++i)
66 {
67 returnCtorArgs->push_back(v->deepCopy());
68 }
69 }
70
71 // Adds the arguments to sequence for a vector or matrix constructor from the available arguments
72 // applying arguments in order until the requested number of values have been extracted from the
73 // given arguments or until there are no more arguments.
AppendValuesFromMultipleArguments(int numValuesNeeded,TFunction & function,TIntermSequence * returnCtorArgs)74 void AppendValuesFromMultipleArguments(int numValuesNeeded,
75 TFunction &function,
76 TIntermSequence *returnCtorArgs)
77 {
78 size_t numParameters = function.getParamCount();
79 size_t paramIndex = 0;
80 uint8_t colIndex = 0;
81 uint8_t rowIndex = 0;
82
83 for (int i = 0; i < numValuesNeeded && paramIndex < numParameters; ++i)
84 {
85 const TVariable *p = function.getParam(paramIndex);
86 TIntermSymbol *parameter = new TIntermSymbol(p);
87 if (parameter->isScalar())
88 {
89 returnCtorArgs->push_back(parameter);
90 ++paramIndex;
91 }
92 else if (parameter->isVector())
93 {
94 TIntermSwizzle *vecS = new TIntermSwizzle(parameter->deepCopy(), {rowIndex++});
95 returnCtorArgs->push_back(vecS);
96 if (rowIndex == parameter->getNominalSize())
97 {
98 ++paramIndex;
99 rowIndex = 0;
100 }
101 }
102 else if (parameter->isMatrix())
103 {
104 AppendMatrixElementArgument(parameter, colIndex, rowIndex++, returnCtorArgs);
105 if (rowIndex == parameter->getSecondarySize())
106 {
107 rowIndex = 0;
108 ++colIndex;
109 if (colIndex == parameter->getNominalSize())
110 {
111 colIndex = 0;
112 ++paramIndex;
113 }
114 }
115 }
116 }
117 }
118
119 // Adds the arguments for a matrix constructor from a scalar
120 // putting the scalar along the diagonal and 0 everywhere else.
AppendMatrixFromScalarArgument(const TType & type,TFunction & function,TIntermSequence * returnCtorArgs)121 void AppendMatrixFromScalarArgument(const TType &type,
122 TFunction &function,
123 TIntermSequence *returnCtorArgs)
124 {
125 const TVariable *var = function.getParam(0);
126 TIntermSymbol *v = new TIntermSymbol(var);
127 const uint8_t numCols = type.getNominalSize();
128 const uint8_t numRows = type.getSecondarySize();
129 for (uint8_t col = 0; col < numCols; ++col)
130 {
131 for (uint8_t row = 0; row < numRows; ++row)
132 {
133 if (col == row)
134 {
135 returnCtorArgs->push_back(v->deepCopy());
136 }
137 else
138 {
139 returnCtorArgs->push_back(CreateFloatNode(0.0f, sh::EbpUndefined));
140 }
141 }
142 }
143 }
144
145 // Add the argument for a matrix constructor from a matrix
146 // copying elements from the same column/row and otherwise
147 // initialize to the identity matrix.
AppendMatrixFromMatrixArgument(const TType & type,TFunction & function,TIntermSequence * returnCtorArgs)148 void AppendMatrixFromMatrixArgument(const TType &type,
149 TFunction &function,
150 TIntermSequence *returnCtorArgs)
151 {
152 const TVariable *var = function.getParam(0);
153 TIntermSymbol *v = new TIntermSymbol(var);
154 const uint8_t dstCols = type.getNominalSize();
155 const uint8_t dstRows = type.getSecondarySize();
156 const uint8_t srcCols = v->getNominalSize();
157 const uint8_t srcRows = v->getSecondarySize();
158 for (uint8_t dstCol = 0; dstCol < dstCols; ++dstCol)
159 {
160 for (uint8_t dstRow = 0; dstRow < dstRows; ++dstRow)
161 {
162 if (dstRow < srcRows && dstCol < srcCols)
163 {
164 AppendMatrixElementArgument(v, dstCol, dstRow, returnCtorArgs);
165 }
166 else
167 {
168 returnCtorArgs->push_back(
169 CreateFloatNode(dstRow == dstCol ? 1.0f : 0.0f, sh::EbpUndefined));
170 }
171 }
172 }
173 }
174
175 class Rebuild : public TIntermRebuild
176 {
177 public:
Rebuild(TCompiler & compiler)178 explicit Rebuild(TCompiler &compiler) : TIntermRebuild(compiler, false, true) {}
visitAggregatePost(TIntermAggregate & node)179 PostResult visitAggregatePost(TIntermAggregate &node) override
180 {
181 if (!node.isConstructor())
182 {
183 return node;
184 }
185
186 TIntermSequence &arguments = *node.getSequence();
187 if (arguments.empty())
188 {
189 return node;
190 }
191
192 const TType &type = node.getType();
193 const TType &arg0Type = arguments[0]->getAsTyped()->getType();
194
195 if (!type.isScalar() && !type.isVector() && !type.isMatrix())
196 {
197 return node;
198 }
199
200 if (type.isArray())
201 {
202 return node;
203 }
204
205 // check for type_ctor(sameType)
206 // scalar(scalar) -> passthrough
207 // vecN(vecN) -> passthrough
208 // matN(matN) -> passthrough
209 if (arguments.size() == 1 && arg0Type == type)
210 {
211 return node;
212 }
213
214 // The following are simple casts:
215 //
216 // - basic(s) (where basic is int, uint, float or bool, and s is scalar).
217 // - gvecN(vN) (where the argument is a single vector with the same number of components).
218 // - matNxM(mNxM) (where the argument is a single matrix with the same dimensions). Note
219 // that
220 // matrices are always float, so there's no actual cast and this would be a no-op.
221 //
222 const bool isSingleScalarCast =
223 arguments.size() == 1 && type.isScalar() && arg0Type.isScalar();
224 const bool isSingleVectorCast = arguments.size() == 1 && type.isVector() &&
225 arg0Type.isVector() &&
226 type.getNominalSize() == arg0Type.getNominalSize();
227 const bool isSingleMatrixCast =
228 arguments.size() == 1 && type.isMatrix() && arg0Type.isMatrix() &&
229 type.getCols() == arg0Type.getCols() && type.getRows() == arg0Type.getRows();
230 if (isSingleScalarCast || isSingleVectorCast || isSingleMatrixCast)
231 {
232 return node;
233 }
234
235 // Cases we need to handle:
236 // scalar(vec)
237 // scalar(mat)
238 // vecN(scalar)
239 // vecN(vecM)
240 // vecN(a,...)
241 // matN(scalar) -> diag
242 // matN(vec) -> fail!
243 // manN(matM) -> corner + ident
244 // matN(a, ...)
245
246 // Build a function and pass all the constructor's arguments to it.
247 TIntermBlock *body = new TIntermBlock;
248 TFunction *function = new TFunction(&mSymbolTable, ImmutableString(""),
249 SymbolType::AngleInternal, &type, true);
250
251 for (size_t i = 0; i < arguments.size(); ++i)
252 {
253 TIntermTyped &arg = *arguments[i]->getAsTyped();
254 TType *argType = new TType(arg.getBasicType(), arg.getPrecision(), EvqParamIn,
255 arg.getNominalSize(), arg.getSecondarySize());
256 TVariable *var = CreateTempVariable(&mSymbolTable, argType);
257 function->addParameter(var);
258 }
259
260 // Build a return statement for the function that
261 // converts the arguments into the required type.
262 TIntermSequence *returnCtorArgs = new TIntermSequence();
263
264 if (type.isScalar())
265 {
266 AppendScalarFromNonScalarArguments(*function, returnCtorArgs);
267 }
268 else if (type.isVector())
269 {
270 if (arguments.size() == 1 && arg0Type.isScalar())
271 {
272 AppendVectorFromScalarArgument(type, *function, returnCtorArgs);
273 }
274 else
275 {
276 AppendValuesFromMultipleArguments(type.getNominalSize(), *function, returnCtorArgs);
277 }
278 }
279 else if (type.isMatrix())
280 {
281 if (arguments.size() == 1 && arg0Type.isScalar())
282 {
283 // MSL already handles this case
284 AppendMatrixFromScalarArgument(type, *function, returnCtorArgs);
285 }
286 else if (arg0Type.isMatrix())
287 {
288 AppendMatrixFromMatrixArgument(type, *function, returnCtorArgs);
289 }
290 else
291 {
292 AppendValuesFromMultipleArguments(type.getNominalSize() * type.getSecondarySize(),
293 *function, returnCtorArgs);
294 }
295 }
296
297 TIntermBranch *returnStatement =
298 new TIntermBranch(EOpReturn, TIntermAggregate::CreateConstructor(type, returnCtorArgs));
299 body->appendStatement(returnStatement);
300
301 TIntermFunctionDefinition *functionDefinition =
302 CreateInternalFunctionDefinitionNode(*function, body);
303 mFunctionDefs.push_back(functionDefinition);
304
305 TIntermTyped *functionCall = TIntermAggregate::CreateFunctionCall(*function, &arguments);
306
307 return *functionCall;
308 }
309
rewrite(TIntermBlock & root)310 bool rewrite(TIntermBlock &root)
311 {
312 if (!rebuildInPlace(root))
313 {
314 return true;
315 }
316
317 size_t firstFunctionIndex = FindFirstFunctionDefinitionIndex(&root);
318 for (TIntermFunctionDefinition *functionDefinition : mFunctionDefs)
319 {
320 root.insertChildNodes(firstFunctionIndex, TIntermSequence({functionDefinition}));
321 }
322
323 return mCompiler.validateAST(&root);
324 }
325
326 private:
327 TVector<TIntermFunctionDefinition *> mFunctionDefs;
328 };
329
330 } // anonymous namespace
331
ConvertUnsupportedConstructorsToFunctionCalls(TCompiler & compiler,TIntermBlock & root)332 bool sh::ConvertUnsupportedConstructorsToFunctionCalls(TCompiler &compiler, TIntermBlock &root)
333 {
334 return Rebuild(compiler).rewrite(root);
335 }
336