1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include <cstring>
8 #include <numeric>
9 #include <unordered_map>
10 #include <unordered_set>
11
12 #include "compiler/translator/TranslatorMetalDirect/AstHelpers.h"
13
14 using namespace sh;
15
16 ////////////////////////////////////////////////////////////////////////////////
17
ViewDeclaration(TIntermDeclaration & declNode)18 Declaration sh::ViewDeclaration(TIntermDeclaration &declNode)
19 {
20 ASSERT(declNode.getChildCount() == 1);
21 TIntermNode *childNode = declNode.getChildNode(0);
22 ASSERT(childNode);
23 TIntermSymbol *symbolNode;
24 if ((symbolNode = childNode->getAsSymbolNode()))
25 {
26 return {*symbolNode, nullptr};
27 }
28 else
29 {
30 TIntermBinary *initNode = childNode->getAsBinaryNode();
31 ASSERT(initNode);
32 ASSERT(initNode->getOp() == TOperator::EOpInitialize);
33 symbolNode = initNode->getLeft()->getAsSymbolNode();
34 ASSERT(symbolNode);
35 return {*symbolNode, initNode->getRight()};
36 }
37 }
38
CreateStructTypeVariable(TSymbolTable & symbolTable,const TStructure & structure)39 const TVariable &sh::CreateStructTypeVariable(TSymbolTable &symbolTable,
40 const TStructure &structure)
41 {
42 TType *type = new TType(&structure, true);
43 TVariable *var = new TVariable(&symbolTable, ImmutableString(""), type, SymbolType::Empty);
44 return *var;
45 }
46
CreateInstanceVariable(TSymbolTable & symbolTable,const TStructure & structure,const Name & name,TQualifier qualifier,const TSpan<const unsigned int> * arraySizes)47 const TVariable &sh::CreateInstanceVariable(TSymbolTable &symbolTable,
48 const TStructure &structure,
49 const Name &name,
50 TQualifier qualifier,
51 const TSpan<const unsigned int> *arraySizes)
52 {
53 TType *type = new TType(&structure, false);
54 type->setQualifier(qualifier);
55 if (arraySizes)
56 {
57 type->makeArrays(*arraySizes);
58 }
59 TVariable *var = new TVariable(&symbolTable, name.rawName(), type, name.symbolType());
60 return *var;
61 }
62
AcquireFunctionExtras(TFunction & dest,const TFunction & src)63 static void AcquireFunctionExtras(TFunction &dest, const TFunction &src)
64 {
65 if (src.isDefined())
66 {
67 dest.setDefined();
68 }
69
70 if (src.hasPrototypeDeclaration())
71 {
72 dest.setHasPrototypeDeclaration();
73 }
74 }
75
CloneSequenceAndPrepend(const TIntermSequence & seq,TIntermNode & node)76 TIntermSequence &sh::CloneSequenceAndPrepend(const TIntermSequence &seq, TIntermNode &node)
77 {
78 TIntermSequence *newSeq = new TIntermSequence();
79 newSeq->push_back(&node);
80
81 for (TIntermNode *oldNode : seq)
82 {
83 newSeq->push_back(oldNode);
84 }
85
86 return *newSeq;
87 }
88
AddParametersFrom(TFunction & dest,const TFunction & src)89 void sh::AddParametersFrom(TFunction &dest, const TFunction &src)
90 {
91 const size_t paramCount = src.getParamCount();
92 for (size_t i = 0; i < paramCount; ++i)
93 {
94 const TVariable *var = src.getParam(i);
95 dest.addParameter(var);
96 }
97 }
98
CloneFunction(TSymbolTable & symbolTable,IdGen & idGen,const TFunction & oldFunc)99 const TFunction &sh::CloneFunction(TSymbolTable &symbolTable,
100 IdGen &idGen,
101 const TFunction &oldFunc)
102 {
103 ASSERT(oldFunc.symbolType() == SymbolType::UserDefined);
104
105 Name newName = idGen.createNewName(Name(oldFunc));
106
107 TFunction &newFunc =
108 *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
109 &oldFunc.getReturnType(), oldFunc.isKnownToNotHaveSideEffects());
110
111 AcquireFunctionExtras(newFunc, oldFunc);
112 AddParametersFrom(newFunc, oldFunc);
113
114 return newFunc;
115 }
116
CloneFunctionAndPrependParam(TSymbolTable & symbolTable,IdGen * idGen,const TFunction & oldFunc,const TVariable & newParam)117 const TFunction &sh::CloneFunctionAndPrependParam(TSymbolTable &symbolTable,
118 IdGen *idGen,
119 const TFunction &oldFunc,
120 const TVariable &newParam)
121 {
122 ASSERT(oldFunc.symbolType() == SymbolType::UserDefined ||
123 oldFunc.symbolType() == SymbolType::AngleInternal);
124
125 Name newName = idGen ? idGen->createNewName(Name(oldFunc)) : Name(oldFunc);
126
127 TFunction &newFunc =
128 *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
129 &oldFunc.getReturnType(), oldFunc.isKnownToNotHaveSideEffects());
130
131 AcquireFunctionExtras(newFunc, oldFunc);
132 newFunc.addParameter(&newParam);
133 AddParametersFrom(newFunc, oldFunc);
134
135 return newFunc;
136 }
137
CloneFunctionAndAppendParams(TSymbolTable & symbolTable,IdGen * idGen,const TFunction & oldFunc,const std::vector<const TVariable * > & newParams)138 const TFunction &sh::CloneFunctionAndAppendParams(TSymbolTable &symbolTable,
139 IdGen *idGen,
140 const TFunction &oldFunc,
141 const std::vector<const TVariable *> &newParams)
142 {
143 ASSERT(oldFunc.symbolType() == SymbolType::UserDefined ||
144 oldFunc.symbolType() == SymbolType::AngleInternal);
145
146 Name newName = idGen ? idGen->createNewName(Name(oldFunc)) : Name(oldFunc);
147
148 TFunction &newFunc =
149 *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
150 &oldFunc.getReturnType(), oldFunc.isKnownToNotHaveSideEffects());
151
152 AcquireFunctionExtras(newFunc, oldFunc);
153 AddParametersFrom(newFunc, oldFunc);
154 for (const TVariable *param : newParams)
155 {
156 newFunc.addParameter(param);
157 }
158
159 return newFunc;
160 }
161
CloneFunctionAndChangeReturnType(TSymbolTable & symbolTable,IdGen * idGen,const TFunction & oldFunc,const TStructure & newReturn)162 const TFunction &sh::CloneFunctionAndChangeReturnType(TSymbolTable &symbolTable,
163 IdGen *idGen,
164 const TFunction &oldFunc,
165 const TStructure &newReturn)
166 {
167 ASSERT(oldFunc.symbolType() == SymbolType::UserDefined);
168
169 Name newName = idGen ? idGen->createNewName(Name(oldFunc)) : Name(oldFunc);
170
171 TType *newReturnType = new TType(&newReturn, true);
172 TFunction &newFunc = *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
173 newReturnType, oldFunc.isKnownToNotHaveSideEffects());
174
175 AcquireFunctionExtras(newFunc, oldFunc);
176 AddParametersFrom(newFunc, oldFunc);
177
178 return newFunc;
179 }
180
GetArg(const TIntermAggregate & call,size_t index)181 TIntermTyped &sh::GetArg(const TIntermAggregate &call, size_t index)
182 {
183 ASSERT(index < call.getChildCount());
184 TIntermNode *arg = call.getChildNode(index);
185 ASSERT(arg);
186 TIntermTyped *targ = arg->getAsTyped();
187 ASSERT(targ);
188 return *targ;
189 }
190
SetArg(TIntermAggregate & call,size_t index,TIntermTyped & arg)191 void sh::SetArg(TIntermAggregate &call, size_t index, TIntermTyped &arg)
192 {
193 ASSERT(index < call.getChildCount());
194 (*call.getSequence())[index] = &arg;
195 }
196
GetFieldIndex(const TStructure & structure,const ImmutableString & fieldName)197 int sh::GetFieldIndex(const TStructure &structure, const ImmutableString &fieldName)
198 {
199 const TFieldList &fieldList = structure.fields();
200
201 int i = 0;
202 for (TField *field : fieldList)
203 {
204 if (field->name() == fieldName)
205 {
206 return i;
207 }
208 ++i;
209 }
210
211 return -1;
212 }
213
AccessField(const TVariable & structInstanceVar,const ImmutableString & fieldName)214 TIntermBinary &sh::AccessField(const TVariable &structInstanceVar, const ImmutableString &fieldName)
215 {
216 return AccessField(*new TIntermSymbol(&structInstanceVar), fieldName);
217 }
218
AccessField(TIntermTyped & object,const ImmutableString & fieldName)219 TIntermBinary &sh::AccessField(TIntermTyped &object, const ImmutableString &fieldName)
220 {
221 const TStructure *structure = object.getType().getStruct();
222 ASSERT(structure);
223
224 const int index = GetFieldIndex(*structure, fieldName);
225 ASSERT(index >= 0);
226 return AccessFieldByIndex(object, index);
227 }
228
AccessFieldByIndex(TIntermTyped & object,int index)229 TIntermBinary &sh::AccessFieldByIndex(TIntermTyped &object, int index)
230 {
231 #if defined(ANGLE_ENABLE_ASSERTS)
232 const TType &type = object.getType();
233 ASSERT(!type.isArray());
234 const TStructure *structure = type.getStruct();
235 ASSERT(structure);
236 ASSERT(0 <= index);
237 ASSERT(static_cast<size_t>(index) < structure->fields().size());
238 #endif
239
240 return *new TIntermBinary(
241 TOperator::EOpIndexDirectStruct, &object,
242 new TIntermConstantUnion(new TConstantUnion(index), *new TType(TBasicType::EbtInt)));
243 }
244
AccessIndex(TIntermTyped & indexableNode,int index)245 TIntermBinary &sh::AccessIndex(TIntermTyped &indexableNode, int index)
246 {
247 #if defined(ANGLE_ENABLE_ASSERTS)
248 const TType &type = indexableNode.getType();
249 ASSERT(type.isArray() || type.isVector() || type.isMatrix());
250 #endif
251
252 TIntermBinary *accessNode = new TIntermBinary(
253 TOperator::EOpIndexDirect, &indexableNode,
254 new TIntermConstantUnion(new TConstantUnion(index), *new TType(TBasicType::EbtInt)));
255 return *accessNode;
256 }
257
AccessIndex(TIntermTyped & node,const int * index)258 TIntermTyped &sh::AccessIndex(TIntermTyped &node, const int *index)
259 {
260 if (index)
261 {
262 return AccessIndex(node, *index);
263 }
264 return node;
265 }
266
SubVector(TIntermTyped & vectorNode,int begin,int end)267 TIntermTyped &sh::SubVector(TIntermTyped &vectorNode, int begin, int end)
268 {
269 ASSERT(vectorNode.getType().isVector());
270 ASSERT(0 <= begin);
271 ASSERT(end <= 4);
272 ASSERT(begin <= end);
273 if (begin == 0 && end == vectorNode.getType().getNominalSize())
274 {
275 return vectorNode;
276 }
277 TVector<int> offsets(static_cast<size_t>(end - begin));
278 std::iota(offsets.begin(), offsets.end(), begin);
279 TIntermSwizzle *swizzle = new TIntermSwizzle(vectorNode.deepCopy(), offsets);
280 return *swizzle;
281 }
282
IsScalarBasicType(const TType & type)283 bool sh::IsScalarBasicType(const TType &type)
284 {
285 if (!type.isScalar())
286 {
287 return false;
288 }
289 return HasScalarBasicType(type);
290 }
291
IsVectorBasicType(const TType & type)292 bool sh::IsVectorBasicType(const TType &type)
293 {
294 if (!type.isVector())
295 {
296 return false;
297 }
298 return HasScalarBasicType(type);
299 }
300
HasScalarBasicType(TBasicType type)301 bool sh::HasScalarBasicType(TBasicType type)
302 {
303 switch (type)
304 {
305 case TBasicType::EbtFloat:
306 case TBasicType::EbtDouble:
307 case TBasicType::EbtInt:
308 case TBasicType::EbtUInt:
309 case TBasicType::EbtBool:
310 return true;
311
312 default:
313 return false;
314 }
315 }
316
HasScalarBasicType(const TType & type)317 bool sh::HasScalarBasicType(const TType &type)
318 {
319 return HasScalarBasicType(type.getBasicType());
320 }
321
CloneType(const TType & type)322 TType &sh::CloneType(const TType &type)
323 {
324 TType &clone = *new TType(type);
325 return clone;
326 }
327
InnermostType(const TType & type)328 TType &sh::InnermostType(const TType &type)
329 {
330 TType &inner = *new TType(type);
331 inner.toArrayBaseType();
332 return inner;
333 }
334
DropColumns(const TType & matrixType)335 TType &sh::DropColumns(const TType &matrixType)
336 {
337 ASSERT(matrixType.isMatrix());
338 ASSERT(HasScalarBasicType(matrixType));
339
340 TType &vectorType = *new TType(matrixType);
341 vectorType.toMatrixColumnType();
342 return vectorType;
343 }
344
DropOuterDimension(const TType & arrayType)345 TType &sh::DropOuterDimension(const TType &arrayType)
346 {
347 ASSERT(arrayType.isArray());
348
349 TType &innerType = *new TType(arrayType);
350 innerType.toArrayElementType();
351 return innerType;
352 }
353
SetTypeDimsImpl(const TType & type,int primary,int secondary)354 static TType &SetTypeDimsImpl(const TType &type, int primary, int secondary)
355 {
356 ASSERT(1 < primary && primary <= 4);
357 ASSERT(1 <= secondary && secondary <= 4);
358 ASSERT(HasScalarBasicType(type));
359
360 TType &newType = *new TType(type);
361 newType.setPrimarySize(primary);
362 newType.setSecondarySize(secondary);
363 return newType;
364 }
365
SetVectorDim(const TType & type,int newDim)366 TType &sh::SetVectorDim(const TType &type, int newDim)
367 {
368 ASSERT(type.isRank0() || type.isVector());
369 return SetTypeDimsImpl(type, newDim, 1);
370 }
371
SetMatrixRowDim(const TType & matrixType,int newDim)372 TType &sh::SetMatrixRowDim(const TType &matrixType, int newDim)
373 {
374 ASSERT(matrixType.isMatrix());
375 ASSERT(1 < newDim && newDim <= 4);
376 return SetTypeDimsImpl(matrixType, matrixType.getCols(), newDim);
377 }
378
HasMatrixField(const TStructure & structure)379 bool sh::HasMatrixField(const TStructure &structure)
380 {
381 for (const TField *field : structure.fields())
382 {
383 const TType &type = *field->type();
384 if (type.isMatrix())
385 {
386 return true;
387 }
388 }
389 return false;
390 }
391
HasArrayField(const TStructure & structure)392 bool sh::HasArrayField(const TStructure &structure)
393 {
394 for (const TField *field : structure.fields())
395 {
396 const TType &type = *field->type();
397 if (type.isArray())
398 {
399 return true;
400 }
401 }
402 return false;
403 }
404
CoerceSimple(TBasicType toBasicType,TIntermTyped & fromNode,bool needsExplicitBoolCast)405 TIntermTyped &sh::CoerceSimple(TBasicType toBasicType,
406 TIntermTyped &fromNode,
407 bool needsExplicitBoolCast)
408 {
409 const TType &fromType = fromNode.getType();
410
411 ASSERT(HasScalarBasicType(toBasicType));
412 ASSERT(HasScalarBasicType(fromType));
413 ASSERT(!fromType.isArray());
414
415 const TBasicType fromBasicType = fromType.getBasicType();
416
417 if (toBasicType != fromBasicType)
418 {
419 if (toBasicType == TBasicType::EbtBool && fromNode.isVector() && needsExplicitBoolCast)
420 {
421 switch (fromBasicType)
422 {
423 case TBasicType::EbtFloat:
424 case TBasicType::EbtDouble:
425 case TBasicType::EbtInt:
426 case TBasicType::EbtUInt:
427 {
428 TIntermSequence *argsSequence = new TIntermSequence();
429 for (int i = 0; i < fromType.getNominalSize(); i++)
430 {
431 TIntermTyped &fromTypeSwizzle = SubVector(fromNode, i, i + 1);
432 TIntermAggregate *boolConstructor = TIntermAggregate::CreateConstructor(
433 *new TType(toBasicType, 1, 1), new TIntermSequence{&fromTypeSwizzle});
434 argsSequence->push_back(boolConstructor);
435 }
436 return *TIntermAggregate::CreateConstructor(
437 *new TType(toBasicType, fromType.getNominalSize(),
438 fromType.getSecondarySize()),
439 argsSequence);
440 }
441
442 default:
443 break; // No explicit conversion needed
444 }
445 }
446
447 return *TIntermAggregate::CreateConstructor(
448 *new TType(toBasicType, fromType.getNominalSize(), fromType.getSecondarySize()),
449 new TIntermSequence{&fromNode});
450 }
451 return fromNode;
452 }
453
CoerceSimple(const TType & toType,TIntermTyped & fromNode,bool needsExplicitBoolCast)454 TIntermTyped &sh::CoerceSimple(const TType &toType,
455 TIntermTyped &fromNode,
456 bool needsExplicitBoolCast)
457 {
458 const TType &fromType = fromNode.getType();
459
460 ASSERT(HasScalarBasicType(toType));
461 ASSERT(HasScalarBasicType(fromType));
462 ASSERT(toType.getNominalSize() == fromType.getNominalSize());
463 ASSERT(toType.getSecondarySize() == fromType.getSecondarySize());
464 ASSERT(!toType.isArray());
465 ASSERT(!fromType.isArray());
466
467 const TBasicType toBasicType = toType.getBasicType();
468 const TBasicType fromBasicType = fromType.getBasicType();
469
470 if (toBasicType != fromBasicType)
471 {
472 if (toBasicType == TBasicType::EbtBool && fromNode.isVector() && needsExplicitBoolCast)
473 {
474 switch (fromBasicType)
475 {
476 case TBasicType::EbtFloat:
477 case TBasicType::EbtDouble:
478 case TBasicType::EbtInt:
479 case TBasicType::EbtUInt:
480 {
481 TIntermSequence *argsSequence = new TIntermSequence();
482 for (int i = 0; i < fromType.getNominalSize(); i++)
483 {
484 TIntermTyped &fromTypeSwizzle = SubVector(fromNode, i, i + 1);
485 TIntermAggregate *boolConstructor = TIntermAggregate::CreateConstructor(
486 *new TType(toBasicType, 1, 1), new TIntermSequence{&fromTypeSwizzle});
487 argsSequence->push_back(boolConstructor);
488 }
489 return *TIntermAggregate::CreateConstructor(
490 *new TType(toBasicType, fromType.getNominalSize(),
491 fromType.getSecondarySize()),
492 new TIntermSequence{*argsSequence});
493 }
494
495 default:
496 break; // No explicit conversion needed
497 }
498 }
499
500 return *TIntermAggregate::CreateConstructor(toType, new TIntermSequence{&fromNode});
501 }
502 return fromNode;
503 }
504
AsType(SymbolEnv & symbolEnv,const TType & toType,TIntermTyped & fromNode)505 TIntermTyped &sh::AsType(SymbolEnv &symbolEnv, const TType &toType, TIntermTyped &fromNode)
506 {
507 const TType &fromType = fromNode.getType();
508
509 ASSERT(HasScalarBasicType(toType));
510 ASSERT(HasScalarBasicType(fromType));
511 ASSERT(!toType.isArray());
512 ASSERT(!fromType.isArray());
513
514 if (toType == fromType)
515 {
516 return fromNode;
517 }
518 TemplateArg targ(toType);
519 return symbolEnv.callFunctionOverload(Name("as_type", SymbolType::BuiltIn), toType,
520 *new TIntermSequence{&fromNode}, 1, &targ);
521 }
522