1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include <cstring>
8 #include <numeric>
9 #include <unordered_map>
10 #include <unordered_set>
11
12 #include "compiler/translator/msl/AstHelpers.h"
13
14 using namespace sh;
15
16 ////////////////////////////////////////////////////////////////////////////////
17
CreateStructTypeVariable(TSymbolTable & symbolTable,const TStructure & structure)18 const TVariable &sh::CreateStructTypeVariable(TSymbolTable &symbolTable,
19 const TStructure &structure)
20 {
21 TType *type = new TType(&structure, true);
22 TVariable *var = new TVariable(&symbolTable, ImmutableString(""), type, SymbolType::Empty);
23 return *var;
24 }
25
CreateInstanceVariable(TSymbolTable & symbolTable,const TStructure & structure,const Name & name,TQualifier qualifier,const TSpan<const unsigned int> * arraySizes)26 const TVariable &sh::CreateInstanceVariable(TSymbolTable &symbolTable,
27 const TStructure &structure,
28 const Name &name,
29 TQualifier qualifier,
30 const TSpan<const unsigned int> *arraySizes)
31 {
32 TType *type = new TType(&structure, false);
33 type->setQualifier(qualifier);
34 if (arraySizes)
35 {
36 type->makeArrays(*arraySizes);
37 }
38 TVariable *var = new TVariable(&symbolTable, name.rawName(), type, name.symbolType());
39 return *var;
40 }
41
AcquireFunctionExtras(TFunction & dest,const TFunction & src)42 static void AcquireFunctionExtras(TFunction &dest, const TFunction &src)
43 {
44 if (src.isDefined())
45 {
46 dest.setDefined();
47 }
48
49 if (src.hasPrototypeDeclaration())
50 {
51 dest.setHasPrototypeDeclaration();
52 }
53 }
54
CloneSequenceAndPrepend(const TIntermSequence & seq,TIntermNode & node)55 TIntermSequence &sh::CloneSequenceAndPrepend(const TIntermSequence &seq, TIntermNode &node)
56 {
57 TIntermSequence *newSeq = new TIntermSequence();
58 newSeq->push_back(&node);
59
60 for (TIntermNode *oldNode : seq)
61 {
62 newSeq->push_back(oldNode);
63 }
64
65 return *newSeq;
66 }
67
AddParametersFrom(TFunction & dest,const TFunction & src)68 void sh::AddParametersFrom(TFunction &dest, const TFunction &src)
69 {
70 const size_t paramCount = src.getParamCount();
71 for (size_t i = 0; i < paramCount; ++i)
72 {
73 const TVariable *var = src.getParam(i);
74 dest.addParameter(var);
75 }
76 }
77
CloneFunction(TSymbolTable & symbolTable,IdGen & idGen,const TFunction & oldFunc)78 const TFunction &sh::CloneFunction(TSymbolTable &symbolTable,
79 IdGen &idGen,
80 const TFunction &oldFunc)
81 {
82 ASSERT(oldFunc.symbolType() == SymbolType::UserDefined);
83
84 Name newName = idGen.createNewName(Name(oldFunc));
85
86 TFunction &newFunc =
87 *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
88 &oldFunc.getReturnType(), oldFunc.isKnownToNotHaveSideEffects());
89
90 AcquireFunctionExtras(newFunc, oldFunc);
91 AddParametersFrom(newFunc, oldFunc);
92
93 return newFunc;
94 }
95
CloneFunctionAndPrependParam(TSymbolTable & symbolTable,IdGen * idGen,const TFunction & oldFunc,const TVariable & newParam)96 const TFunction &sh::CloneFunctionAndPrependParam(TSymbolTable &symbolTable,
97 IdGen *idGen,
98 const TFunction &oldFunc,
99 const TVariable &newParam)
100 {
101 ASSERT(oldFunc.symbolType() == SymbolType::UserDefined ||
102 oldFunc.symbolType() == SymbolType::AngleInternal);
103
104 Name newName = idGen ? idGen->createNewName(Name(oldFunc)) : Name(oldFunc);
105
106 TFunction &newFunc =
107 *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
108 &oldFunc.getReturnType(), oldFunc.isKnownToNotHaveSideEffects());
109
110 AcquireFunctionExtras(newFunc, oldFunc);
111 newFunc.addParameter(&newParam);
112 AddParametersFrom(newFunc, oldFunc);
113
114 return newFunc;
115 }
116
CloneFunctionAndPrependTwoParams(TSymbolTable & symbolTable,IdGen * idGen,const TFunction & oldFunc,const TVariable & newParam1,const TVariable & newParam2)117 const TFunction &sh::CloneFunctionAndPrependTwoParams(TSymbolTable &symbolTable,
118 IdGen *idGen,
119 const TFunction &oldFunc,
120 const TVariable &newParam1,
121 const TVariable &newParam2)
122 {
123 ASSERT(oldFunc.symbolType() == SymbolType::UserDefined ||
124 oldFunc.symbolType() == SymbolType::AngleInternal);
125
126 Name newName = idGen ? idGen->createNewName(Name(oldFunc)) : Name(oldFunc);
127
128 TFunction &newFunc =
129 *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
130 &oldFunc.getReturnType(), oldFunc.isKnownToNotHaveSideEffects());
131
132 AcquireFunctionExtras(newFunc, oldFunc);
133 newFunc.addParameter(&newParam1);
134 newFunc.addParameter(&newParam2);
135 AddParametersFrom(newFunc, oldFunc);
136
137 return newFunc;
138 }
139
CloneFunctionAndAppendParams(TSymbolTable & symbolTable,IdGen * idGen,const TFunction & oldFunc,const std::vector<const TVariable * > & newParams)140 const TFunction &sh::CloneFunctionAndAppendParams(TSymbolTable &symbolTable,
141 IdGen *idGen,
142 const TFunction &oldFunc,
143 const std::vector<const TVariable *> &newParams)
144 {
145 ASSERT(oldFunc.symbolType() == SymbolType::UserDefined ||
146 oldFunc.symbolType() == SymbolType::AngleInternal);
147
148 Name newName = idGen ? idGen->createNewName(Name(oldFunc)) : Name(oldFunc);
149
150 TFunction &newFunc =
151 *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
152 &oldFunc.getReturnType(), oldFunc.isKnownToNotHaveSideEffects());
153
154 AcquireFunctionExtras(newFunc, oldFunc);
155 AddParametersFrom(newFunc, oldFunc);
156 for (const TVariable *param : newParams)
157 {
158 newFunc.addParameter(param);
159 }
160
161 return newFunc;
162 }
163
CloneFunctionAndChangeReturnType(TSymbolTable & symbolTable,IdGen * idGen,const TFunction & oldFunc,const TStructure & newReturn)164 const TFunction &sh::CloneFunctionAndChangeReturnType(TSymbolTable &symbolTable,
165 IdGen *idGen,
166 const TFunction &oldFunc,
167 const TStructure &newReturn)
168 {
169 ASSERT(oldFunc.symbolType() == SymbolType::UserDefined);
170
171 Name newName = idGen ? idGen->createNewName(Name(oldFunc)) : Name(oldFunc);
172
173 TType *newReturnType = new TType(&newReturn, true);
174 TFunction &newFunc = *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
175 newReturnType, oldFunc.isKnownToNotHaveSideEffects());
176
177 AcquireFunctionExtras(newFunc, oldFunc);
178 AddParametersFrom(newFunc, oldFunc);
179
180 return newFunc;
181 }
182
GetArg(const TIntermAggregate & call,size_t index)183 TIntermTyped &sh::GetArg(const TIntermAggregate &call, size_t index)
184 {
185 ASSERT(index < call.getChildCount());
186 TIntermNode *arg = call.getChildNode(index);
187 ASSERT(arg);
188 TIntermTyped *targ = arg->getAsTyped();
189 ASSERT(targ);
190 return *targ;
191 }
192
SetArg(TIntermAggregate & call,size_t index,TIntermTyped & arg)193 void sh::SetArg(TIntermAggregate &call, size_t index, TIntermTyped &arg)
194 {
195 ASSERT(index < call.getChildCount());
196 (*call.getSequence())[index] = &arg;
197 }
198
AccessField(const TVariable & structInstanceVar,const Name & name)199 TIntermBinary &sh::AccessField(const TVariable &structInstanceVar, const Name &name)
200 {
201 return AccessField(*new TIntermSymbol(&structInstanceVar), name);
202 }
203
AccessField(TIntermTyped & object,const Name & name)204 TIntermBinary &sh::AccessField(TIntermTyped &object, const Name &name)
205 {
206 const TStructure *structure = object.getType().getStruct();
207 ASSERT(structure);
208 const TFieldList &fieldList = structure->fields();
209 for (int i = 0; i < static_cast<int>(fieldList.size()); ++i)
210 {
211 TField *current = fieldList[i];
212 if (Name(*current) == name)
213 {
214 return AccessFieldByIndex(object, i);
215 }
216 }
217 UNREACHABLE();
218 return AccessFieldByIndex(object, -1);
219 }
220
AccessFieldByIndex(TIntermTyped & object,int index)221 TIntermBinary &sh::AccessFieldByIndex(TIntermTyped &object, int index)
222 {
223 #if defined(ANGLE_ENABLE_ASSERTS)
224 const TType &type = object.getType();
225 ASSERT(!type.isArray());
226 const TStructure *structure = type.getStruct();
227 ASSERT(structure);
228 ASSERT(0 <= index);
229 ASSERT(static_cast<size_t>(index) < structure->fields().size());
230 #endif
231
232 return *new TIntermBinary(
233 TOperator::EOpIndexDirectStruct, &object,
234 new TIntermConstantUnion(new TConstantUnion(index), *new TType(TBasicType::EbtInt)));
235 }
236
AccessIndex(TIntermTyped & indexableNode,int index)237 TIntermBinary &sh::AccessIndex(TIntermTyped &indexableNode, int index)
238 {
239 #if defined(ANGLE_ENABLE_ASSERTS)
240 const TType &type = indexableNode.getType();
241 ASSERT(type.isArray() || type.isVector() || type.isMatrix());
242 #endif
243
244 TIntermBinary *accessNode = new TIntermBinary(
245 TOperator::EOpIndexDirect, &indexableNode,
246 new TIntermConstantUnion(new TConstantUnion(index), *new TType(TBasicType::EbtInt)));
247 return *accessNode;
248 }
249
AccessIndex(TIntermTyped & node,const int * index)250 TIntermTyped &sh::AccessIndex(TIntermTyped &node, const int *index)
251 {
252 if (index)
253 {
254 return AccessIndex(node, *index);
255 }
256 return node;
257 }
258
SubVector(TIntermTyped & vectorNode,int begin,int end)259 TIntermTyped &sh::SubVector(TIntermTyped &vectorNode, int begin, int end)
260 {
261 ASSERT(vectorNode.getType().isVector());
262 ASSERT(0 <= begin);
263 ASSERT(end <= 4);
264 ASSERT(begin <= end);
265 if (begin == 0 && end == vectorNode.getType().getNominalSize())
266 {
267 return vectorNode;
268 }
269 TVector<int> offsets(static_cast<size_t>(end - begin));
270 std::iota(offsets.begin(), offsets.end(), begin);
271 TIntermSwizzle *swizzle = new TIntermSwizzle(vectorNode.deepCopy(), offsets);
272 return *swizzle;
273 }
274
IsScalarBasicType(const TType & type)275 bool sh::IsScalarBasicType(const TType &type)
276 {
277 if (!type.isScalar())
278 {
279 return false;
280 }
281 return HasScalarBasicType(type);
282 }
283
IsVectorBasicType(const TType & type)284 bool sh::IsVectorBasicType(const TType &type)
285 {
286 if (!type.isVector())
287 {
288 return false;
289 }
290 return HasScalarBasicType(type);
291 }
292
HasScalarBasicType(TBasicType type)293 bool sh::HasScalarBasicType(TBasicType type)
294 {
295 switch (type)
296 {
297 case TBasicType::EbtFloat:
298 case TBasicType::EbtDouble:
299 case TBasicType::EbtInt:
300 case TBasicType::EbtUInt:
301 case TBasicType::EbtBool:
302 return true;
303
304 default:
305 return false;
306 }
307 }
308
HasScalarBasicType(const TType & type)309 bool sh::HasScalarBasicType(const TType &type)
310 {
311 return HasScalarBasicType(type.getBasicType());
312 }
313
CloneType(const TType & type)314 TType &sh::CloneType(const TType &type)
315 {
316 TType &clone = *new TType(type);
317 return clone;
318 }
319
InnermostType(const TType & type)320 TType &sh::InnermostType(const TType &type)
321 {
322 TType &inner = *new TType(type);
323 inner.toArrayBaseType();
324 return inner;
325 }
326
DropColumns(const TType & matrixType)327 TType &sh::DropColumns(const TType &matrixType)
328 {
329 ASSERT(matrixType.isMatrix());
330 ASSERT(HasScalarBasicType(matrixType));
331
332 TType &vectorType = *new TType(matrixType);
333 vectorType.toMatrixColumnType();
334 return vectorType;
335 }
336
DropOuterDimension(const TType & arrayType)337 TType &sh::DropOuterDimension(const TType &arrayType)
338 {
339 ASSERT(arrayType.isArray());
340
341 TType &innerType = *new TType(arrayType);
342 innerType.toArrayElementType();
343 return innerType;
344 }
345
SetTypeDimsImpl(const TType & type,int primary,int secondary)346 static TType &SetTypeDimsImpl(const TType &type, int primary, int secondary)
347 {
348 ASSERT(1 < primary && primary <= 4);
349 ASSERT(1 <= secondary && secondary <= 4);
350 ASSERT(HasScalarBasicType(type));
351
352 TType &newType = *new TType(type);
353 newType.setPrimarySize(primary);
354 newType.setSecondarySize(secondary);
355 return newType;
356 }
357
SetVectorDim(const TType & type,int newDim)358 TType &sh::SetVectorDim(const TType &type, int newDim)
359 {
360 ASSERT(type.isRank0() || type.isVector());
361 return SetTypeDimsImpl(type, newDim, 1);
362 }
363
SetMatrixRowDim(const TType & matrixType,int newDim)364 TType &sh::SetMatrixRowDim(const TType &matrixType, int newDim)
365 {
366 ASSERT(matrixType.isMatrix());
367 ASSERT(1 < newDim && newDim <= 4);
368 return SetTypeDimsImpl(matrixType, matrixType.getCols(), newDim);
369 }
370
HasMatrixField(const TStructure & structure)371 bool sh::HasMatrixField(const TStructure &structure)
372 {
373 for (const TField *field : structure.fields())
374 {
375 const TType &type = *field->type();
376 if (type.isMatrix())
377 {
378 return true;
379 }
380 }
381 return false;
382 }
383
HasArrayField(const TStructure & structure)384 bool sh::HasArrayField(const TStructure &structure)
385 {
386 for (const TField *field : structure.fields())
387 {
388 const TType &type = *field->type();
389 if (type.isArray())
390 {
391 return true;
392 }
393 }
394 return false;
395 }
396
CoerceSimple(TBasicType toBasicType,TIntermTyped & fromNode,bool needsExplicitBoolCast)397 TIntermTyped &sh::CoerceSimple(TBasicType toBasicType,
398 TIntermTyped &fromNode,
399 bool needsExplicitBoolCast)
400 {
401 const TType &fromType = fromNode.getType();
402
403 ASSERT(HasScalarBasicType(toBasicType));
404 ASSERT(HasScalarBasicType(fromType));
405 ASSERT(!fromType.isArray());
406
407 const TBasicType fromBasicType = fromType.getBasicType();
408
409 if (toBasicType != fromBasicType)
410 {
411 if (toBasicType == TBasicType::EbtBool && fromNode.isVector() && needsExplicitBoolCast)
412 {
413 switch (fromBasicType)
414 {
415 case TBasicType::EbtFloat:
416 case TBasicType::EbtDouble:
417 case TBasicType::EbtInt:
418 case TBasicType::EbtUInt:
419 {
420 TIntermSequence *argsSequence = new TIntermSequence();
421 for (uint8_t i = 0; i < fromType.getNominalSize(); i++)
422 {
423 TIntermTyped &fromTypeSwizzle = SubVector(fromNode, i, i + 1);
424 TIntermAggregate *boolConstructor = TIntermAggregate::CreateConstructor(
425 *new TType(toBasicType, 1, 1), new TIntermSequence{&fromTypeSwizzle});
426 argsSequence->push_back(boolConstructor);
427 }
428 return *TIntermAggregate::CreateConstructor(
429 *new TType(toBasicType, fromType.getNominalSize(),
430 fromType.getSecondarySize()),
431 argsSequence);
432 }
433
434 default:
435 break; // No explicit conversion needed
436 }
437 }
438
439 return *TIntermAggregate::CreateConstructor(
440 *new TType(toBasicType, fromType.getNominalSize(), fromType.getSecondarySize()),
441 new TIntermSequence{&fromNode});
442 }
443 return fromNode;
444 }
445
CoerceSimple(const TType & toType,TIntermTyped & fromNode,bool needsExplicitBoolCast)446 TIntermTyped &sh::CoerceSimple(const TType &toType,
447 TIntermTyped &fromNode,
448 bool needsExplicitBoolCast)
449 {
450 const TType &fromType = fromNode.getType();
451
452 ASSERT(HasScalarBasicType(toType));
453 ASSERT(HasScalarBasicType(fromType));
454 ASSERT(toType.getNominalSize() == fromType.getNominalSize());
455 ASSERT(toType.getSecondarySize() == fromType.getSecondarySize());
456 ASSERT(!toType.isArray());
457 ASSERT(!fromType.isArray());
458
459 const TBasicType toBasicType = toType.getBasicType();
460 const TBasicType fromBasicType = fromType.getBasicType();
461
462 if (toBasicType != fromBasicType)
463 {
464 if (toBasicType == TBasicType::EbtBool && fromNode.isVector() && needsExplicitBoolCast)
465 {
466 switch (fromBasicType)
467 {
468 case TBasicType::EbtFloat:
469 case TBasicType::EbtDouble:
470 case TBasicType::EbtInt:
471 case TBasicType::EbtUInt:
472 {
473 TIntermSequence *argsSequence = new TIntermSequence();
474 for (uint8_t i = 0; i < fromType.getNominalSize(); i++)
475 {
476 TIntermTyped &fromTypeSwizzle = SubVector(fromNode, i, i + 1);
477 TIntermAggregate *boolConstructor = TIntermAggregate::CreateConstructor(
478 *new TType(toBasicType, 1, 1), new TIntermSequence{&fromTypeSwizzle});
479 argsSequence->push_back(boolConstructor);
480 }
481 return *TIntermAggregate::CreateConstructor(
482 *new TType(toBasicType, fromType.getNominalSize(),
483 fromType.getSecondarySize()),
484 new TIntermSequence{*argsSequence});
485 }
486
487 default:
488 break; // No explicit conversion needed
489 }
490 }
491
492 return *TIntermAggregate::CreateConstructor(toType, new TIntermSequence{&fromNode});
493 }
494 return fromNode;
495 }
496
AsType(SymbolEnv & symbolEnv,const TType & toType,TIntermTyped & fromNode)497 TIntermTyped &sh::AsType(SymbolEnv &symbolEnv, const TType &toType, TIntermTyped &fromNode)
498 {
499 const TType &fromType = fromNode.getType();
500
501 ASSERT(HasScalarBasicType(toType));
502 ASSERT(HasScalarBasicType(fromType));
503 ASSERT(!toType.isArray());
504 ASSERT(!fromType.isArray());
505
506 if (toType == fromType)
507 {
508 return fromNode;
509 }
510 TemplateArg targ(toType);
511 return symbolEnv.callFunctionOverload(Name("as_type", SymbolType::BuiltIn), toType,
512 *new TIntermSequence{&fromNode}, 1, &targ);
513 }
514