1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include <cstring>
8 #include <numeric>
9 #include <unordered_map>
10 #include <unordered_set>
11
12 #include "compiler/translator/TranslatorMetalDirect/AstHelpers.h"
13
14 using namespace sh;
15
16 ////////////////////////////////////////////////////////////////////////////////
17
ViewDeclaration(TIntermDeclaration & declNode)18 Declaration sh::ViewDeclaration(TIntermDeclaration &declNode)
19 {
20 ASSERT(declNode.getChildCount() == 1);
21 TIntermNode *childNode = declNode.getChildNode(0);
22 ASSERT(childNode);
23 TIntermSymbol *symbolNode;
24 if ((symbolNode = childNode->getAsSymbolNode()))
25 {
26 return {*symbolNode, nullptr};
27 }
28 else
29 {
30 TIntermBinary *initNode = childNode->getAsBinaryNode();
31 ASSERT(initNode);
32 ASSERT(initNode->getOp() == TOperator::EOpInitialize);
33 symbolNode = initNode->getLeft()->getAsSymbolNode();
34 ASSERT(symbolNode);
35 return {*symbolNode, initNode->getRight()};
36 }
37 }
38
CreateStructTypeVariable(TSymbolTable & symbolTable,const TStructure & structure)39 const TVariable &sh::CreateStructTypeVariable(TSymbolTable &symbolTable,
40 const TStructure &structure)
41 {
42 auto *type = new TType(&structure, true);
43 auto *var = new TVariable(&symbolTable, ImmutableString(""), type, SymbolType::Empty);
44 return *var;
45 }
46
CreateInstanceVariable(TSymbolTable & symbolTable,const TStructure & structure,const Name & name,TQualifier qualifier,const TSpan<const unsigned int> * arraySizes)47 const TVariable &sh::CreateInstanceVariable(TSymbolTable &symbolTable,
48 const TStructure &structure,
49 const Name &name,
50 TQualifier qualifier,
51 const TSpan<const unsigned int> *arraySizes)
52 {
53 auto *type = new TType(&structure, false);
54 type->setQualifier(qualifier);
55 if (arraySizes)
56 {
57 type->makeArrays(*arraySizes);
58 }
59 auto *var = new TVariable(&symbolTable, name.rawName(), type, name.symbolType());
60 return *var;
61 }
62
AcquireFunctionExtras(TFunction & dest,const TFunction & src)63 static void AcquireFunctionExtras(TFunction &dest, const TFunction &src)
64 {
65 if (src.isDefined())
66 {
67 dest.setDefined();
68 }
69
70 if (src.hasPrototypeDeclaration())
71 {
72 dest.setHasPrototypeDeclaration();
73 }
74 }
75
CloneSequenceAndPrepend(const TIntermSequence & seq,TIntermNode & node)76 TIntermSequence &sh::CloneSequenceAndPrepend(const TIntermSequence &seq, TIntermNode &node)
77 {
78 auto *newSeq = new TIntermSequence();
79 newSeq->push_back(&node);
80
81 for (TIntermNode *oldNode : seq)
82 {
83 newSeq->push_back(oldNode);
84 }
85
86 return *newSeq;
87 }
88
AddParametersFrom(TFunction & dest,const TFunction & src)89 void sh::AddParametersFrom(TFunction &dest, const TFunction &src)
90 {
91 const size_t paramCount = src.getParamCount();
92 for (size_t i = 0; i < paramCount; ++i)
93 {
94 const TVariable *var = src.getParam(i);
95 dest.addParameter(var);
96 }
97 }
98
CloneFunction(TSymbolTable & symbolTable,IdGen & idGen,const TFunction & oldFunc)99 const TFunction &sh::CloneFunction(TSymbolTable &symbolTable,
100 IdGen &idGen,
101 const TFunction &oldFunc)
102 {
103 ASSERT(oldFunc.symbolType() == SymbolType::UserDefined);
104
105 Name newName = idGen.createNewName(Name(oldFunc));
106
107 auto &newFunc = *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
108 &oldFunc.getReturnType(), oldFunc.isKnownToNotHaveSideEffects());
109
110 AcquireFunctionExtras(newFunc, oldFunc);
111 AddParametersFrom(newFunc, oldFunc);
112
113 return newFunc;
114 }
115
CloneFunctionAndPrependParam(TSymbolTable & symbolTable,IdGen * idGen,const TFunction & oldFunc,const TVariable & newParam)116 const TFunction &sh::CloneFunctionAndPrependParam(TSymbolTable &symbolTable,
117 IdGen *idGen,
118 const TFunction &oldFunc,
119 const TVariable &newParam)
120 {
121 ASSERT(oldFunc.symbolType() == SymbolType::UserDefined ||
122 oldFunc.symbolType() == SymbolType::AngleInternal);
123
124 Name newName = idGen ? idGen->createNewName(Name(oldFunc)) : Name(oldFunc);
125
126 auto &newFunc = *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
127 &oldFunc.getReturnType(), oldFunc.isKnownToNotHaveSideEffects());
128
129 AcquireFunctionExtras(newFunc, oldFunc);
130 newFunc.addParameter(&newParam);
131 AddParametersFrom(newFunc, oldFunc);
132
133 return newFunc;
134 }
135
CloneFunctionAndAppendParams(TSymbolTable & symbolTable,IdGen * idGen,const TFunction & oldFunc,const std::vector<const TVariable * > & newParams)136 const TFunction &sh::CloneFunctionAndAppendParams(TSymbolTable &symbolTable,
137 IdGen *idGen,
138 const TFunction &oldFunc,
139 const std::vector<const TVariable *> &newParams)
140 {
141 ASSERT(oldFunc.symbolType() == SymbolType::UserDefined ||
142 oldFunc.symbolType() == SymbolType::AngleInternal);
143
144 Name newName = idGen ? idGen->createNewName(Name(oldFunc)) : Name(oldFunc);
145
146 auto &newFunc = *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
147 &oldFunc.getReturnType(), oldFunc.isKnownToNotHaveSideEffects());
148
149 AcquireFunctionExtras(newFunc, oldFunc);
150 AddParametersFrom(newFunc, oldFunc);
151 for (auto *param : newParams)
152 {
153 newFunc.addParameter(param);
154 }
155
156 return newFunc;
157 }
158
CloneFunctionAndChangeReturnType(TSymbolTable & symbolTable,IdGen * idGen,const TFunction & oldFunc,const TStructure & newReturn)159 const TFunction &sh::CloneFunctionAndChangeReturnType(TSymbolTable &symbolTable,
160 IdGen *idGen,
161 const TFunction &oldFunc,
162 const TStructure &newReturn)
163 {
164 ASSERT(oldFunc.symbolType() == SymbolType::UserDefined);
165
166 Name newName = idGen ? idGen->createNewName(Name(oldFunc)) : Name(oldFunc);
167
168 auto *newReturnType = new TType(&newReturn, true);
169 auto &newFunc = *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
170 newReturnType, oldFunc.isKnownToNotHaveSideEffects());
171
172 AcquireFunctionExtras(newFunc, oldFunc);
173 AddParametersFrom(newFunc, oldFunc);
174
175 return newFunc;
176 }
177
GetArg(const TIntermAggregate & call,size_t index)178 TIntermTyped &sh::GetArg(const TIntermAggregate &call, size_t index)
179 {
180 ASSERT(index < call.getChildCount());
181 auto *arg = call.getChildNode(index);
182 ASSERT(arg);
183 auto *targ = arg->getAsTyped();
184 ASSERT(targ);
185 return *targ;
186 }
187
SetArg(TIntermAggregate & call,size_t index,TIntermTyped & arg)188 void sh::SetArg(TIntermAggregate &call, size_t index, TIntermTyped &arg)
189 {
190 ASSERT(index < call.getChildCount());
191 (*call.getSequence())[index] = &arg;
192 }
193
GetFieldIndex(const TStructure & structure,const ImmutableString & fieldName)194 int sh::GetFieldIndex(const TStructure &structure, const ImmutableString &fieldName)
195 {
196 const TFieldList &fieldList = structure.fields();
197
198 int i = 0;
199 for (TField *field : fieldList)
200 {
201 if (field->name() == fieldName)
202 {
203 return i;
204 }
205 ++i;
206 }
207
208 return -1;
209 }
210
AccessField(const TVariable & structInstanceVar,const ImmutableString & fieldName)211 TIntermBinary &sh::AccessField(const TVariable &structInstanceVar, const ImmutableString &fieldName)
212 {
213 return AccessField(*new TIntermSymbol(&structInstanceVar), fieldName);
214 }
215
AccessField(TIntermTyped & object,const ImmutableString & fieldName)216 TIntermBinary &sh::AccessField(TIntermTyped &object, const ImmutableString &fieldName)
217 {
218 const TStructure *structure = object.getType().getStruct();
219 ASSERT(structure);
220
221 const int index = GetFieldIndex(*structure, fieldName);
222 ASSERT(index >= 0);
223 return AccessFieldByIndex(object, index);
224 }
225
AccessFieldByIndex(TIntermTyped & object,int index)226 TIntermBinary &sh::AccessFieldByIndex(TIntermTyped &object, int index)
227 {
228 #if defined(ANGLE_ENABLE_ASSERTS)
229 const TType &type = object.getType();
230 ASSERT(!type.isArray());
231 const TStructure *structure = type.getStruct();
232 ASSERT(structure);
233 ASSERT(0 <= index);
234 ASSERT(static_cast<size_t>(index) < structure->fields().size());
235 #endif
236
237 return *new TIntermBinary(
238 TOperator::EOpIndexDirectStruct, &object,
239 new TIntermConstantUnion(new TConstantUnion(index), *new TType(TBasicType::EbtInt)));
240 }
241
AccessIndex(TIntermTyped & indexableNode,int index)242 TIntermBinary &sh::AccessIndex(TIntermTyped &indexableNode, int index)
243 {
244 #if defined(ANGLE_ENABLE_ASSERTS)
245 const TType &type = indexableNode.getType();
246 ASSERT(type.isArray() || type.isVector() || type.isMatrix());
247 #endif
248
249 auto *accessNode = new TIntermBinary(
250 TOperator::EOpIndexDirect, &indexableNode,
251 new TIntermConstantUnion(new TConstantUnion(index), *new TType(TBasicType::EbtInt)));
252 return *accessNode;
253 }
254
AccessIndex(TIntermTyped & node,const int * index)255 TIntermTyped &sh::AccessIndex(TIntermTyped &node, const int *index)
256 {
257 if (index)
258 {
259 return AccessIndex(node, *index);
260 }
261 return node;
262 }
263
SubVector(TIntermTyped & vectorNode,int begin,int end)264 TIntermTyped &sh::SubVector(TIntermTyped &vectorNode, int begin, int end)
265 {
266 ASSERT(vectorNode.getType().isVector());
267 ASSERT(0 <= begin);
268 ASSERT(end <= 4);
269 ASSERT(begin <= end);
270 if (begin == 0 && end == vectorNode.getType().getNominalSize())
271 {
272 return vectorNode;
273 }
274 TVector<int> offsets(static_cast<size_t>(end - begin));
275 std::iota(offsets.begin(), offsets.end(), begin);
276 auto *swizzle = new TIntermSwizzle(&vectorNode, offsets);
277 return *swizzle;
278 }
279
IsScalarBasicType(const TType & type)280 bool sh::IsScalarBasicType(const TType &type)
281 {
282 if (!type.isScalar())
283 {
284 return false;
285 }
286 return HasScalarBasicType(type);
287 }
288
IsVectorBasicType(const TType & type)289 bool sh::IsVectorBasicType(const TType &type)
290 {
291 if (!type.isVector())
292 {
293 return false;
294 }
295 return HasScalarBasicType(type);
296 }
297
HasScalarBasicType(TBasicType type)298 bool sh::HasScalarBasicType(TBasicType type)
299 {
300 switch (type)
301 {
302 case TBasicType::EbtFloat:
303 case TBasicType::EbtDouble:
304 case TBasicType::EbtInt:
305 case TBasicType::EbtUInt:
306 case TBasicType::EbtBool:
307 return true;
308
309 default:
310 return false;
311 }
312 }
313
HasScalarBasicType(const TType & type)314 bool sh::HasScalarBasicType(const TType &type)
315 {
316 return HasScalarBasicType(type.getBasicType());
317 }
318
InitType(TType & type)319 static void InitType(TType &type)
320 {
321 if (type.isArray())
322 {
323 auto sizes = type.getArraySizes();
324 type.toArrayBaseType();
325 type.makeArrays(sizes);
326 }
327 }
328
CloneType(const TType & type)329 TType &sh::CloneType(const TType &type)
330 {
331 auto &clone = *new TType(type);
332 InitType(clone);
333 return clone;
334 }
335
InnermostType(const TType & type)336 TType &sh::InnermostType(const TType &type)
337 {
338 auto &inner = *new TType(type);
339 inner.toArrayBaseType();
340 InitType(inner);
341 return inner;
342 }
343
DropColumns(const TType & matrixType)344 TType &sh::DropColumns(const TType &matrixType)
345 {
346 ASSERT(matrixType.isMatrix());
347 ASSERT(HasScalarBasicType(matrixType));
348 const char *mangledName = nullptr;
349
350 auto &vectorType =
351 *new TType(matrixType.getBasicType(), matrixType.getPrecision(), matrixType.getQualifier(),
352 matrixType.getRows(), 1, matrixType.getArraySizes(), mangledName);
353 InitType(vectorType);
354 return vectorType;
355 }
356
DropOuterDimension(const TType & arrayType)357 TType &sh::DropOuterDimension(const TType &arrayType)
358 {
359 ASSERT(arrayType.isArray());
360 const char *mangledName = nullptr;
361 const auto &arraySizes = arrayType.getArraySizes();
362
363 auto &innerType =
364 *new TType(arrayType.getBasicType(), arrayType.getPrecision(), arrayType.getQualifier(),
365 arrayType.getNominalSize(), arrayType.getSecondarySize(),
366 arraySizes.subspan(0, arraySizes.size() - 1), mangledName);
367 InitType(innerType);
368 return innerType;
369 }
370
SetTypeDimsImpl(const TType & type,int primary,int secondary)371 static TType &SetTypeDimsImpl(const TType &type, int primary, int secondary)
372 {
373 ASSERT(1 < primary && primary <= 4);
374 ASSERT(1 <= secondary && secondary <= 4);
375 ASSERT(HasScalarBasicType(type));
376 const char *mangledName = nullptr;
377
378 auto &newType = *new TType(type.getBasicType(), type.getPrecision(), type.getQualifier(),
379 primary, secondary, type.getArraySizes(), mangledName);
380 InitType(newType);
381 return newType;
382 }
383
SetVectorDim(const TType & type,int newDim)384 TType &sh::SetVectorDim(const TType &type, int newDim)
385 {
386 ASSERT(type.isRank0() || type.isVector());
387 return SetTypeDimsImpl(type, newDim, 1);
388 }
389
SetMatrixRowDim(const TType & matrixType,int newDim)390 TType &sh::SetMatrixRowDim(const TType &matrixType, int newDim)
391 {
392 ASSERT(matrixType.isMatrix());
393 ASSERT(1 < newDim && newDim <= 4);
394 return SetTypeDimsImpl(matrixType, matrixType.getCols(), newDim);
395 }
396
HasMatrixField(const TStructure & structure)397 bool sh::HasMatrixField(const TStructure &structure)
398 {
399 for (const TField *field : structure.fields())
400 {
401 const TType &type = *field->type();
402 if (type.isMatrix())
403 {
404 return true;
405 }
406 }
407 return false;
408 }
409
HasArrayField(const TStructure & structure)410 bool sh::HasArrayField(const TStructure &structure)
411 {
412 for (const TField *field : structure.fields())
413 {
414 const TType &type = *field->type();
415 if (type.isArray())
416 {
417 return true;
418 }
419 }
420 return false;
421 }
422
CoerceSimple(TBasicType toType,TIntermTyped & fromNode)423 TIntermTyped &sh::CoerceSimple(TBasicType toType, TIntermTyped &fromNode)
424 {
425 const TType &fromType = fromNode.getType();
426
427 ASSERT(HasScalarBasicType(toType));
428 ASSERT(HasScalarBasicType(fromType));
429 ASSERT(!fromType.isArray());
430
431 if (toType != fromType.getBasicType())
432 {
433 return *TIntermAggregate::CreateConstructor(
434 *new TType(toType, fromType.getNominalSize(), fromType.getSecondarySize()),
435 new TIntermSequence{&fromNode});
436 }
437 return fromNode;
438 }
439
CoerceSimple(const TType & toType,TIntermTyped & fromNode)440 TIntermTyped &sh::CoerceSimple(const TType &toType, TIntermTyped &fromNode)
441 {
442 const TType &fromType = fromNode.getType();
443
444 ASSERT(HasScalarBasicType(toType));
445 ASSERT(HasScalarBasicType(fromType));
446 ASSERT(toType.getNominalSize() == fromType.getNominalSize());
447 ASSERT(toType.getSecondarySize() == fromType.getSecondarySize());
448 ASSERT(!toType.isArray());
449 ASSERT(!fromType.isArray());
450
451 if (toType.getBasicType() != fromType.getBasicType())
452 {
453 return *TIntermAggregate::CreateConstructor(toType, new TIntermSequence{&fromNode});
454 }
455 return fromNode;
456 }
457
AsType(SymbolEnv & symbolEnv,const TType & toType,TIntermTyped & fromNode)458 TIntermTyped &sh::AsType(SymbolEnv &symbolEnv, const TType &toType, TIntermTyped &fromNode)
459 {
460 const TType &fromType = fromNode.getType();
461
462 ASSERT(HasScalarBasicType(toType));
463 ASSERT(HasScalarBasicType(fromType));
464 ASSERT(!toType.isArray());
465 ASSERT(!fromType.isArray());
466
467 if (toType == fromType)
468 {
469 return fromNode;
470 }
471 TemplateArg targ(toType);
472 return symbolEnv.callFunctionOverload(Name("as_type", SymbolType::BuiltIn), toType,
473 *new TIntermSequence{&fromNode}, 1, &targ);
474 }
475