1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include <cstring>
8 #include <numeric>
9 #include <unordered_map>
10 #include <unordered_set>
11
12 #include "compiler/translator/msl/AstHelpers.h"
13
14 using namespace sh;
15
16 ////////////////////////////////////////////////////////////////////////////////
17
ViewDeclaration(TIntermDeclaration & declNode)18 Declaration sh::ViewDeclaration(TIntermDeclaration &declNode)
19 {
20 ASSERT(declNode.getChildCount() == 1);
21 TIntermNode *childNode = declNode.getChildNode(0);
22 ASSERT(childNode);
23 TIntermSymbol *symbolNode;
24 if ((symbolNode = childNode->getAsSymbolNode()))
25 {
26 return {*symbolNode, nullptr};
27 }
28 else
29 {
30 TIntermBinary *initNode = childNode->getAsBinaryNode();
31 ASSERT(initNode);
32 ASSERT(initNode->getOp() == TOperator::EOpInitialize);
33 symbolNode = initNode->getLeft()->getAsSymbolNode();
34 ASSERT(symbolNode);
35 return {*symbolNode, initNode->getRight()};
36 }
37 }
38
CreateStructTypeVariable(TSymbolTable & symbolTable,const TStructure & structure)39 const TVariable &sh::CreateStructTypeVariable(TSymbolTable &symbolTable,
40 const TStructure &structure)
41 {
42 TType *type = new TType(&structure, true);
43 TVariable *var = new TVariable(&symbolTable, ImmutableString(""), type, SymbolType::Empty);
44 return *var;
45 }
46
CreateInstanceVariable(TSymbolTable & symbolTable,const TStructure & structure,const Name & name,TQualifier qualifier,const TSpan<const unsigned int> * arraySizes)47 const TVariable &sh::CreateInstanceVariable(TSymbolTable &symbolTable,
48 const TStructure &structure,
49 const Name &name,
50 TQualifier qualifier,
51 const TSpan<const unsigned int> *arraySizes)
52 {
53 TType *type = new TType(&structure, false);
54 type->setQualifier(qualifier);
55 if (arraySizes)
56 {
57 type->makeArrays(*arraySizes);
58 }
59 TVariable *var = new TVariable(&symbolTable, name.rawName(), type, name.symbolType());
60 return *var;
61 }
62
AcquireFunctionExtras(TFunction & dest,const TFunction & src)63 static void AcquireFunctionExtras(TFunction &dest, const TFunction &src)
64 {
65 if (src.isDefined())
66 {
67 dest.setDefined();
68 }
69
70 if (src.hasPrototypeDeclaration())
71 {
72 dest.setHasPrototypeDeclaration();
73 }
74 }
75
CloneSequenceAndPrepend(const TIntermSequence & seq,TIntermNode & node)76 TIntermSequence &sh::CloneSequenceAndPrepend(const TIntermSequence &seq, TIntermNode &node)
77 {
78 TIntermSequence *newSeq = new TIntermSequence();
79 newSeq->push_back(&node);
80
81 for (TIntermNode *oldNode : seq)
82 {
83 newSeq->push_back(oldNode);
84 }
85
86 return *newSeq;
87 }
88
AddParametersFrom(TFunction & dest,const TFunction & src)89 void sh::AddParametersFrom(TFunction &dest, const TFunction &src)
90 {
91 const size_t paramCount = src.getParamCount();
92 for (size_t i = 0; i < paramCount; ++i)
93 {
94 const TVariable *var = src.getParam(i);
95 dest.addParameter(var);
96 }
97 }
98
CloneFunction(TSymbolTable & symbolTable,IdGen & idGen,const TFunction & oldFunc)99 const TFunction &sh::CloneFunction(TSymbolTable &symbolTable,
100 IdGen &idGen,
101 const TFunction &oldFunc)
102 {
103 ASSERT(oldFunc.symbolType() == SymbolType::UserDefined);
104
105 Name newName = idGen.createNewName(Name(oldFunc));
106
107 TFunction &newFunc =
108 *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
109 &oldFunc.getReturnType(), oldFunc.isKnownToNotHaveSideEffects());
110
111 AcquireFunctionExtras(newFunc, oldFunc);
112 AddParametersFrom(newFunc, oldFunc);
113
114 return newFunc;
115 }
116
CloneFunctionAndPrependParam(TSymbolTable & symbolTable,IdGen * idGen,const TFunction & oldFunc,const TVariable & newParam)117 const TFunction &sh::CloneFunctionAndPrependParam(TSymbolTable &symbolTable,
118 IdGen *idGen,
119 const TFunction &oldFunc,
120 const TVariable &newParam)
121 {
122 ASSERT(oldFunc.symbolType() == SymbolType::UserDefined ||
123 oldFunc.symbolType() == SymbolType::AngleInternal);
124
125 Name newName = idGen ? idGen->createNewName(Name(oldFunc)) : Name(oldFunc);
126
127 TFunction &newFunc =
128 *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
129 &oldFunc.getReturnType(), oldFunc.isKnownToNotHaveSideEffects());
130
131 AcquireFunctionExtras(newFunc, oldFunc);
132 newFunc.addParameter(&newParam);
133 AddParametersFrom(newFunc, oldFunc);
134
135 return newFunc;
136 }
137
CloneFunctionAndPrependTwoParams(TSymbolTable & symbolTable,IdGen * idGen,const TFunction & oldFunc,const TVariable & newParam1,const TVariable & newParam2)138 const TFunction &sh::CloneFunctionAndPrependTwoParams(TSymbolTable &symbolTable,
139 IdGen *idGen,
140 const TFunction &oldFunc,
141 const TVariable &newParam1,
142 const TVariable &newParam2)
143 {
144 ASSERT(oldFunc.symbolType() == SymbolType::UserDefined ||
145 oldFunc.symbolType() == SymbolType::AngleInternal);
146
147 Name newName = idGen ? idGen->createNewName(Name(oldFunc)) : Name(oldFunc);
148
149 TFunction &newFunc =
150 *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
151 &oldFunc.getReturnType(), oldFunc.isKnownToNotHaveSideEffects());
152
153 AcquireFunctionExtras(newFunc, oldFunc);
154 newFunc.addParameter(&newParam1);
155 newFunc.addParameter(&newParam2);
156 AddParametersFrom(newFunc, oldFunc);
157
158 return newFunc;
159 }
160
CloneFunctionAndAppendParams(TSymbolTable & symbolTable,IdGen * idGen,const TFunction & oldFunc,const std::vector<const TVariable * > & newParams)161 const TFunction &sh::CloneFunctionAndAppendParams(TSymbolTable &symbolTable,
162 IdGen *idGen,
163 const TFunction &oldFunc,
164 const std::vector<const TVariable *> &newParams)
165 {
166 ASSERT(oldFunc.symbolType() == SymbolType::UserDefined ||
167 oldFunc.symbolType() == SymbolType::AngleInternal);
168
169 Name newName = idGen ? idGen->createNewName(Name(oldFunc)) : Name(oldFunc);
170
171 TFunction &newFunc =
172 *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
173 &oldFunc.getReturnType(), oldFunc.isKnownToNotHaveSideEffects());
174
175 AcquireFunctionExtras(newFunc, oldFunc);
176 AddParametersFrom(newFunc, oldFunc);
177 for (const TVariable *param : newParams)
178 {
179 newFunc.addParameter(param);
180 }
181
182 return newFunc;
183 }
184
CloneFunctionAndChangeReturnType(TSymbolTable & symbolTable,IdGen * idGen,const TFunction & oldFunc,const TStructure & newReturn)185 const TFunction &sh::CloneFunctionAndChangeReturnType(TSymbolTable &symbolTable,
186 IdGen *idGen,
187 const TFunction &oldFunc,
188 const TStructure &newReturn)
189 {
190 ASSERT(oldFunc.symbolType() == SymbolType::UserDefined);
191
192 Name newName = idGen ? idGen->createNewName(Name(oldFunc)) : Name(oldFunc);
193
194 TType *newReturnType = new TType(&newReturn, true);
195 TFunction &newFunc = *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
196 newReturnType, oldFunc.isKnownToNotHaveSideEffects());
197
198 AcquireFunctionExtras(newFunc, oldFunc);
199 AddParametersFrom(newFunc, oldFunc);
200
201 return newFunc;
202 }
203
GetArg(const TIntermAggregate & call,size_t index)204 TIntermTyped &sh::GetArg(const TIntermAggregate &call, size_t index)
205 {
206 ASSERT(index < call.getChildCount());
207 TIntermNode *arg = call.getChildNode(index);
208 ASSERT(arg);
209 TIntermTyped *targ = arg->getAsTyped();
210 ASSERT(targ);
211 return *targ;
212 }
213
SetArg(TIntermAggregate & call,size_t index,TIntermTyped & arg)214 void sh::SetArg(TIntermAggregate &call, size_t index, TIntermTyped &arg)
215 {
216 ASSERT(index < call.getChildCount());
217 (*call.getSequence())[index] = &arg;
218 }
219
GetFieldIndex(const TStructure & structure,const ImmutableString & fieldName)220 int sh::GetFieldIndex(const TStructure &structure, const ImmutableString &fieldName)
221 {
222 const TFieldList &fieldList = structure.fields();
223
224 int i = 0;
225 for (TField *field : fieldList)
226 {
227 if (field->name() == fieldName)
228 {
229 return i;
230 }
231 ++i;
232 }
233
234 return -1;
235 }
236
AccessField(const TVariable & structInstanceVar,const ImmutableString & fieldName)237 TIntermBinary &sh::AccessField(const TVariable &structInstanceVar, const ImmutableString &fieldName)
238 {
239 return AccessField(*new TIntermSymbol(&structInstanceVar), fieldName);
240 }
241
AccessField(TIntermTyped & object,const ImmutableString & fieldName)242 TIntermBinary &sh::AccessField(TIntermTyped &object, const ImmutableString &fieldName)
243 {
244 const TStructure *structure = object.getType().getStruct();
245 ASSERT(structure);
246
247 const int index = GetFieldIndex(*structure, fieldName);
248 ASSERT(index >= 0);
249 return AccessFieldByIndex(object, index);
250 }
251
AccessFieldByIndex(TIntermTyped & object,int index)252 TIntermBinary &sh::AccessFieldByIndex(TIntermTyped &object, int index)
253 {
254 #if defined(ANGLE_ENABLE_ASSERTS)
255 const TType &type = object.getType();
256 ASSERT(!type.isArray());
257 const TStructure *structure = type.getStruct();
258 ASSERT(structure);
259 ASSERT(0 <= index);
260 ASSERT(static_cast<size_t>(index) < structure->fields().size());
261 #endif
262
263 return *new TIntermBinary(
264 TOperator::EOpIndexDirectStruct, &object,
265 new TIntermConstantUnion(new TConstantUnion(index), *new TType(TBasicType::EbtInt)));
266 }
267
AccessIndex(TIntermTyped & indexableNode,int index)268 TIntermBinary &sh::AccessIndex(TIntermTyped &indexableNode, int index)
269 {
270 #if defined(ANGLE_ENABLE_ASSERTS)
271 const TType &type = indexableNode.getType();
272 ASSERT(type.isArray() || type.isVector() || type.isMatrix());
273 #endif
274
275 TIntermBinary *accessNode = new TIntermBinary(
276 TOperator::EOpIndexDirect, &indexableNode,
277 new TIntermConstantUnion(new TConstantUnion(index), *new TType(TBasicType::EbtInt)));
278 return *accessNode;
279 }
280
AccessIndex(TIntermTyped & node,const int * index)281 TIntermTyped &sh::AccessIndex(TIntermTyped &node, const int *index)
282 {
283 if (index)
284 {
285 return AccessIndex(node, *index);
286 }
287 return node;
288 }
289
SubVector(TIntermTyped & vectorNode,int begin,int end)290 TIntermTyped &sh::SubVector(TIntermTyped &vectorNode, int begin, int end)
291 {
292 ASSERT(vectorNode.getType().isVector());
293 ASSERT(0 <= begin);
294 ASSERT(end <= 4);
295 ASSERT(begin <= end);
296 if (begin == 0 && end == vectorNode.getType().getNominalSize())
297 {
298 return vectorNode;
299 }
300 TVector<int> offsets(static_cast<size_t>(end - begin));
301 std::iota(offsets.begin(), offsets.end(), begin);
302 TIntermSwizzle *swizzle = new TIntermSwizzle(vectorNode.deepCopy(), offsets);
303 return *swizzle;
304 }
305
IsScalarBasicType(const TType & type)306 bool sh::IsScalarBasicType(const TType &type)
307 {
308 if (!type.isScalar())
309 {
310 return false;
311 }
312 return HasScalarBasicType(type);
313 }
314
IsVectorBasicType(const TType & type)315 bool sh::IsVectorBasicType(const TType &type)
316 {
317 if (!type.isVector())
318 {
319 return false;
320 }
321 return HasScalarBasicType(type);
322 }
323
HasScalarBasicType(TBasicType type)324 bool sh::HasScalarBasicType(TBasicType type)
325 {
326 switch (type)
327 {
328 case TBasicType::EbtFloat:
329 case TBasicType::EbtDouble:
330 case TBasicType::EbtInt:
331 case TBasicType::EbtUInt:
332 case TBasicType::EbtBool:
333 return true;
334
335 default:
336 return false;
337 }
338 }
339
HasScalarBasicType(const TType & type)340 bool sh::HasScalarBasicType(const TType &type)
341 {
342 return HasScalarBasicType(type.getBasicType());
343 }
344
CloneType(const TType & type)345 TType &sh::CloneType(const TType &type)
346 {
347 TType &clone = *new TType(type);
348 return clone;
349 }
350
InnermostType(const TType & type)351 TType &sh::InnermostType(const TType &type)
352 {
353 TType &inner = *new TType(type);
354 inner.toArrayBaseType();
355 return inner;
356 }
357
DropColumns(const TType & matrixType)358 TType &sh::DropColumns(const TType &matrixType)
359 {
360 ASSERT(matrixType.isMatrix());
361 ASSERT(HasScalarBasicType(matrixType));
362
363 TType &vectorType = *new TType(matrixType);
364 vectorType.toMatrixColumnType();
365 return vectorType;
366 }
367
DropOuterDimension(const TType & arrayType)368 TType &sh::DropOuterDimension(const TType &arrayType)
369 {
370 ASSERT(arrayType.isArray());
371
372 TType &innerType = *new TType(arrayType);
373 innerType.toArrayElementType();
374 return innerType;
375 }
376
SetTypeDimsImpl(const TType & type,int primary,int secondary)377 static TType &SetTypeDimsImpl(const TType &type, int primary, int secondary)
378 {
379 ASSERT(1 < primary && primary <= 4);
380 ASSERT(1 <= secondary && secondary <= 4);
381 ASSERT(HasScalarBasicType(type));
382
383 TType &newType = *new TType(type);
384 newType.setPrimarySize(primary);
385 newType.setSecondarySize(secondary);
386 return newType;
387 }
388
SetVectorDim(const TType & type,int newDim)389 TType &sh::SetVectorDim(const TType &type, int newDim)
390 {
391 ASSERT(type.isRank0() || type.isVector());
392 return SetTypeDimsImpl(type, newDim, 1);
393 }
394
SetMatrixRowDim(const TType & matrixType,int newDim)395 TType &sh::SetMatrixRowDim(const TType &matrixType, int newDim)
396 {
397 ASSERT(matrixType.isMatrix());
398 ASSERT(1 < newDim && newDim <= 4);
399 return SetTypeDimsImpl(matrixType, matrixType.getCols(), newDim);
400 }
401
HasMatrixField(const TStructure & structure)402 bool sh::HasMatrixField(const TStructure &structure)
403 {
404 for (const TField *field : structure.fields())
405 {
406 const TType &type = *field->type();
407 if (type.isMatrix())
408 {
409 return true;
410 }
411 }
412 return false;
413 }
414
HasArrayField(const TStructure & structure)415 bool sh::HasArrayField(const TStructure &structure)
416 {
417 for (const TField *field : structure.fields())
418 {
419 const TType &type = *field->type();
420 if (type.isArray())
421 {
422 return true;
423 }
424 }
425 return false;
426 }
427
CoerceSimple(TBasicType toBasicType,TIntermTyped & fromNode,bool needsExplicitBoolCast)428 TIntermTyped &sh::CoerceSimple(TBasicType toBasicType,
429 TIntermTyped &fromNode,
430 bool needsExplicitBoolCast)
431 {
432 const TType &fromType = fromNode.getType();
433
434 ASSERT(HasScalarBasicType(toBasicType));
435 ASSERT(HasScalarBasicType(fromType));
436 ASSERT(!fromType.isArray());
437
438 const TBasicType fromBasicType = fromType.getBasicType();
439
440 if (toBasicType != fromBasicType)
441 {
442 if (toBasicType == TBasicType::EbtBool && fromNode.isVector() && needsExplicitBoolCast)
443 {
444 switch (fromBasicType)
445 {
446 case TBasicType::EbtFloat:
447 case TBasicType::EbtDouble:
448 case TBasicType::EbtInt:
449 case TBasicType::EbtUInt:
450 {
451 TIntermSequence *argsSequence = new TIntermSequence();
452 for (uint8_t i = 0; i < fromType.getNominalSize(); i++)
453 {
454 TIntermTyped &fromTypeSwizzle = SubVector(fromNode, i, i + 1);
455 TIntermAggregate *boolConstructor = TIntermAggregate::CreateConstructor(
456 *new TType(toBasicType, 1, 1), new TIntermSequence{&fromTypeSwizzle});
457 argsSequence->push_back(boolConstructor);
458 }
459 return *TIntermAggregate::CreateConstructor(
460 *new TType(toBasicType, fromType.getNominalSize(),
461 fromType.getSecondarySize()),
462 argsSequence);
463 }
464
465 default:
466 break; // No explicit conversion needed
467 }
468 }
469
470 return *TIntermAggregate::CreateConstructor(
471 *new TType(toBasicType, fromType.getNominalSize(), fromType.getSecondarySize()),
472 new TIntermSequence{&fromNode});
473 }
474 return fromNode;
475 }
476
CoerceSimple(const TType & toType,TIntermTyped & fromNode,bool needsExplicitBoolCast)477 TIntermTyped &sh::CoerceSimple(const TType &toType,
478 TIntermTyped &fromNode,
479 bool needsExplicitBoolCast)
480 {
481 const TType &fromType = fromNode.getType();
482
483 ASSERT(HasScalarBasicType(toType));
484 ASSERT(HasScalarBasicType(fromType));
485 ASSERT(toType.getNominalSize() == fromType.getNominalSize());
486 ASSERT(toType.getSecondarySize() == fromType.getSecondarySize());
487 ASSERT(!toType.isArray());
488 ASSERT(!fromType.isArray());
489
490 const TBasicType toBasicType = toType.getBasicType();
491 const TBasicType fromBasicType = fromType.getBasicType();
492
493 if (toBasicType != fromBasicType)
494 {
495 if (toBasicType == TBasicType::EbtBool && fromNode.isVector() && needsExplicitBoolCast)
496 {
497 switch (fromBasicType)
498 {
499 case TBasicType::EbtFloat:
500 case TBasicType::EbtDouble:
501 case TBasicType::EbtInt:
502 case TBasicType::EbtUInt:
503 {
504 TIntermSequence *argsSequence = new TIntermSequence();
505 for (uint8_t i = 0; i < fromType.getNominalSize(); i++)
506 {
507 TIntermTyped &fromTypeSwizzle = SubVector(fromNode, i, i + 1);
508 TIntermAggregate *boolConstructor = TIntermAggregate::CreateConstructor(
509 *new TType(toBasicType, 1, 1), new TIntermSequence{&fromTypeSwizzle});
510 argsSequence->push_back(boolConstructor);
511 }
512 return *TIntermAggregate::CreateConstructor(
513 *new TType(toBasicType, fromType.getNominalSize(),
514 fromType.getSecondarySize()),
515 new TIntermSequence{*argsSequence});
516 }
517
518 default:
519 break; // No explicit conversion needed
520 }
521 }
522
523 return *TIntermAggregate::CreateConstructor(toType, new TIntermSequence{&fromNode});
524 }
525 return fromNode;
526 }
527
AsType(SymbolEnv & symbolEnv,const TType & toType,TIntermTyped & fromNode)528 TIntermTyped &sh::AsType(SymbolEnv &symbolEnv, const TType &toType, TIntermTyped &fromNode)
529 {
530 const TType &fromType = fromNode.getType();
531
532 ASSERT(HasScalarBasicType(toType));
533 ASSERT(HasScalarBasicType(fromType));
534 ASSERT(!toType.isArray());
535 ASSERT(!fromType.isArray());
536
537 if (toType == fromType)
538 {
539 return fromNode;
540 }
541 TemplateArg targ(toType);
542 return symbolEnv.callFunctionOverload(Name("as_type", SymbolType::BuiltIn), toType,
543 *new TIntermSequence{&fromNode}, 1, &targ);
544 }
545