• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include <cstring>
8 #include <numeric>
9 #include <unordered_map>
10 #include <unordered_set>
11 
12 #include "compiler/translator/TranslatorMetalDirect/AstHelpers.h"
13 
14 using namespace sh;
15 
16 ////////////////////////////////////////////////////////////////////////////////
17 
ViewDeclaration(TIntermDeclaration & declNode)18 Declaration sh::ViewDeclaration(TIntermDeclaration &declNode)
19 {
20     ASSERT(declNode.getChildCount() == 1);
21     TIntermNode *childNode = declNode.getChildNode(0);
22     ASSERT(childNode);
23     TIntermSymbol *symbolNode;
24     if ((symbolNode = childNode->getAsSymbolNode()))
25     {
26         return {*symbolNode, nullptr};
27     }
28     else
29     {
30         TIntermBinary *initNode = childNode->getAsBinaryNode();
31         ASSERT(initNode);
32         ASSERT(initNode->getOp() == TOperator::EOpInitialize);
33         symbolNode = initNode->getLeft()->getAsSymbolNode();
34         ASSERT(symbolNode);
35         return {*symbolNode, initNode->getRight()};
36     }
37 }
38 
CreateStructTypeVariable(TSymbolTable & symbolTable,const TStructure & structure)39 const TVariable &sh::CreateStructTypeVariable(TSymbolTable &symbolTable,
40                                               const TStructure &structure)
41 {
42     TType *type    = new TType(&structure, true);
43     TVariable *var = new TVariable(&symbolTable, ImmutableString(""), type, SymbolType::Empty);
44     return *var;
45 }
46 
CreateInstanceVariable(TSymbolTable & symbolTable,const TStructure & structure,const Name & name,TQualifier qualifier,const TSpan<const unsigned int> * arraySizes)47 const TVariable &sh::CreateInstanceVariable(TSymbolTable &symbolTable,
48                                             const TStructure &structure,
49                                             const Name &name,
50                                             TQualifier qualifier,
51                                             const TSpan<const unsigned int> *arraySizes)
52 {
53     TType *type = new TType(&structure, false);
54     type->setQualifier(qualifier);
55     if (arraySizes)
56     {
57         type->makeArrays(*arraySizes);
58     }
59     TVariable *var = new TVariable(&symbolTable, name.rawName(), type, name.symbolType());
60     return *var;
61 }
62 
AcquireFunctionExtras(TFunction & dest,const TFunction & src)63 static void AcquireFunctionExtras(TFunction &dest, const TFunction &src)
64 {
65     if (src.isDefined())
66     {
67         dest.setDefined();
68     }
69 
70     if (src.hasPrototypeDeclaration())
71     {
72         dest.setHasPrototypeDeclaration();
73     }
74 }
75 
CloneSequenceAndPrepend(const TIntermSequence & seq,TIntermNode & node)76 TIntermSequence &sh::CloneSequenceAndPrepend(const TIntermSequence &seq, TIntermNode &node)
77 {
78     TIntermSequence *newSeq = new TIntermSequence();
79     newSeq->push_back(&node);
80 
81     for (TIntermNode *oldNode : seq)
82     {
83         newSeq->push_back(oldNode);
84     }
85 
86     return *newSeq;
87 }
88 
AddParametersFrom(TFunction & dest,const TFunction & src)89 void sh::AddParametersFrom(TFunction &dest, const TFunction &src)
90 {
91     const size_t paramCount = src.getParamCount();
92     for (size_t i = 0; i < paramCount; ++i)
93     {
94         const TVariable *var = src.getParam(i);
95         dest.addParameter(var);
96     }
97 }
98 
CloneFunction(TSymbolTable & symbolTable,IdGen & idGen,const TFunction & oldFunc)99 const TFunction &sh::CloneFunction(TSymbolTable &symbolTable,
100                                    IdGen &idGen,
101                                    const TFunction &oldFunc)
102 {
103     ASSERT(oldFunc.symbolType() == SymbolType::UserDefined);
104 
105     Name newName = idGen.createNewName(Name(oldFunc));
106 
107     TFunction &newFunc =
108         *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
109                        &oldFunc.getReturnType(), oldFunc.isKnownToNotHaveSideEffects());
110 
111     AcquireFunctionExtras(newFunc, oldFunc);
112     AddParametersFrom(newFunc, oldFunc);
113 
114     return newFunc;
115 }
116 
CloneFunctionAndPrependParam(TSymbolTable & symbolTable,IdGen * idGen,const TFunction & oldFunc,const TVariable & newParam)117 const TFunction &sh::CloneFunctionAndPrependParam(TSymbolTable &symbolTable,
118                                                   IdGen *idGen,
119                                                   const TFunction &oldFunc,
120                                                   const TVariable &newParam)
121 {
122     ASSERT(oldFunc.symbolType() == SymbolType::UserDefined ||
123            oldFunc.symbolType() == SymbolType::AngleInternal);
124 
125     Name newName = idGen ? idGen->createNewName(Name(oldFunc)) : Name(oldFunc);
126 
127     TFunction &newFunc =
128         *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
129                        &oldFunc.getReturnType(), oldFunc.isKnownToNotHaveSideEffects());
130 
131     AcquireFunctionExtras(newFunc, oldFunc);
132     newFunc.addParameter(&newParam);
133     AddParametersFrom(newFunc, oldFunc);
134 
135     return newFunc;
136 }
137 
CloneFunctionAndAppendParams(TSymbolTable & symbolTable,IdGen * idGen,const TFunction & oldFunc,const std::vector<const TVariable * > & newParams)138 const TFunction &sh::CloneFunctionAndAppendParams(TSymbolTable &symbolTable,
139                                                   IdGen *idGen,
140                                                   const TFunction &oldFunc,
141                                                   const std::vector<const TVariable *> &newParams)
142 {
143     ASSERT(oldFunc.symbolType() == SymbolType::UserDefined ||
144            oldFunc.symbolType() == SymbolType::AngleInternal);
145 
146     Name newName = idGen ? idGen->createNewName(Name(oldFunc)) : Name(oldFunc);
147 
148     TFunction &newFunc =
149         *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
150                        &oldFunc.getReturnType(), oldFunc.isKnownToNotHaveSideEffects());
151 
152     AcquireFunctionExtras(newFunc, oldFunc);
153     AddParametersFrom(newFunc, oldFunc);
154     for (const TVariable *param : newParams)
155     {
156         newFunc.addParameter(param);
157     }
158 
159     return newFunc;
160 }
161 
CloneFunctionAndChangeReturnType(TSymbolTable & symbolTable,IdGen * idGen,const TFunction & oldFunc,const TStructure & newReturn)162 const TFunction &sh::CloneFunctionAndChangeReturnType(TSymbolTable &symbolTable,
163                                                       IdGen *idGen,
164                                                       const TFunction &oldFunc,
165                                                       const TStructure &newReturn)
166 {
167     ASSERT(oldFunc.symbolType() == SymbolType::UserDefined);
168 
169     Name newName = idGen ? idGen->createNewName(Name(oldFunc)) : Name(oldFunc);
170 
171     TType *newReturnType = new TType(&newReturn, true);
172     TFunction &newFunc   = *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
173                                         newReturnType, oldFunc.isKnownToNotHaveSideEffects());
174 
175     AcquireFunctionExtras(newFunc, oldFunc);
176     AddParametersFrom(newFunc, oldFunc);
177 
178     return newFunc;
179 }
180 
GetArg(const TIntermAggregate & call,size_t index)181 TIntermTyped &sh::GetArg(const TIntermAggregate &call, size_t index)
182 {
183     ASSERT(index < call.getChildCount());
184     TIntermNode *arg = call.getChildNode(index);
185     ASSERT(arg);
186     TIntermTyped *targ = arg->getAsTyped();
187     ASSERT(targ);
188     return *targ;
189 }
190 
SetArg(TIntermAggregate & call,size_t index,TIntermTyped & arg)191 void sh::SetArg(TIntermAggregate &call, size_t index, TIntermTyped &arg)
192 {
193     ASSERT(index < call.getChildCount());
194     (*call.getSequence())[index] = &arg;
195 }
196 
GetFieldIndex(const TStructure & structure,const ImmutableString & fieldName)197 int sh::GetFieldIndex(const TStructure &structure, const ImmutableString &fieldName)
198 {
199     const TFieldList &fieldList = structure.fields();
200 
201     int i = 0;
202     for (TField *field : fieldList)
203     {
204         if (field->name() == fieldName)
205         {
206             return i;
207         }
208         ++i;
209     }
210 
211     return -1;
212 }
213 
AccessField(const TVariable & structInstanceVar,const ImmutableString & fieldName)214 TIntermBinary &sh::AccessField(const TVariable &structInstanceVar, const ImmutableString &fieldName)
215 {
216     return AccessField(*new TIntermSymbol(&structInstanceVar), fieldName);
217 }
218 
AccessField(TIntermTyped & object,const ImmutableString & fieldName)219 TIntermBinary &sh::AccessField(TIntermTyped &object, const ImmutableString &fieldName)
220 {
221     const TStructure *structure = object.getType().getStruct();
222     ASSERT(structure);
223 
224     const int index = GetFieldIndex(*structure, fieldName);
225     ASSERT(index >= 0);
226     return AccessFieldByIndex(object, index);
227 }
228 
AccessFieldByIndex(TIntermTyped & object,int index)229 TIntermBinary &sh::AccessFieldByIndex(TIntermTyped &object, int index)
230 {
231 #if defined(ANGLE_ENABLE_ASSERTS)
232     const TType &type = object.getType();
233     ASSERT(!type.isArray());
234     const TStructure *structure = type.getStruct();
235     ASSERT(structure);
236     ASSERT(0 <= index);
237     ASSERT(static_cast<size_t>(index) < structure->fields().size());
238 #endif
239 
240     return *new TIntermBinary(
241         TOperator::EOpIndexDirectStruct, &object,
242         new TIntermConstantUnion(new TConstantUnion(index), *new TType(TBasicType::EbtInt)));
243 }
244 
AccessIndex(TIntermTyped & indexableNode,int index)245 TIntermBinary &sh::AccessIndex(TIntermTyped &indexableNode, int index)
246 {
247 #if defined(ANGLE_ENABLE_ASSERTS)
248     const TType &type = indexableNode.getType();
249     ASSERT(type.isArray() || type.isVector() || type.isMatrix());
250 #endif
251 
252     TIntermBinary *accessNode = new TIntermBinary(
253         TOperator::EOpIndexDirect, &indexableNode,
254         new TIntermConstantUnion(new TConstantUnion(index), *new TType(TBasicType::EbtInt)));
255     return *accessNode;
256 }
257 
AccessIndex(TIntermTyped & node,const int * index)258 TIntermTyped &sh::AccessIndex(TIntermTyped &node, const int *index)
259 {
260     if (index)
261     {
262         return AccessIndex(node, *index);
263     }
264     return node;
265 }
266 
SubVector(TIntermTyped & vectorNode,int begin,int end)267 TIntermTyped &sh::SubVector(TIntermTyped &vectorNode, int begin, int end)
268 {
269     ASSERT(vectorNode.getType().isVector());
270     ASSERT(0 <= begin);
271     ASSERT(end <= 4);
272     ASSERT(begin <= end);
273     if (begin == 0 && end == vectorNode.getType().getNominalSize())
274     {
275         return vectorNode;
276     }
277     TVector<int> offsets(static_cast<size_t>(end - begin));
278     std::iota(offsets.begin(), offsets.end(), begin);
279     TIntermSwizzle *swizzle = new TIntermSwizzle(vectorNode.deepCopy(), offsets);
280     return *swizzle;
281 }
282 
IsScalarBasicType(const TType & type)283 bool sh::IsScalarBasicType(const TType &type)
284 {
285     if (!type.isScalar())
286     {
287         return false;
288     }
289     return HasScalarBasicType(type);
290 }
291 
IsVectorBasicType(const TType & type)292 bool sh::IsVectorBasicType(const TType &type)
293 {
294     if (!type.isVector())
295     {
296         return false;
297     }
298     return HasScalarBasicType(type);
299 }
300 
HasScalarBasicType(TBasicType type)301 bool sh::HasScalarBasicType(TBasicType type)
302 {
303     switch (type)
304     {
305         case TBasicType::EbtFloat:
306         case TBasicType::EbtDouble:
307         case TBasicType::EbtInt:
308         case TBasicType::EbtUInt:
309         case TBasicType::EbtBool:
310             return true;
311 
312         default:
313             return false;
314     }
315 }
316 
HasScalarBasicType(const TType & type)317 bool sh::HasScalarBasicType(const TType &type)
318 {
319     return HasScalarBasicType(type.getBasicType());
320 }
321 
CloneType(const TType & type)322 TType &sh::CloneType(const TType &type)
323 {
324     TType &clone = *new TType(type);
325     return clone;
326 }
327 
InnermostType(const TType & type)328 TType &sh::InnermostType(const TType &type)
329 {
330     TType &inner = *new TType(type);
331     inner.toArrayBaseType();
332     return inner;
333 }
334 
DropColumns(const TType & matrixType)335 TType &sh::DropColumns(const TType &matrixType)
336 {
337     ASSERT(matrixType.isMatrix());
338     ASSERT(HasScalarBasicType(matrixType));
339 
340     TType &vectorType = *new TType(matrixType);
341     vectorType.toMatrixColumnType();
342     return vectorType;
343 }
344 
DropOuterDimension(const TType & arrayType)345 TType &sh::DropOuterDimension(const TType &arrayType)
346 {
347     ASSERT(arrayType.isArray());
348 
349     TType &innerType = *new TType(arrayType);
350     innerType.toArrayElementType();
351     return innerType;
352 }
353 
SetTypeDimsImpl(const TType & type,int primary,int secondary)354 static TType &SetTypeDimsImpl(const TType &type, int primary, int secondary)
355 {
356     ASSERT(1 < primary && primary <= 4);
357     ASSERT(1 <= secondary && secondary <= 4);
358     ASSERT(HasScalarBasicType(type));
359 
360     TType &newType = *new TType(type);
361     newType.setPrimarySize(primary);
362     newType.setSecondarySize(secondary);
363     return newType;
364 }
365 
SetVectorDim(const TType & type,int newDim)366 TType &sh::SetVectorDim(const TType &type, int newDim)
367 {
368     ASSERT(type.isRank0() || type.isVector());
369     return SetTypeDimsImpl(type, newDim, 1);
370 }
371 
SetMatrixRowDim(const TType & matrixType,int newDim)372 TType &sh::SetMatrixRowDim(const TType &matrixType, int newDim)
373 {
374     ASSERT(matrixType.isMatrix());
375     ASSERT(1 < newDim && newDim <= 4);
376     return SetTypeDimsImpl(matrixType, matrixType.getCols(), newDim);
377 }
378 
HasMatrixField(const TStructure & structure)379 bool sh::HasMatrixField(const TStructure &structure)
380 {
381     for (const TField *field : structure.fields())
382     {
383         const TType &type = *field->type();
384         if (type.isMatrix())
385         {
386             return true;
387         }
388     }
389     return false;
390 }
391 
HasArrayField(const TStructure & structure)392 bool sh::HasArrayField(const TStructure &structure)
393 {
394     for (const TField *field : structure.fields())
395     {
396         const TType &type = *field->type();
397         if (type.isArray())
398         {
399             return true;
400         }
401     }
402     return false;
403 }
404 
CoerceSimple(TBasicType toBasicType,TIntermTyped & fromNode,bool needsExplicitBoolCast)405 TIntermTyped &sh::CoerceSimple(TBasicType toBasicType,
406                                TIntermTyped &fromNode,
407                                bool needsExplicitBoolCast)
408 {
409     const TType &fromType = fromNode.getType();
410 
411     ASSERT(HasScalarBasicType(toBasicType));
412     ASSERT(HasScalarBasicType(fromType));
413     ASSERT(!fromType.isArray());
414 
415     const TBasicType fromBasicType = fromType.getBasicType();
416 
417     if (toBasicType != fromBasicType)
418     {
419         if (toBasicType == TBasicType::EbtBool && fromNode.isVector() && needsExplicitBoolCast)
420         {
421             switch (fromBasicType)
422             {
423                 case TBasicType::EbtFloat:
424                 case TBasicType::EbtDouble:
425                 case TBasicType::EbtInt:
426                 case TBasicType::EbtUInt:
427                 {
428                     TIntermSequence *argsSequence = new TIntermSequence();
429                     for (int i = 0; i < fromType.getNominalSize(); i++)
430                     {
431                         TIntermTyped &fromTypeSwizzle     = SubVector(fromNode, i, i + 1);
432                         TIntermAggregate *boolConstructor = TIntermAggregate::CreateConstructor(
433                             *new TType(toBasicType, 1, 1), new TIntermSequence{&fromTypeSwizzle});
434                         argsSequence->push_back(boolConstructor);
435                     }
436                     return *TIntermAggregate::CreateConstructor(
437                         *new TType(toBasicType, fromType.getNominalSize(),
438                                    fromType.getSecondarySize()),
439                         argsSequence);
440                 }
441 
442                 default:
443                     break;  // No explicit conversion needed
444             }
445         }
446 
447         return *TIntermAggregate::CreateConstructor(
448             *new TType(toBasicType, fromType.getNominalSize(), fromType.getSecondarySize()),
449             new TIntermSequence{&fromNode});
450     }
451     return fromNode;
452 }
453 
CoerceSimple(const TType & toType,TIntermTyped & fromNode,bool needsExplicitBoolCast)454 TIntermTyped &sh::CoerceSimple(const TType &toType,
455                                TIntermTyped &fromNode,
456                                bool needsExplicitBoolCast)
457 {
458     const TType &fromType = fromNode.getType();
459 
460     ASSERT(HasScalarBasicType(toType));
461     ASSERT(HasScalarBasicType(fromType));
462     ASSERT(toType.getNominalSize() == fromType.getNominalSize());
463     ASSERT(toType.getSecondarySize() == fromType.getSecondarySize());
464     ASSERT(!toType.isArray());
465     ASSERT(!fromType.isArray());
466 
467     const TBasicType toBasicType   = toType.getBasicType();
468     const TBasicType fromBasicType = fromType.getBasicType();
469 
470     if (toBasicType != fromBasicType)
471     {
472         if (toBasicType == TBasicType::EbtBool && fromNode.isVector() && needsExplicitBoolCast)
473         {
474             switch (fromBasicType)
475             {
476                 case TBasicType::EbtFloat:
477                 case TBasicType::EbtDouble:
478                 case TBasicType::EbtInt:
479                 case TBasicType::EbtUInt:
480                 {
481                     TIntermSequence *argsSequence = new TIntermSequence();
482                     for (int i = 0; i < fromType.getNominalSize(); i++)
483                     {
484                         TIntermTyped &fromTypeSwizzle     = SubVector(fromNode, i, i + 1);
485                         TIntermAggregate *boolConstructor = TIntermAggregate::CreateConstructor(
486                             *new TType(toBasicType, 1, 1), new TIntermSequence{&fromTypeSwizzle});
487                         argsSequence->push_back(boolConstructor);
488                     }
489                     return *TIntermAggregate::CreateConstructor(
490                         *new TType(toBasicType, fromType.getNominalSize(),
491                                    fromType.getSecondarySize()),
492                         new TIntermSequence{*argsSequence});
493                 }
494 
495                 default:
496                     break;  // No explicit conversion needed
497             }
498         }
499 
500         return *TIntermAggregate::CreateConstructor(toType, new TIntermSequence{&fromNode});
501     }
502     return fromNode;
503 }
504 
AsType(SymbolEnv & symbolEnv,const TType & toType,TIntermTyped & fromNode)505 TIntermTyped &sh::AsType(SymbolEnv &symbolEnv, const TType &toType, TIntermTyped &fromNode)
506 {
507     const TType &fromType = fromNode.getType();
508 
509     ASSERT(HasScalarBasicType(toType));
510     ASSERT(HasScalarBasicType(fromType));
511     ASSERT(!toType.isArray());
512     ASSERT(!fromType.isArray());
513 
514     if (toType == fromType)
515     {
516         return fromNode;
517     }
518     TemplateArg targ(toType);
519     return symbolEnv.callFunctionOverload(Name("as_type", SymbolType::BuiltIn), toType,
520                                           *new TIntermSequence{&fromNode}, 1, &targ);
521 }
522