1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include <cstring>
8 #include <numeric>
9 #include <unordered_map>
10 #include <unordered_set>
11 
12 #include "compiler/translator/msl/AstHelpers.h"
13 
14 using namespace sh;
15 
16 ////////////////////////////////////////////////////////////////////////////////
17 
ViewDeclaration(TIntermDeclaration & declNode)18 Declaration sh::ViewDeclaration(TIntermDeclaration &declNode)
19 {
20     ASSERT(declNode.getChildCount() == 1);
21     TIntermNode *childNode = declNode.getChildNode(0);
22     ASSERT(childNode);
23     TIntermSymbol *symbolNode;
24     if ((symbolNode = childNode->getAsSymbolNode()))
25     {
26         return {*symbolNode, nullptr};
27     }
28     else
29     {
30         TIntermBinary *initNode = childNode->getAsBinaryNode();
31         ASSERT(initNode);
32         ASSERT(initNode->getOp() == TOperator::EOpInitialize);
33         symbolNode = initNode->getLeft()->getAsSymbolNode();
34         ASSERT(symbolNode);
35         return {*symbolNode, initNode->getRight()};
36     }
37 }
38 
CreateStructTypeVariable(TSymbolTable & symbolTable,const TStructure & structure)39 const TVariable &sh::CreateStructTypeVariable(TSymbolTable &symbolTable,
40                                               const TStructure &structure)
41 {
42     TType *type    = new TType(&structure, true);
43     TVariable *var = new TVariable(&symbolTable, ImmutableString(""), type, SymbolType::Empty);
44     return *var;
45 }
46 
CreateInstanceVariable(TSymbolTable & symbolTable,const TStructure & structure,const Name & name,TQualifier qualifier,const TSpan<const unsigned int> * arraySizes)47 const TVariable &sh::CreateInstanceVariable(TSymbolTable &symbolTable,
48                                             const TStructure &structure,
49                                             const Name &name,
50                                             TQualifier qualifier,
51                                             const TSpan<const unsigned int> *arraySizes)
52 {
53     TType *type = new TType(&structure, false);
54     type->setQualifier(qualifier);
55     if (arraySizes)
56     {
57         type->makeArrays(*arraySizes);
58     }
59     TVariable *var = new TVariable(&symbolTable, name.rawName(), type, name.symbolType());
60     return *var;
61 }
62 
AcquireFunctionExtras(TFunction & dest,const TFunction & src)63 static void AcquireFunctionExtras(TFunction &dest, const TFunction &src)
64 {
65     if (src.isDefined())
66     {
67         dest.setDefined();
68     }
69 
70     if (src.hasPrototypeDeclaration())
71     {
72         dest.setHasPrototypeDeclaration();
73     }
74 }
75 
CloneSequenceAndPrepend(const TIntermSequence & seq,TIntermNode & node)76 TIntermSequence &sh::CloneSequenceAndPrepend(const TIntermSequence &seq, TIntermNode &node)
77 {
78     TIntermSequence *newSeq = new TIntermSequence();
79     newSeq->push_back(&node);
80 
81     for (TIntermNode *oldNode : seq)
82     {
83         newSeq->push_back(oldNode);
84     }
85 
86     return *newSeq;
87 }
88 
AddParametersFrom(TFunction & dest,const TFunction & src)89 void sh::AddParametersFrom(TFunction &dest, const TFunction &src)
90 {
91     const size_t paramCount = src.getParamCount();
92     for (size_t i = 0; i < paramCount; ++i)
93     {
94         const TVariable *var = src.getParam(i);
95         dest.addParameter(var);
96     }
97 }
98 
CloneFunction(TSymbolTable & symbolTable,IdGen & idGen,const TFunction & oldFunc)99 const TFunction &sh::CloneFunction(TSymbolTable &symbolTable,
100                                    IdGen &idGen,
101                                    const TFunction &oldFunc)
102 {
103     ASSERT(oldFunc.symbolType() == SymbolType::UserDefined);
104 
105     Name newName = idGen.createNewName(Name(oldFunc));
106 
107     TFunction &newFunc =
108         *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
109                        &oldFunc.getReturnType(), oldFunc.isKnownToNotHaveSideEffects());
110 
111     AcquireFunctionExtras(newFunc, oldFunc);
112     AddParametersFrom(newFunc, oldFunc);
113 
114     return newFunc;
115 }
116 
CloneFunctionAndPrependParam(TSymbolTable & symbolTable,IdGen * idGen,const TFunction & oldFunc,const TVariable & newParam)117 const TFunction &sh::CloneFunctionAndPrependParam(TSymbolTable &symbolTable,
118                                                   IdGen *idGen,
119                                                   const TFunction &oldFunc,
120                                                   const TVariable &newParam)
121 {
122     ASSERT(oldFunc.symbolType() == SymbolType::UserDefined ||
123            oldFunc.symbolType() == SymbolType::AngleInternal);
124 
125     Name newName = idGen ? idGen->createNewName(Name(oldFunc)) : Name(oldFunc);
126 
127     TFunction &newFunc =
128         *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
129                        &oldFunc.getReturnType(), oldFunc.isKnownToNotHaveSideEffects());
130 
131     AcquireFunctionExtras(newFunc, oldFunc);
132     newFunc.addParameter(&newParam);
133     AddParametersFrom(newFunc, oldFunc);
134 
135     return newFunc;
136 }
137 
CloneFunctionAndPrependTwoParams(TSymbolTable & symbolTable,IdGen * idGen,const TFunction & oldFunc,const TVariable & newParam1,const TVariable & newParam2)138 const TFunction &sh::CloneFunctionAndPrependTwoParams(TSymbolTable &symbolTable,
139                                                       IdGen *idGen,
140                                                       const TFunction &oldFunc,
141                                                       const TVariable &newParam1,
142                                                       const TVariable &newParam2)
143 {
144     ASSERT(oldFunc.symbolType() == SymbolType::UserDefined ||
145            oldFunc.symbolType() == SymbolType::AngleInternal);
146 
147     Name newName = idGen ? idGen->createNewName(Name(oldFunc)) : Name(oldFunc);
148 
149     TFunction &newFunc =
150         *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
151                        &oldFunc.getReturnType(), oldFunc.isKnownToNotHaveSideEffects());
152 
153     AcquireFunctionExtras(newFunc, oldFunc);
154     newFunc.addParameter(&newParam1);
155     newFunc.addParameter(&newParam2);
156     AddParametersFrom(newFunc, oldFunc);
157 
158     return newFunc;
159 }
160 
CloneFunctionAndAppendParams(TSymbolTable & symbolTable,IdGen * idGen,const TFunction & oldFunc,const std::vector<const TVariable * > & newParams)161 const TFunction &sh::CloneFunctionAndAppendParams(TSymbolTable &symbolTable,
162                                                   IdGen *idGen,
163                                                   const TFunction &oldFunc,
164                                                   const std::vector<const TVariable *> &newParams)
165 {
166     ASSERT(oldFunc.symbolType() == SymbolType::UserDefined ||
167            oldFunc.symbolType() == SymbolType::AngleInternal);
168 
169     Name newName = idGen ? idGen->createNewName(Name(oldFunc)) : Name(oldFunc);
170 
171     TFunction &newFunc =
172         *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
173                        &oldFunc.getReturnType(), oldFunc.isKnownToNotHaveSideEffects());
174 
175     AcquireFunctionExtras(newFunc, oldFunc);
176     AddParametersFrom(newFunc, oldFunc);
177     for (const TVariable *param : newParams)
178     {
179         newFunc.addParameter(param);
180     }
181 
182     return newFunc;
183 }
184 
CloneFunctionAndChangeReturnType(TSymbolTable & symbolTable,IdGen * idGen,const TFunction & oldFunc,const TStructure & newReturn)185 const TFunction &sh::CloneFunctionAndChangeReturnType(TSymbolTable &symbolTable,
186                                                       IdGen *idGen,
187                                                       const TFunction &oldFunc,
188                                                       const TStructure &newReturn)
189 {
190     ASSERT(oldFunc.symbolType() == SymbolType::UserDefined);
191 
192     Name newName = idGen ? idGen->createNewName(Name(oldFunc)) : Name(oldFunc);
193 
194     TType *newReturnType = new TType(&newReturn, true);
195     TFunction &newFunc   = *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
196                                           newReturnType, oldFunc.isKnownToNotHaveSideEffects());
197 
198     AcquireFunctionExtras(newFunc, oldFunc);
199     AddParametersFrom(newFunc, oldFunc);
200 
201     return newFunc;
202 }
203 
GetArg(const TIntermAggregate & call,size_t index)204 TIntermTyped &sh::GetArg(const TIntermAggregate &call, size_t index)
205 {
206     ASSERT(index < call.getChildCount());
207     TIntermNode *arg = call.getChildNode(index);
208     ASSERT(arg);
209     TIntermTyped *targ = arg->getAsTyped();
210     ASSERT(targ);
211     return *targ;
212 }
213 
SetArg(TIntermAggregate & call,size_t index,TIntermTyped & arg)214 void sh::SetArg(TIntermAggregate &call, size_t index, TIntermTyped &arg)
215 {
216     ASSERT(index < call.getChildCount());
217     (*call.getSequence())[index] = &arg;
218 }
219 
GetFieldIndex(const TStructure & structure,const ImmutableString & fieldName)220 int sh::GetFieldIndex(const TStructure &structure, const ImmutableString &fieldName)
221 {
222     const TFieldList &fieldList = structure.fields();
223 
224     int i = 0;
225     for (TField *field : fieldList)
226     {
227         if (field->name() == fieldName)
228         {
229             return i;
230         }
231         ++i;
232     }
233 
234     return -1;
235 }
236 
AccessField(const TVariable & structInstanceVar,const ImmutableString & fieldName)237 TIntermBinary &sh::AccessField(const TVariable &structInstanceVar, const ImmutableString &fieldName)
238 {
239     return AccessField(*new TIntermSymbol(&structInstanceVar), fieldName);
240 }
241 
AccessField(TIntermTyped & object,const ImmutableString & fieldName)242 TIntermBinary &sh::AccessField(TIntermTyped &object, const ImmutableString &fieldName)
243 {
244     const TStructure *structure = object.getType().getStruct();
245     ASSERT(structure);
246 
247     const int index = GetFieldIndex(*structure, fieldName);
248     ASSERT(index >= 0);
249     return AccessFieldByIndex(object, index);
250 }
251 
AccessFieldByIndex(TIntermTyped & object,int index)252 TIntermBinary &sh::AccessFieldByIndex(TIntermTyped &object, int index)
253 {
254 #if defined(ANGLE_ENABLE_ASSERTS)
255     const TType &type = object.getType();
256     ASSERT(!type.isArray());
257     const TStructure *structure = type.getStruct();
258     ASSERT(structure);
259     ASSERT(0 <= index);
260     ASSERT(static_cast<size_t>(index) < structure->fields().size());
261 #endif
262 
263     return *new TIntermBinary(
264         TOperator::EOpIndexDirectStruct, &object,
265         new TIntermConstantUnion(new TConstantUnion(index), *new TType(TBasicType::EbtInt)));
266 }
267 
AccessIndex(TIntermTyped & indexableNode,int index)268 TIntermBinary &sh::AccessIndex(TIntermTyped &indexableNode, int index)
269 {
270 #if defined(ANGLE_ENABLE_ASSERTS)
271     const TType &type = indexableNode.getType();
272     ASSERT(type.isArray() || type.isVector() || type.isMatrix());
273 #endif
274 
275     TIntermBinary *accessNode = new TIntermBinary(
276         TOperator::EOpIndexDirect, &indexableNode,
277         new TIntermConstantUnion(new TConstantUnion(index), *new TType(TBasicType::EbtInt)));
278     return *accessNode;
279 }
280 
AccessIndex(TIntermTyped & node,const int * index)281 TIntermTyped &sh::AccessIndex(TIntermTyped &node, const int *index)
282 {
283     if (index)
284     {
285         return AccessIndex(node, *index);
286     }
287     return node;
288 }
289 
SubVector(TIntermTyped & vectorNode,int begin,int end)290 TIntermTyped &sh::SubVector(TIntermTyped &vectorNode, int begin, int end)
291 {
292     ASSERT(vectorNode.getType().isVector());
293     ASSERT(0 <= begin);
294     ASSERT(end <= 4);
295     ASSERT(begin <= end);
296     if (begin == 0 && end == vectorNode.getType().getNominalSize())
297     {
298         return vectorNode;
299     }
300     TVector<int> offsets(static_cast<size_t>(end - begin));
301     std::iota(offsets.begin(), offsets.end(), begin);
302     TIntermSwizzle *swizzle = new TIntermSwizzle(vectorNode.deepCopy(), offsets);
303     return *swizzle;
304 }
305 
IsScalarBasicType(const TType & type)306 bool sh::IsScalarBasicType(const TType &type)
307 {
308     if (!type.isScalar())
309     {
310         return false;
311     }
312     return HasScalarBasicType(type);
313 }
314 
IsVectorBasicType(const TType & type)315 bool sh::IsVectorBasicType(const TType &type)
316 {
317     if (!type.isVector())
318     {
319         return false;
320     }
321     return HasScalarBasicType(type);
322 }
323 
HasScalarBasicType(TBasicType type)324 bool sh::HasScalarBasicType(TBasicType type)
325 {
326     switch (type)
327     {
328         case TBasicType::EbtFloat:
329         case TBasicType::EbtDouble:
330         case TBasicType::EbtInt:
331         case TBasicType::EbtUInt:
332         case TBasicType::EbtBool:
333             return true;
334 
335         default:
336             return false;
337     }
338 }
339 
HasScalarBasicType(const TType & type)340 bool sh::HasScalarBasicType(const TType &type)
341 {
342     return HasScalarBasicType(type.getBasicType());
343 }
344 
CloneType(const TType & type)345 TType &sh::CloneType(const TType &type)
346 {
347     TType &clone = *new TType(type);
348     return clone;
349 }
350 
InnermostType(const TType & type)351 TType &sh::InnermostType(const TType &type)
352 {
353     TType &inner = *new TType(type);
354     inner.toArrayBaseType();
355     return inner;
356 }
357 
DropColumns(const TType & matrixType)358 TType &sh::DropColumns(const TType &matrixType)
359 {
360     ASSERT(matrixType.isMatrix());
361     ASSERT(HasScalarBasicType(matrixType));
362 
363     TType &vectorType = *new TType(matrixType);
364     vectorType.toMatrixColumnType();
365     return vectorType;
366 }
367 
DropOuterDimension(const TType & arrayType)368 TType &sh::DropOuterDimension(const TType &arrayType)
369 {
370     ASSERT(arrayType.isArray());
371 
372     TType &innerType = *new TType(arrayType);
373     innerType.toArrayElementType();
374     return innerType;
375 }
376 
SetTypeDimsImpl(const TType & type,int primary,int secondary)377 static TType &SetTypeDimsImpl(const TType &type, int primary, int secondary)
378 {
379     ASSERT(1 < primary && primary <= 4);
380     ASSERT(1 <= secondary && secondary <= 4);
381     ASSERT(HasScalarBasicType(type));
382 
383     TType &newType = *new TType(type);
384     newType.setPrimarySize(primary);
385     newType.setSecondarySize(secondary);
386     return newType;
387 }
388 
SetVectorDim(const TType & type,int newDim)389 TType &sh::SetVectorDim(const TType &type, int newDim)
390 {
391     ASSERT(type.isRank0() || type.isVector());
392     return SetTypeDimsImpl(type, newDim, 1);
393 }
394 
SetMatrixRowDim(const TType & matrixType,int newDim)395 TType &sh::SetMatrixRowDim(const TType &matrixType, int newDim)
396 {
397     ASSERT(matrixType.isMatrix());
398     ASSERT(1 < newDim && newDim <= 4);
399     return SetTypeDimsImpl(matrixType, matrixType.getCols(), newDim);
400 }
401 
HasMatrixField(const TStructure & structure)402 bool sh::HasMatrixField(const TStructure &structure)
403 {
404     for (const TField *field : structure.fields())
405     {
406         const TType &type = *field->type();
407         if (type.isMatrix())
408         {
409             return true;
410         }
411     }
412     return false;
413 }
414 
HasArrayField(const TStructure & structure)415 bool sh::HasArrayField(const TStructure &structure)
416 {
417     for (const TField *field : structure.fields())
418     {
419         const TType &type = *field->type();
420         if (type.isArray())
421         {
422             return true;
423         }
424     }
425     return false;
426 }
427 
CoerceSimple(TBasicType toBasicType,TIntermTyped & fromNode,bool needsExplicitBoolCast)428 TIntermTyped &sh::CoerceSimple(TBasicType toBasicType,
429                                TIntermTyped &fromNode,
430                                bool needsExplicitBoolCast)
431 {
432     const TType &fromType = fromNode.getType();
433 
434     ASSERT(HasScalarBasicType(toBasicType));
435     ASSERT(HasScalarBasicType(fromType));
436     ASSERT(!fromType.isArray());
437 
438     const TBasicType fromBasicType = fromType.getBasicType();
439 
440     if (toBasicType != fromBasicType)
441     {
442         if (toBasicType == TBasicType::EbtBool && fromNode.isVector() && needsExplicitBoolCast)
443         {
444             switch (fromBasicType)
445             {
446                 case TBasicType::EbtFloat:
447                 case TBasicType::EbtDouble:
448                 case TBasicType::EbtInt:
449                 case TBasicType::EbtUInt:
450                 {
451                     TIntermSequence *argsSequence = new TIntermSequence();
452                     for (uint8_t i = 0; i < fromType.getNominalSize(); i++)
453                     {
454                         TIntermTyped &fromTypeSwizzle     = SubVector(fromNode, i, i + 1);
455                         TIntermAggregate *boolConstructor = TIntermAggregate::CreateConstructor(
456                             *new TType(toBasicType, 1, 1), new TIntermSequence{&fromTypeSwizzle});
457                         argsSequence->push_back(boolConstructor);
458                     }
459                     return *TIntermAggregate::CreateConstructor(
460                         *new TType(toBasicType, fromType.getNominalSize(),
461                                    fromType.getSecondarySize()),
462                         argsSequence);
463                 }
464 
465                 default:
466                     break;  // No explicit conversion needed
467             }
468         }
469 
470         return *TIntermAggregate::CreateConstructor(
471             *new TType(toBasicType, fromType.getNominalSize(), fromType.getSecondarySize()),
472             new TIntermSequence{&fromNode});
473     }
474     return fromNode;
475 }
476 
CoerceSimple(const TType & toType,TIntermTyped & fromNode,bool needsExplicitBoolCast)477 TIntermTyped &sh::CoerceSimple(const TType &toType,
478                                TIntermTyped &fromNode,
479                                bool needsExplicitBoolCast)
480 {
481     const TType &fromType = fromNode.getType();
482 
483     ASSERT(HasScalarBasicType(toType));
484     ASSERT(HasScalarBasicType(fromType));
485     ASSERT(toType.getNominalSize() == fromType.getNominalSize());
486     ASSERT(toType.getSecondarySize() == fromType.getSecondarySize());
487     ASSERT(!toType.isArray());
488     ASSERT(!fromType.isArray());
489 
490     const TBasicType toBasicType   = toType.getBasicType();
491     const TBasicType fromBasicType = fromType.getBasicType();
492 
493     if (toBasicType != fromBasicType)
494     {
495         if (toBasicType == TBasicType::EbtBool && fromNode.isVector() && needsExplicitBoolCast)
496         {
497             switch (fromBasicType)
498             {
499                 case TBasicType::EbtFloat:
500                 case TBasicType::EbtDouble:
501                 case TBasicType::EbtInt:
502                 case TBasicType::EbtUInt:
503                 {
504                     TIntermSequence *argsSequence = new TIntermSequence();
505                     for (uint8_t i = 0; i < fromType.getNominalSize(); i++)
506                     {
507                         TIntermTyped &fromTypeSwizzle     = SubVector(fromNode, i, i + 1);
508                         TIntermAggregate *boolConstructor = TIntermAggregate::CreateConstructor(
509                             *new TType(toBasicType, 1, 1), new TIntermSequence{&fromTypeSwizzle});
510                         argsSequence->push_back(boolConstructor);
511                     }
512                     return *TIntermAggregate::CreateConstructor(
513                         *new TType(toBasicType, fromType.getNominalSize(),
514                                    fromType.getSecondarySize()),
515                         new TIntermSequence{*argsSequence});
516                 }
517 
518                 default:
519                     break;  // No explicit conversion needed
520             }
521         }
522 
523         return *TIntermAggregate::CreateConstructor(toType, new TIntermSequence{&fromNode});
524     }
525     return fromNode;
526 }
527 
AsType(SymbolEnv & symbolEnv,const TType & toType,TIntermTyped & fromNode)528 TIntermTyped &sh::AsType(SymbolEnv &symbolEnv, const TType &toType, TIntermTyped &fromNode)
529 {
530     const TType &fromType = fromNode.getType();
531 
532     ASSERT(HasScalarBasicType(toType));
533     ASSERT(HasScalarBasicType(fromType));
534     ASSERT(!toType.isArray());
535     ASSERT(!fromType.isArray());
536 
537     if (toType == fromType)
538     {
539         return fromNode;
540     }
541     TemplateArg targ(toType);
542     return symbolEnv.callFunctionOverload(Name("as_type", SymbolType::BuiltIn), toType,
543                                           *new TIntermSequence{&fromNode}, 1, &targ);
544 }
545