1 //
2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include "compiler/translator/tree_ops/InitializeVariables.h"
8
9 #include "angle_gl.h"
10 #include "common/debug.h"
11 #include "common/hash_containers.h"
12 #include "compiler/translator/Compiler.h"
13 #include "compiler/translator/StaticType.h"
14 #include "compiler/translator/SymbolTable.h"
15 #include "compiler/translator/tree_util/FindMain.h"
16 #include "compiler/translator/tree_util/FindSymbolNode.h"
17 #include "compiler/translator/tree_util/IntermNode_util.h"
18 #include "compiler/translator/tree_util/IntermTraverse.h"
19 #include "compiler/translator/util.h"
20
21 namespace sh
22 {
23
24 namespace
25 {
26
27 void AddArrayZeroInitSequence(const TIntermTyped *initializedNode,
28 bool canUseLoopsToInitialize,
29 bool highPrecisionSupported,
30 TIntermSequence *initSequenceOut,
31 TSymbolTable *symbolTable);
32
33 void AddStructZeroInitSequence(const TIntermTyped *initializedNode,
34 bool canUseLoopsToInitialize,
35 bool highPrecisionSupported,
36 TIntermSequence *initSequenceOut,
37 TSymbolTable *symbolTable);
38
CreateZeroInitAssignment(const TIntermTyped * initializedNode)39 TIntermBinary *CreateZeroInitAssignment(const TIntermTyped *initializedNode)
40 {
41 TIntermTyped *zero = CreateZeroNode(initializedNode->getType());
42 return new TIntermBinary(EOpAssign, initializedNode->deepCopy(), zero);
43 }
44
AddZeroInitSequence(const TIntermTyped * initializedNode,bool canUseLoopsToInitialize,bool highPrecisionSupported,TIntermSequence * initSequenceOut,TSymbolTable * symbolTable)45 void AddZeroInitSequence(const TIntermTyped *initializedNode,
46 bool canUseLoopsToInitialize,
47 bool highPrecisionSupported,
48 TIntermSequence *initSequenceOut,
49 TSymbolTable *symbolTable)
50 {
51 if (initializedNode->isArray())
52 {
53 AddArrayZeroInitSequence(initializedNode, canUseLoopsToInitialize, highPrecisionSupported,
54 initSequenceOut, symbolTable);
55 }
56 else if (initializedNode->getType().isStructureContainingArrays() ||
57 initializedNode->getType().isNamelessStruct())
58 {
59 AddStructZeroInitSequence(initializedNode, canUseLoopsToInitialize, highPrecisionSupported,
60 initSequenceOut, symbolTable);
61 }
62 else if (initializedNode->getType().isInterfaceBlock())
63 {
64 const TType &type = initializedNode->getType();
65 const TInterfaceBlock &interfaceBlock = *type.getInterfaceBlock();
66 const TFieldList &fieldList = interfaceBlock.fields();
67 for (size_t fieldIndex = 0; fieldIndex < fieldList.size(); ++fieldIndex)
68 {
69 const TField &field = *fieldList[fieldIndex];
70 TIntermTyped *fieldIndexRef = CreateIndexNode(static_cast<int>(fieldIndex));
71 TIntermTyped *fieldReference =
72 new TIntermBinary(TOperator::EOpIndexDirectInterfaceBlock,
73 initializedNode->deepCopy(), fieldIndexRef);
74 TIntermTyped *fieldZero = CreateZeroNode(*field.type());
75 TIntermTyped *assignment =
76 new TIntermBinary(TOperator::EOpAssign, fieldReference, fieldZero);
77 initSequenceOut->push_back(assignment);
78 }
79 }
80 else
81 {
82 initSequenceOut->push_back(CreateZeroInitAssignment(initializedNode));
83 }
84 }
85
AddStructZeroInitSequence(const TIntermTyped * initializedNode,bool canUseLoopsToInitialize,bool highPrecisionSupported,TIntermSequence * initSequenceOut,TSymbolTable * symbolTable)86 void AddStructZeroInitSequence(const TIntermTyped *initializedNode,
87 bool canUseLoopsToInitialize,
88 bool highPrecisionSupported,
89 TIntermSequence *initSequenceOut,
90 TSymbolTable *symbolTable)
91 {
92 ASSERT(initializedNode->getBasicType() == EbtStruct);
93 const TStructure *structType = initializedNode->getType().getStruct();
94 for (int i = 0; i < static_cast<int>(structType->fields().size()); ++i)
95 {
96 TIntermBinary *element = new TIntermBinary(EOpIndexDirectStruct,
97 initializedNode->deepCopy(), CreateIndexNode(i));
98 // Structs can't be defined inside structs, so the type of a struct field can't be a
99 // nameless struct.
100 ASSERT(!element->getType().isNamelessStruct());
101 AddZeroInitSequence(element, canUseLoopsToInitialize, highPrecisionSupported,
102 initSequenceOut, symbolTable);
103 }
104 }
105
AddArrayZeroInitStatementList(const TIntermTyped * initializedNode,bool canUseLoopsToInitialize,bool highPrecisionSupported,TIntermSequence * initSequenceOut,TSymbolTable * symbolTable)106 void AddArrayZeroInitStatementList(const TIntermTyped *initializedNode,
107 bool canUseLoopsToInitialize,
108 bool highPrecisionSupported,
109 TIntermSequence *initSequenceOut,
110 TSymbolTable *symbolTable)
111 {
112 for (unsigned int i = 0; i < initializedNode->getOutermostArraySize(); ++i)
113 {
114 TIntermBinary *element =
115 new TIntermBinary(EOpIndexDirect, initializedNode->deepCopy(), CreateIndexNode(i));
116 AddZeroInitSequence(element, canUseLoopsToInitialize, highPrecisionSupported,
117 initSequenceOut, symbolTable);
118 }
119 }
120
AddArrayZeroInitForLoop(const TIntermTyped * initializedNode,bool highPrecisionSupported,TIntermSequence * initSequenceOut,TSymbolTable * symbolTable)121 void AddArrayZeroInitForLoop(const TIntermTyped *initializedNode,
122 bool highPrecisionSupported,
123 TIntermSequence *initSequenceOut,
124 TSymbolTable *symbolTable)
125 {
126 ASSERT(initializedNode->isArray());
127 const TType *mediumpIndexType = StaticType::Get<EbtInt, EbpMedium, EvqTemporary, 1, 1>();
128 const TType *highpIndexType = StaticType::Get<EbtInt, EbpHigh, EvqTemporary, 1, 1>();
129 TVariable *indexVariable =
130 CreateTempVariable(symbolTable, highPrecisionSupported ? highpIndexType : mediumpIndexType);
131
132 TIntermSymbol *indexSymbolNode = CreateTempSymbolNode(indexVariable);
133 TIntermDeclaration *indexInit =
134 CreateTempInitDeclarationNode(indexVariable, CreateZeroNode(indexVariable->getType()));
135 TIntermConstantUnion *arraySizeNode = CreateIndexNode(initializedNode->getOutermostArraySize());
136 TIntermBinary *indexSmallerThanSize =
137 new TIntermBinary(EOpLessThan, indexSymbolNode->deepCopy(), arraySizeNode);
138 TIntermUnary *indexIncrement =
139 new TIntermUnary(EOpPreIncrement, indexSymbolNode->deepCopy(), nullptr);
140
141 TIntermBlock *forLoopBody = new TIntermBlock();
142 TIntermSequence *forLoopBodySeq = forLoopBody->getSequence();
143
144 TIntermBinary *element = new TIntermBinary(EOpIndexIndirect, initializedNode->deepCopy(),
145 indexSymbolNode->deepCopy());
146 AddZeroInitSequence(element, true, highPrecisionSupported, forLoopBodySeq, symbolTable);
147
148 TIntermLoop *forLoop =
149 new TIntermLoop(ELoopFor, indexInit, indexSmallerThanSize, indexIncrement, forLoopBody);
150 initSequenceOut->push_back(forLoop);
151 }
152
AddArrayZeroInitSequence(const TIntermTyped * initializedNode,bool canUseLoopsToInitialize,bool highPrecisionSupported,TIntermSequence * initSequenceOut,TSymbolTable * symbolTable)153 void AddArrayZeroInitSequence(const TIntermTyped *initializedNode,
154 bool canUseLoopsToInitialize,
155 bool highPrecisionSupported,
156 TIntermSequence *initSequenceOut,
157 TSymbolTable *symbolTable)
158 {
159 // The array elements are assigned one by one to keep the AST compatible with ESSL 1.00 which
160 // doesn't have array assignment. We'll do this either with a for loop or just a list of
161 // statements assigning to each array index. Note that it is important to have the array init in
162 // the right order to workaround http://crbug.com/709317
163 bool isSmallArray = initializedNode->getOutermostArraySize() <= 1u ||
164 (initializedNode->getBasicType() != EbtStruct &&
165 !initializedNode->getType().isArrayOfArrays() &&
166 initializedNode->getOutermostArraySize() <= 3u);
167 if (initializedNode->getQualifier() == EvqFragData ||
168 initializedNode->getQualifier() == EvqFragmentOut || isSmallArray ||
169 !canUseLoopsToInitialize)
170 {
171 // Fragment outputs should not be indexed by non-constant indices.
172 // Also it doesn't make sense to use loops to initialize very small arrays.
173 AddArrayZeroInitStatementList(initializedNode, canUseLoopsToInitialize,
174 highPrecisionSupported, initSequenceOut, symbolTable);
175 }
176 else
177 {
178 AddArrayZeroInitForLoop(initializedNode, highPrecisionSupported, initSequenceOut,
179 symbolTable);
180 }
181 }
182
InsertInitCode(TCompiler * compiler,TIntermBlock * root,const InitVariableList & variables,TSymbolTable * symbolTable,int shaderVersion,const TExtensionBehavior & extensionBehavior,bool canUseLoopsToInitialize,bool highPrecisionSupported)183 void InsertInitCode(TCompiler *compiler,
184 TIntermBlock *root,
185 const InitVariableList &variables,
186 TSymbolTable *symbolTable,
187 int shaderVersion,
188 const TExtensionBehavior &extensionBehavior,
189 bool canUseLoopsToInitialize,
190 bool highPrecisionSupported)
191 {
192 TIntermSequence *mainBody = FindMainBody(root)->getSequence();
193 for (const TVariable *var : variables)
194 {
195 TIntermTyped *initializedSymbol = nullptr;
196
197 if (var->symbolType() == SymbolType::Empty)
198 {
199 // Must be a nameless interface block.
200 ASSERT(var->getType().getInterfaceBlock() != nullptr);
201 ASSERT(!var->getType().getInterfaceBlock()->name().empty());
202
203 const TInterfaceBlock *block = var->getType().getInterfaceBlock();
204 for (const TField *field : block->fields())
205 {
206 initializedSymbol = ReferenceGlobalVariable(field->name(), *symbolTable);
207
208 TIntermSequence initCode;
209 CreateInitCode(initializedSymbol, canUseLoopsToInitialize, highPrecisionSupported,
210 &initCode, symbolTable);
211 mainBody->insert(mainBody->begin(), initCode.begin(), initCode.end());
212 }
213
214 // All done with the interface block
215 continue;
216 }
217
218 const TQualifier qualifier = var->getType().getQualifier();
219
220 initializedSymbol = new TIntermSymbol(var);
221 if (qualifier == EvqFragData &&
222 !IsExtensionEnabled(extensionBehavior, TExtension::EXT_draw_buffers))
223 {
224 // If GL_EXT_draw_buffers is disabled, only the 0th index of gl_FragData can be
225 // written to.
226 initializedSymbol =
227 new TIntermBinary(EOpIndexDirect, initializedSymbol, CreateIndexNode(0));
228 }
229
230 TIntermSequence initCode;
231 CreateInitCode(initializedSymbol, canUseLoopsToInitialize, highPrecisionSupported,
232 &initCode, symbolTable);
233 mainBody->insert(mainBody->begin(), initCode.begin(), initCode.end());
234 }
235 }
236
CloneFunctionHeader(TSymbolTable * symbolTable,const TFunction * function)237 TFunction *CloneFunctionHeader(TSymbolTable *symbolTable, const TFunction *function)
238 {
239 TFunction *newFunction =
240 new TFunction(symbolTable, function->name(), function->symbolType(),
241 &function->getReturnType(), function->isKnownToNotHaveSideEffects());
242
243 if (function->isDefined())
244 {
245 newFunction->setDefined();
246 }
247 if (function->hasPrototypeDeclaration())
248 {
249 newFunction->setHasPrototypeDeclaration();
250 }
251 return newFunction;
252 }
253
254 class InitializeLocalsTraverser final : public TIntermTraverser
255 {
256 public:
InitializeLocalsTraverser(int shaderVersion,TSymbolTable * symbolTable,bool canUseLoopsToInitialize,bool highPrecisionSupported)257 InitializeLocalsTraverser(int shaderVersion,
258 TSymbolTable *symbolTable,
259 bool canUseLoopsToInitialize,
260 bool highPrecisionSupported)
261 : TIntermTraverser(true, false, false, symbolTable),
262 mShaderVersion(shaderVersion),
263 mCanUseLoopsToInitialize(canUseLoopsToInitialize),
264 mHighPrecisionSupported(highPrecisionSupported)
265 {}
266
collectUnnamedOutFunctions(TIntermBlock & root)267 void collectUnnamedOutFunctions(TIntermBlock &root)
268 {
269 TIntermSequence &sequence = *root.getSequence();
270 const size_t count = sequence.size();
271 for (size_t i = 0; i < count; ++i)
272 {
273 const TIntermFunctionDefinition *functionDefinition =
274 sequence[i]->getAsFunctionDefinition();
275 if (!functionDefinition)
276 {
277 continue;
278 }
279 const TFunction *function = functionDefinition->getFunction();
280 TFunction *newFunction = nullptr;
281 for (size_t p = 0; p < function->getParamCount(); ++p)
282 {
283 const TVariable *param = function->getParam(p);
284 const TType &type = param->getType();
285 if (param->symbolType() == SymbolType::Empty)
286 {
287 if (!newFunction)
288 {
289 newFunction = CloneFunctionHeader(mSymbolTable, function);
290 mFunctionsToReplace[function] = newFunction;
291 for (size_t z = 0; z < p; ++z)
292 {
293 newFunction->addParameter(function->getParam(z));
294 }
295 }
296 param = new TVariable(mSymbolTable, kEmptyImmutableString, &type,
297 SymbolType::AngleInternal, param->extensions());
298 }
299 if (newFunction)
300 {
301 newFunction->addParameter(param);
302 }
303 }
304 }
305 }
306
307 protected:
visitDeclaration(Visit visit,TIntermDeclaration * node)308 bool visitDeclaration(Visit visit, TIntermDeclaration *node) override
309 {
310 for (TIntermNode *declarator : *node->getSequence())
311 {
312 if (!mInGlobalScope && !declarator->getAsBinaryNode())
313 {
314 TIntermSymbol *symbol = declarator->getAsSymbolNode();
315 ASSERT(symbol);
316 if (symbol->variable().symbolType() == SymbolType::Empty)
317 {
318 continue;
319 }
320
321 // Arrays may need to be initialized one element at a time, since ESSL 1.00 does not
322 // support array constructors or assigning arrays.
323 bool arrayConstructorUnavailable =
324 (symbol->isArray() || symbol->getType().isStructureContainingArrays()) &&
325 mShaderVersion == 100;
326 // Nameless struct constructors can't be referred to, so they also need to be
327 // initialized one element at a time.
328 // TODO(oetuaho): Check if it makes sense to initialize using a loop, even if we
329 // could use an initializer. It could at least reduce code size for very large
330 // arrays, but could hurt runtime performance.
331 if (arrayConstructorUnavailable || symbol->getType().isNamelessStruct())
332 {
333 // SimplifyLoopConditions should have been run so the parent node of this node
334 // should not be a loop.
335 ASSERT(getParentNode()->getAsLoopNode() == nullptr);
336 // SeparateDeclarations should have already been run, so we don't need to worry
337 // about further declarators in this declaration depending on the effects of
338 // this declarator.
339 ASSERT(node->getSequence()->size() == 1);
340 TIntermSequence initCode;
341 CreateInitCode(symbol, mCanUseLoopsToInitialize, mHighPrecisionSupported,
342 &initCode, mSymbolTable);
343 insertStatementsInParentBlock(TIntermSequence(), initCode);
344 }
345 else
346 {
347 TIntermBinary *init =
348 new TIntermBinary(EOpInitialize, symbol, CreateZeroNode(symbol->getType()));
349 queueReplacementWithParent(node, symbol, init, OriginalNode::BECOMES_CHILD);
350 }
351 }
352 }
353 // Must recurse in the cases which had initializers, because the initializiers might
354 // call the function that was rewritten.
355 return true;
356 }
357
visitFunctionPrototype(TIntermFunctionPrototype * node)358 void visitFunctionPrototype(TIntermFunctionPrototype *node) override
359 {
360 if (getParentNode()->getAsFunctionDefinition() != nullptr)
361 {
362 return;
363 }
364 auto it = mFunctionsToReplace.find(node->getFunction());
365 if (it != mFunctionsToReplace.end())
366 {
367 queueReplacement(new TIntermFunctionPrototype(it->second), OriginalNode::IS_DROPPED);
368 }
369 }
370
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)371 bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override
372 {
373 // Initialize output function arguments as well, the parameter passed in at call time may be
374 // clobbered if the function doesn't fully write to the argument.
375
376 TIntermSequence initCode;
377
378 const TFunction *function = node->getFunction();
379 auto it = mFunctionsToReplace.find(function);
380 if (it != mFunctionsToReplace.end())
381 {
382 function = it->second;
383 TIntermFunctionPrototype *newPrototypeNode = new TIntermFunctionPrototype(function);
384 TIntermFunctionDefinition *newNode =
385 new TIntermFunctionDefinition(newPrototypeNode, node->getBody());
386 queueReplacement(newNode, OriginalNode::IS_DROPPED);
387 }
388
389 for (size_t paramIndex = 0; paramIndex < function->getParamCount(); ++paramIndex)
390 {
391 const TVariable *paramVariable = function->getParam(paramIndex);
392 const TType ¶mType = paramVariable->getType();
393
394 if (paramType.getQualifier() != EvqParamOut)
395 {
396 continue;
397 }
398
399 CreateInitCode(new TIntermSymbol(paramVariable), mCanUseLoopsToInitialize,
400 mHighPrecisionSupported, &initCode, mSymbolTable);
401 }
402
403 if (!initCode.empty())
404 {
405 TIntermSequence *body = node->getBody()->getSequence();
406 body->insert(body->begin(), initCode.begin(), initCode.end());
407 }
408
409 return true;
410 }
411
visitAggregate(Visit visit,TIntermAggregate * node)412 bool visitAggregate(Visit visit, TIntermAggregate *node) override
413 {
414 const TFunction *function = node->getFunction();
415 if (function != nullptr)
416 {
417 auto it = mFunctionsToReplace.find(function);
418 if (it != mFunctionsToReplace.end())
419 {
420 const TFunction *target = it->second;
421 TIntermAggregate *newNode =
422 TIntermAggregate::CreateFunctionCall(*target, node->getSequence());
423 queueReplacement(newNode, OriginalNode::IS_DROPPED);
424 }
425 }
426 return true;
427 }
428
429 private:
430 int mShaderVersion;
431 bool mCanUseLoopsToInitialize;
432 bool mHighPrecisionSupported;
433 angle::HashMap<const TFunction *, TFunction *> mFunctionsToReplace;
434 };
435
436 } // namespace
437
CreateInitCode(const TIntermTyped * initializedSymbol,bool canUseLoopsToInitialize,bool highPrecisionSupported,TIntermSequence * initCode,TSymbolTable * symbolTable)438 void CreateInitCode(const TIntermTyped *initializedSymbol,
439 bool canUseLoopsToInitialize,
440 bool highPrecisionSupported,
441 TIntermSequence *initCode,
442 TSymbolTable *symbolTable)
443 {
444 AddZeroInitSequence(initializedSymbol, canUseLoopsToInitialize, highPrecisionSupported,
445 initCode, symbolTable);
446 }
447
InitializeUninitializedLocals(TCompiler * compiler,TIntermBlock * root,int shaderVersion,bool canUseLoopsToInitialize,bool highPrecisionSupported,TSymbolTable * symbolTable)448 bool InitializeUninitializedLocals(TCompiler *compiler,
449 TIntermBlock *root,
450 int shaderVersion,
451 bool canUseLoopsToInitialize,
452 bool highPrecisionSupported,
453 TSymbolTable *symbolTable)
454 {
455 InitializeLocalsTraverser traverser(shaderVersion, symbolTable, canUseLoopsToInitialize,
456 highPrecisionSupported);
457 traverser.collectUnnamedOutFunctions(*root);
458 root->traverse(&traverser);
459 return traverser.updateTree(compiler, root);
460 }
461
InitializeVariables(TCompiler * compiler,TIntermBlock * root,const InitVariableList & vars,TSymbolTable * symbolTable,int shaderVersion,const TExtensionBehavior & extensionBehavior,bool canUseLoopsToInitialize,bool highPrecisionSupported)462 bool InitializeVariables(TCompiler *compiler,
463 TIntermBlock *root,
464 const InitVariableList &vars,
465 TSymbolTable *symbolTable,
466 int shaderVersion,
467 const TExtensionBehavior &extensionBehavior,
468 bool canUseLoopsToInitialize,
469 bool highPrecisionSupported)
470 {
471 InsertInitCode(compiler, root, vars, symbolTable, shaderVersion, extensionBehavior,
472 canUseLoopsToInitialize, highPrecisionSupported);
473
474 return compiler->validateAST(root);
475 }
476
477 } // namespace sh
478