• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include <cstring>
8 #include <numeric>
9 #include <unordered_map>
10 #include <unordered_set>
11 
12 #include "compiler/translator/msl/AstHelpers.h"
13 
14 using namespace sh;
15 
16 ////////////////////////////////////////////////////////////////////////////////
17 
CreateStructTypeVariable(TSymbolTable & symbolTable,const TStructure & structure)18 const TVariable &sh::CreateStructTypeVariable(TSymbolTable &symbolTable,
19                                               const TStructure &structure)
20 {
21     TType *type    = new TType(&structure, true);
22     TVariable *var = new TVariable(&symbolTable, ImmutableString(""), type, SymbolType::Empty);
23     return *var;
24 }
25 
CreateInstanceVariable(TSymbolTable & symbolTable,const TStructure & structure,const Name & name,TQualifier qualifier,const TSpan<const unsigned int> * arraySizes)26 const TVariable &sh::CreateInstanceVariable(TSymbolTable &symbolTable,
27                                             const TStructure &structure,
28                                             const Name &name,
29                                             TQualifier qualifier,
30                                             const TSpan<const unsigned int> *arraySizes)
31 {
32     TType *type = new TType(&structure, false);
33     type->setQualifier(qualifier);
34     if (arraySizes)
35     {
36         type->makeArrays(*arraySizes);
37     }
38     TVariable *var = new TVariable(&symbolTable, name.rawName(), type, name.symbolType());
39     return *var;
40 }
41 
AcquireFunctionExtras(TFunction & dest,const TFunction & src)42 static void AcquireFunctionExtras(TFunction &dest, const TFunction &src)
43 {
44     if (src.isDefined())
45     {
46         dest.setDefined();
47     }
48 
49     if (src.hasPrototypeDeclaration())
50     {
51         dest.setHasPrototypeDeclaration();
52     }
53 }
54 
CloneSequenceAndPrepend(const TIntermSequence & seq,TIntermNode & node)55 TIntermSequence &sh::CloneSequenceAndPrepend(const TIntermSequence &seq, TIntermNode &node)
56 {
57     TIntermSequence *newSeq = new TIntermSequence();
58     newSeq->push_back(&node);
59 
60     for (TIntermNode *oldNode : seq)
61     {
62         newSeq->push_back(oldNode);
63     }
64 
65     return *newSeq;
66 }
67 
AddParametersFrom(TFunction & dest,const TFunction & src)68 void sh::AddParametersFrom(TFunction &dest, const TFunction &src)
69 {
70     const size_t paramCount = src.getParamCount();
71     for (size_t i = 0; i < paramCount; ++i)
72     {
73         const TVariable *var = src.getParam(i);
74         dest.addParameter(var);
75     }
76 }
77 
CloneFunction(TSymbolTable & symbolTable,IdGen & idGen,const TFunction & oldFunc)78 const TFunction &sh::CloneFunction(TSymbolTable &symbolTable,
79                                    IdGen &idGen,
80                                    const TFunction &oldFunc)
81 {
82     ASSERT(oldFunc.symbolType() == SymbolType::UserDefined);
83 
84     Name newName = idGen.createNewName(Name(oldFunc));
85 
86     TFunction &newFunc =
87         *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
88                        &oldFunc.getReturnType(), oldFunc.isKnownToNotHaveSideEffects());
89 
90     AcquireFunctionExtras(newFunc, oldFunc);
91     AddParametersFrom(newFunc, oldFunc);
92 
93     return newFunc;
94 }
95 
CloneFunctionAndPrependParam(TSymbolTable & symbolTable,IdGen * idGen,const TFunction & oldFunc,const TVariable & newParam)96 const TFunction &sh::CloneFunctionAndPrependParam(TSymbolTable &symbolTable,
97                                                   IdGen *idGen,
98                                                   const TFunction &oldFunc,
99                                                   const TVariable &newParam)
100 {
101     ASSERT(oldFunc.symbolType() == SymbolType::UserDefined ||
102            oldFunc.symbolType() == SymbolType::AngleInternal);
103 
104     Name newName = idGen ? idGen->createNewName(Name(oldFunc)) : Name(oldFunc);
105 
106     TFunction &newFunc =
107         *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
108                        &oldFunc.getReturnType(), oldFunc.isKnownToNotHaveSideEffects());
109 
110     AcquireFunctionExtras(newFunc, oldFunc);
111     newFunc.addParameter(&newParam);
112     AddParametersFrom(newFunc, oldFunc);
113 
114     return newFunc;
115 }
116 
CloneFunctionAndPrependTwoParams(TSymbolTable & symbolTable,IdGen * idGen,const TFunction & oldFunc,const TVariable & newParam1,const TVariable & newParam2)117 const TFunction &sh::CloneFunctionAndPrependTwoParams(TSymbolTable &symbolTable,
118                                                       IdGen *idGen,
119                                                       const TFunction &oldFunc,
120                                                       const TVariable &newParam1,
121                                                       const TVariable &newParam2)
122 {
123     ASSERT(oldFunc.symbolType() == SymbolType::UserDefined ||
124            oldFunc.symbolType() == SymbolType::AngleInternal);
125 
126     Name newName = idGen ? idGen->createNewName(Name(oldFunc)) : Name(oldFunc);
127 
128     TFunction &newFunc =
129         *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
130                        &oldFunc.getReturnType(), oldFunc.isKnownToNotHaveSideEffects());
131 
132     AcquireFunctionExtras(newFunc, oldFunc);
133     newFunc.addParameter(&newParam1);
134     newFunc.addParameter(&newParam2);
135     AddParametersFrom(newFunc, oldFunc);
136 
137     return newFunc;
138 }
139 
CloneFunctionAndAppendParams(TSymbolTable & symbolTable,IdGen * idGen,const TFunction & oldFunc,const std::vector<const TVariable * > & newParams)140 const TFunction &sh::CloneFunctionAndAppendParams(TSymbolTable &symbolTable,
141                                                   IdGen *idGen,
142                                                   const TFunction &oldFunc,
143                                                   const std::vector<const TVariable *> &newParams)
144 {
145     ASSERT(oldFunc.symbolType() == SymbolType::UserDefined ||
146            oldFunc.symbolType() == SymbolType::AngleInternal);
147 
148     Name newName = idGen ? idGen->createNewName(Name(oldFunc)) : Name(oldFunc);
149 
150     TFunction &newFunc =
151         *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
152                        &oldFunc.getReturnType(), oldFunc.isKnownToNotHaveSideEffects());
153 
154     AcquireFunctionExtras(newFunc, oldFunc);
155     AddParametersFrom(newFunc, oldFunc);
156     for (const TVariable *param : newParams)
157     {
158         newFunc.addParameter(param);
159     }
160 
161     return newFunc;
162 }
163 
CloneFunctionAndChangeReturnType(TSymbolTable & symbolTable,IdGen * idGen,const TFunction & oldFunc,const TStructure & newReturn)164 const TFunction &sh::CloneFunctionAndChangeReturnType(TSymbolTable &symbolTable,
165                                                       IdGen *idGen,
166                                                       const TFunction &oldFunc,
167                                                       const TStructure &newReturn)
168 {
169     ASSERT(oldFunc.symbolType() == SymbolType::UserDefined);
170 
171     Name newName = idGen ? idGen->createNewName(Name(oldFunc)) : Name(oldFunc);
172 
173     TType *newReturnType = new TType(&newReturn, true);
174     TFunction &newFunc   = *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
175                                           newReturnType, oldFunc.isKnownToNotHaveSideEffects());
176 
177     AcquireFunctionExtras(newFunc, oldFunc);
178     AddParametersFrom(newFunc, oldFunc);
179 
180     return newFunc;
181 }
182 
GetArg(const TIntermAggregate & call,size_t index)183 TIntermTyped &sh::GetArg(const TIntermAggregate &call, size_t index)
184 {
185     ASSERT(index < call.getChildCount());
186     TIntermNode *arg = call.getChildNode(index);
187     ASSERT(arg);
188     TIntermTyped *targ = arg->getAsTyped();
189     ASSERT(targ);
190     return *targ;
191 }
192 
SetArg(TIntermAggregate & call,size_t index,TIntermTyped & arg)193 void sh::SetArg(TIntermAggregate &call, size_t index, TIntermTyped &arg)
194 {
195     ASSERT(index < call.getChildCount());
196     (*call.getSequence())[index] = &arg;
197 }
198 
AccessField(const TVariable & structInstanceVar,const Name & name)199 TIntermBinary &sh::AccessField(const TVariable &structInstanceVar, const Name &name)
200 {
201     return AccessField(*new TIntermSymbol(&structInstanceVar), name);
202 }
203 
AccessField(TIntermTyped & object,const Name & name)204 TIntermBinary &sh::AccessField(TIntermTyped &object, const Name &name)
205 {
206     const TStructure *structure = object.getType().getStruct();
207     ASSERT(structure);
208     const TFieldList &fieldList = structure->fields();
209     for (int i = 0; i < static_cast<int>(fieldList.size()); ++i)
210     {
211         TField *current = fieldList[i];
212         if (Name(*current) == name)
213         {
214             return AccessFieldByIndex(object, i);
215         }
216     }
217     UNREACHABLE();
218     return AccessFieldByIndex(object, -1);
219 }
220 
AccessFieldByIndex(TIntermTyped & object,int index)221 TIntermBinary &sh::AccessFieldByIndex(TIntermTyped &object, int index)
222 {
223 #if defined(ANGLE_ENABLE_ASSERTS)
224     const TType &type = object.getType();
225     ASSERT(!type.isArray());
226     const TStructure *structure = type.getStruct();
227     ASSERT(structure);
228     ASSERT(0 <= index);
229     ASSERT(static_cast<size_t>(index) < structure->fields().size());
230 #endif
231 
232     return *new TIntermBinary(
233         TOperator::EOpIndexDirectStruct, &object,
234         new TIntermConstantUnion(new TConstantUnion(index), *new TType(TBasicType::EbtInt)));
235 }
236 
AccessIndex(TIntermTyped & indexableNode,int index)237 TIntermBinary &sh::AccessIndex(TIntermTyped &indexableNode, int index)
238 {
239 #if defined(ANGLE_ENABLE_ASSERTS)
240     const TType &type = indexableNode.getType();
241     ASSERT(type.isArray() || type.isVector() || type.isMatrix());
242 #endif
243 
244     TIntermBinary *accessNode = new TIntermBinary(
245         TOperator::EOpIndexDirect, &indexableNode,
246         new TIntermConstantUnion(new TConstantUnion(index), *new TType(TBasicType::EbtInt)));
247     return *accessNode;
248 }
249 
AccessIndex(TIntermTyped & node,const int * index)250 TIntermTyped &sh::AccessIndex(TIntermTyped &node, const int *index)
251 {
252     if (index)
253     {
254         return AccessIndex(node, *index);
255     }
256     return node;
257 }
258 
SubVector(TIntermTyped & vectorNode,int begin,int end)259 TIntermTyped &sh::SubVector(TIntermTyped &vectorNode, int begin, int end)
260 {
261     ASSERT(vectorNode.getType().isVector());
262     ASSERT(0 <= begin);
263     ASSERT(end <= 4);
264     ASSERT(begin <= end);
265     if (begin == 0 && end == vectorNode.getType().getNominalSize())
266     {
267         return vectorNode;
268     }
269     TVector<int> offsets(static_cast<size_t>(end - begin));
270     std::iota(offsets.begin(), offsets.end(), begin);
271     TIntermSwizzle *swizzle = new TIntermSwizzle(vectorNode.deepCopy(), offsets);
272     return *swizzle;
273 }
274 
IsScalarBasicType(const TType & type)275 bool sh::IsScalarBasicType(const TType &type)
276 {
277     if (!type.isScalar())
278     {
279         return false;
280     }
281     return HasScalarBasicType(type);
282 }
283 
IsVectorBasicType(const TType & type)284 bool sh::IsVectorBasicType(const TType &type)
285 {
286     if (!type.isVector())
287     {
288         return false;
289     }
290     return HasScalarBasicType(type);
291 }
292 
HasScalarBasicType(TBasicType type)293 bool sh::HasScalarBasicType(TBasicType type)
294 {
295     switch (type)
296     {
297         case TBasicType::EbtFloat:
298         case TBasicType::EbtDouble:
299         case TBasicType::EbtInt:
300         case TBasicType::EbtUInt:
301         case TBasicType::EbtBool:
302             return true;
303 
304         default:
305             return false;
306     }
307 }
308 
HasScalarBasicType(const TType & type)309 bool sh::HasScalarBasicType(const TType &type)
310 {
311     return HasScalarBasicType(type.getBasicType());
312 }
313 
CloneType(const TType & type)314 TType &sh::CloneType(const TType &type)
315 {
316     TType &clone = *new TType(type);
317     return clone;
318 }
319 
InnermostType(const TType & type)320 TType &sh::InnermostType(const TType &type)
321 {
322     TType &inner = *new TType(type);
323     inner.toArrayBaseType();
324     return inner;
325 }
326 
DropColumns(const TType & matrixType)327 TType &sh::DropColumns(const TType &matrixType)
328 {
329     ASSERT(matrixType.isMatrix());
330     ASSERT(HasScalarBasicType(matrixType));
331 
332     TType &vectorType = *new TType(matrixType);
333     vectorType.toMatrixColumnType();
334     return vectorType;
335 }
336 
DropOuterDimension(const TType & arrayType)337 TType &sh::DropOuterDimension(const TType &arrayType)
338 {
339     ASSERT(arrayType.isArray());
340 
341     TType &innerType = *new TType(arrayType);
342     innerType.toArrayElementType();
343     return innerType;
344 }
345 
SetTypeDimsImpl(const TType & type,int primary,int secondary)346 static TType &SetTypeDimsImpl(const TType &type, int primary, int secondary)
347 {
348     ASSERT(1 < primary && primary <= 4);
349     ASSERT(1 <= secondary && secondary <= 4);
350     ASSERT(HasScalarBasicType(type));
351 
352     TType &newType = *new TType(type);
353     newType.setPrimarySize(primary);
354     newType.setSecondarySize(secondary);
355     return newType;
356 }
357 
SetVectorDim(const TType & type,int newDim)358 TType &sh::SetVectorDim(const TType &type, int newDim)
359 {
360     ASSERT(type.isRank0() || type.isVector());
361     return SetTypeDimsImpl(type, newDim, 1);
362 }
363 
SetMatrixRowDim(const TType & matrixType,int newDim)364 TType &sh::SetMatrixRowDim(const TType &matrixType, int newDim)
365 {
366     ASSERT(matrixType.isMatrix());
367     ASSERT(1 < newDim && newDim <= 4);
368     return SetTypeDimsImpl(matrixType, matrixType.getCols(), newDim);
369 }
370 
HasMatrixField(const TStructure & structure)371 bool sh::HasMatrixField(const TStructure &structure)
372 {
373     for (const TField *field : structure.fields())
374     {
375         const TType &type = *field->type();
376         if (type.isMatrix())
377         {
378             return true;
379         }
380     }
381     return false;
382 }
383 
HasArrayField(const TStructure & structure)384 bool sh::HasArrayField(const TStructure &structure)
385 {
386     for (const TField *field : structure.fields())
387     {
388         const TType &type = *field->type();
389         if (type.isArray())
390         {
391             return true;
392         }
393     }
394     return false;
395 }
396 
CoerceSimple(TBasicType toBasicType,TIntermTyped & fromNode,bool needsExplicitBoolCast)397 TIntermTyped &sh::CoerceSimple(TBasicType toBasicType,
398                                TIntermTyped &fromNode,
399                                bool needsExplicitBoolCast)
400 {
401     const TType &fromType = fromNode.getType();
402 
403     ASSERT(HasScalarBasicType(toBasicType));
404     ASSERT(HasScalarBasicType(fromType));
405     ASSERT(!fromType.isArray());
406 
407     const TBasicType fromBasicType = fromType.getBasicType();
408 
409     if (toBasicType != fromBasicType)
410     {
411         if (toBasicType == TBasicType::EbtBool && fromNode.isVector() && needsExplicitBoolCast)
412         {
413             switch (fromBasicType)
414             {
415                 case TBasicType::EbtFloat:
416                 case TBasicType::EbtDouble:
417                 case TBasicType::EbtInt:
418                 case TBasicType::EbtUInt:
419                 {
420                     TIntermSequence *argsSequence = new TIntermSequence();
421                     for (uint8_t i = 0; i < fromType.getNominalSize(); i++)
422                     {
423                         TIntermTyped &fromTypeSwizzle     = SubVector(fromNode, i, i + 1);
424                         TIntermAggregate *boolConstructor = TIntermAggregate::CreateConstructor(
425                             *new TType(toBasicType, 1, 1), new TIntermSequence{&fromTypeSwizzle});
426                         argsSequence->push_back(boolConstructor);
427                     }
428                     return *TIntermAggregate::CreateConstructor(
429                         *new TType(toBasicType, fromType.getNominalSize(),
430                                    fromType.getSecondarySize()),
431                         argsSequence);
432                 }
433 
434                 default:
435                     break;  // No explicit conversion needed
436             }
437         }
438 
439         return *TIntermAggregate::CreateConstructor(
440             *new TType(toBasicType, fromType.getNominalSize(), fromType.getSecondarySize()),
441             new TIntermSequence{&fromNode});
442     }
443     return fromNode;
444 }
445 
CoerceSimple(const TType & toType,TIntermTyped & fromNode,bool needsExplicitBoolCast)446 TIntermTyped &sh::CoerceSimple(const TType &toType,
447                                TIntermTyped &fromNode,
448                                bool needsExplicitBoolCast)
449 {
450     const TType &fromType = fromNode.getType();
451 
452     ASSERT(HasScalarBasicType(toType));
453     ASSERT(HasScalarBasicType(fromType));
454     ASSERT(toType.getNominalSize() == fromType.getNominalSize());
455     ASSERT(toType.getSecondarySize() == fromType.getSecondarySize());
456     ASSERT(!toType.isArray());
457     ASSERT(!fromType.isArray());
458 
459     const TBasicType toBasicType   = toType.getBasicType();
460     const TBasicType fromBasicType = fromType.getBasicType();
461 
462     if (toBasicType != fromBasicType)
463     {
464         if (toBasicType == TBasicType::EbtBool && fromNode.isVector() && needsExplicitBoolCast)
465         {
466             switch (fromBasicType)
467             {
468                 case TBasicType::EbtFloat:
469                 case TBasicType::EbtDouble:
470                 case TBasicType::EbtInt:
471                 case TBasicType::EbtUInt:
472                 {
473                     TIntermSequence *argsSequence = new TIntermSequence();
474                     for (uint8_t i = 0; i < fromType.getNominalSize(); i++)
475                     {
476                         TIntermTyped &fromTypeSwizzle     = SubVector(fromNode, i, i + 1);
477                         TIntermAggregate *boolConstructor = TIntermAggregate::CreateConstructor(
478                             *new TType(toBasicType, 1, 1), new TIntermSequence{&fromTypeSwizzle});
479                         argsSequence->push_back(boolConstructor);
480                     }
481                     return *TIntermAggregate::CreateConstructor(
482                         *new TType(toBasicType, fromType.getNominalSize(),
483                                    fromType.getSecondarySize()),
484                         new TIntermSequence{*argsSequence});
485                 }
486 
487                 default:
488                     break;  // No explicit conversion needed
489             }
490         }
491 
492         return *TIntermAggregate::CreateConstructor(toType, new TIntermSequence{&fromNode});
493     }
494     return fromNode;
495 }
496 
AsType(SymbolEnv & symbolEnv,const TType & toType,TIntermTyped & fromNode)497 TIntermTyped &sh::AsType(SymbolEnv &symbolEnv, const TType &toType, TIntermTyped &fromNode)
498 {
499     const TType &fromType = fromNode.getType();
500 
501     ASSERT(HasScalarBasicType(toType));
502     ASSERT(HasScalarBasicType(fromType));
503     ASSERT(!toType.isArray());
504     ASSERT(!fromType.isArray());
505 
506     if (toType == fromType)
507     {
508         return fromNode;
509     }
510     TemplateArg targ(toType);
511     return symbolEnv.callFunctionOverload(Name("as_type", SymbolType::BuiltIn), toType,
512                                           *new TIntermSequence{&fromNode}, 1, &targ);
513 }
514