1 //
2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RemoveDynamicIndexing is an AST traverser to remove dynamic indexing of non-SSBO vectors and
7 // matrices, replacing them with calls to functions that choose which component to return or write.
8 // We don't need to consider dynamic indexing in SSBO since it can be directly as part of the offset
9 // of RWByteAddressBuffer.
10 //
11
12 #include "compiler/translator/tree_ops/RemoveDynamicIndexing.h"
13
14 #include "compiler/translator/Diagnostics.h"
15 #include "compiler/translator/InfoSink.h"
16 #include "compiler/translator/StaticType.h"
17 #include "compiler/translator/SymbolTable.h"
18 #include "compiler/translator/tree_util/IntermNodePatternMatcher.h"
19 #include "compiler/translator/tree_util/IntermNode_util.h"
20 #include "compiler/translator/tree_util/IntermTraverse.h"
21
22 namespace sh
23 {
24
25 namespace
26 {
27
28 const TType *kIndexType = StaticType::Get<EbtInt, EbpHigh, EvqIn, 1, 1>();
29
30 constexpr const ImmutableString kBaseName("base");
31 constexpr const ImmutableString kIndexName("index");
32 constexpr const ImmutableString kValueName("value");
33
GetIndexFunctionName(const TType & type,bool write)34 std::string GetIndexFunctionName(const TType &type, bool write)
35 {
36 TInfoSinkBase nameSink;
37 nameSink << "dyn_index_";
38 if (write)
39 {
40 nameSink << "write_";
41 }
42 if (type.isMatrix())
43 {
44 nameSink << "mat" << type.getCols() << "x" << type.getRows();
45 }
46 else
47 {
48 switch (type.getBasicType())
49 {
50 case EbtInt:
51 nameSink << "ivec";
52 break;
53 case EbtBool:
54 nameSink << "bvec";
55 break;
56 case EbtUInt:
57 nameSink << "uvec";
58 break;
59 case EbtFloat:
60 nameSink << "vec";
61 break;
62 default:
63 UNREACHABLE();
64 }
65 nameSink << type.getNominalSize();
66 }
67 return nameSink.str();
68 }
69
CreateIntConstantNode(int i)70 TIntermConstantUnion *CreateIntConstantNode(int i)
71 {
72 TConstantUnion *constant = new TConstantUnion();
73 constant->setIConst(i);
74 return new TIntermConstantUnion(constant, TType(EbtInt, EbpHigh));
75 }
76
EnsureSignedInt(TIntermTyped * node)77 TIntermTyped *EnsureSignedInt(TIntermTyped *node)
78 {
79 if (node->getBasicType() == EbtInt)
80 return node;
81
82 TIntermSequence *arguments = new TIntermSequence();
83 arguments->push_back(node);
84 return TIntermAggregate::CreateConstructor(TType(EbtInt), arguments);
85 }
86
GetFieldType(const TType & indexedType)87 TType *GetFieldType(const TType &indexedType)
88 {
89 if (indexedType.isMatrix())
90 {
91 TType *fieldType = new TType(indexedType.getBasicType(), indexedType.getPrecision());
92 fieldType->setPrimarySize(static_cast<unsigned char>(indexedType.getRows()));
93 return fieldType;
94 }
95 else
96 {
97 return new TType(indexedType.getBasicType(), indexedType.getPrecision());
98 }
99 }
100
GetBaseType(const TType & type,bool write)101 const TType *GetBaseType(const TType &type, bool write)
102 {
103 TType *baseType = new TType(type);
104 // Conservatively use highp here, even if the indexed type is not highp. That way the code can't
105 // end up using mediump version of an indexing function for a highp value, if both mediump and
106 // highp values are being indexed in the shader. For HLSL precision doesn't matter, but in
107 // principle this code could be used with multiple backends.
108 baseType->setPrecision(EbpHigh);
109 baseType->setQualifier(EvqInOut);
110 if (!write)
111 baseType->setQualifier(EvqIn);
112 return baseType;
113 }
114
115 // Generate a read or write function for one field in a vector/matrix.
116 // Out-of-range indices are clamped. This is consistent with how ANGLE handles out-of-range
117 // indices in other places.
118 // Note that indices can be either int or uint. We create only int versions of the functions,
119 // and convert uint indices to int at the call site.
120 // read function example:
121 // float dyn_index_vec2(in vec2 base, in int index)
122 // {
123 // switch(index)
124 // {
125 // case (0):
126 // return base[0];
127 // case (1):
128 // return base[1];
129 // default:
130 // break;
131 // }
132 // if (index < 0)
133 // return base[0];
134 // return base[1];
135 // }
136 // write function example:
137 // void dyn_index_write_vec2(inout vec2 base, in int index, in float value)
138 // {
139 // switch(index)
140 // {
141 // case (0):
142 // base[0] = value;
143 // return;
144 // case (1):
145 // base[1] = value;
146 // return;
147 // default:
148 // break;
149 // }
150 // if (index < 0)
151 // {
152 // base[0] = value;
153 // return;
154 // }
155 // base[1] = value;
156 // }
157 // Note that else is not used in above functions to avoid the RewriteElseBlocks transformation.
GetIndexFunctionDefinition(const TType & type,bool write,const TFunction & func,TSymbolTable * symbolTable)158 TIntermFunctionDefinition *GetIndexFunctionDefinition(const TType &type,
159 bool write,
160 const TFunction &func,
161 TSymbolTable *symbolTable)
162 {
163 ASSERT(!type.isArray());
164
165 int numCases = 0;
166 if (type.isMatrix())
167 {
168 numCases = type.getCols();
169 }
170 else
171 {
172 numCases = type.getNominalSize();
173 }
174
175 std::string functionName = GetIndexFunctionName(type, write);
176 TIntermFunctionPrototype *prototypeNode = CreateInternalFunctionPrototypeNode(func);
177
178 TIntermSymbol *baseParam = new TIntermSymbol(func.getParam(0));
179 TIntermSymbol *indexParam = new TIntermSymbol(func.getParam(1));
180 TIntermSymbol *valueParam = nullptr;
181 if (write)
182 {
183 valueParam = new TIntermSymbol(func.getParam(2));
184 }
185
186 TIntermBlock *statementList = new TIntermBlock();
187 for (int i = 0; i < numCases; ++i)
188 {
189 TIntermCase *caseNode = new TIntermCase(CreateIntConstantNode(i));
190 statementList->getSequence()->push_back(caseNode);
191
192 TIntermBinary *indexNode =
193 new TIntermBinary(EOpIndexDirect, baseParam->deepCopy(), CreateIndexNode(i));
194 if (write)
195 {
196 TIntermBinary *assignNode =
197 new TIntermBinary(EOpAssign, indexNode, valueParam->deepCopy());
198 statementList->getSequence()->push_back(assignNode);
199 TIntermBranch *returnNode = new TIntermBranch(EOpReturn, nullptr);
200 statementList->getSequence()->push_back(returnNode);
201 }
202 else
203 {
204 TIntermBranch *returnNode = new TIntermBranch(EOpReturn, indexNode);
205 statementList->getSequence()->push_back(returnNode);
206 }
207 }
208
209 // Default case
210 TIntermCase *defaultNode = new TIntermCase(nullptr);
211 statementList->getSequence()->push_back(defaultNode);
212 TIntermBranch *breakNode = new TIntermBranch(EOpBreak, nullptr);
213 statementList->getSequence()->push_back(breakNode);
214
215 TIntermSwitch *switchNode = new TIntermSwitch(indexParam->deepCopy(), statementList);
216
217 TIntermBlock *bodyNode = new TIntermBlock();
218 bodyNode->getSequence()->push_back(switchNode);
219
220 TIntermBinary *cond =
221 new TIntermBinary(EOpLessThan, indexParam->deepCopy(), CreateIntConstantNode(0));
222
223 // Two blocks: one accesses (either reads or writes) the first element and returns,
224 // the other accesses the last element.
225 TIntermBlock *useFirstBlock = new TIntermBlock();
226 TIntermBlock *useLastBlock = new TIntermBlock();
227 TIntermBinary *indexFirstNode =
228 new TIntermBinary(EOpIndexDirect, baseParam->deepCopy(), CreateIndexNode(0));
229 TIntermBinary *indexLastNode =
230 new TIntermBinary(EOpIndexDirect, baseParam->deepCopy(), CreateIndexNode(numCases - 1));
231 if (write)
232 {
233 TIntermBinary *assignFirstNode =
234 new TIntermBinary(EOpAssign, indexFirstNode, valueParam->deepCopy());
235 useFirstBlock->getSequence()->push_back(assignFirstNode);
236 TIntermBranch *returnNode = new TIntermBranch(EOpReturn, nullptr);
237 useFirstBlock->getSequence()->push_back(returnNode);
238
239 TIntermBinary *assignLastNode =
240 new TIntermBinary(EOpAssign, indexLastNode, valueParam->deepCopy());
241 useLastBlock->getSequence()->push_back(assignLastNode);
242 }
243 else
244 {
245 TIntermBranch *returnFirstNode = new TIntermBranch(EOpReturn, indexFirstNode);
246 useFirstBlock->getSequence()->push_back(returnFirstNode);
247
248 TIntermBranch *returnLastNode = new TIntermBranch(EOpReturn, indexLastNode);
249 useLastBlock->getSequence()->push_back(returnLastNode);
250 }
251 TIntermIfElse *ifNode = new TIntermIfElse(cond, useFirstBlock, nullptr);
252 bodyNode->getSequence()->push_back(ifNode);
253 bodyNode->getSequence()->push_back(useLastBlock);
254
255 TIntermFunctionDefinition *indexingFunction =
256 new TIntermFunctionDefinition(prototypeNode, bodyNode);
257 return indexingFunction;
258 }
259
260 class RemoveDynamicIndexingTraverser : public TLValueTrackingTraverser
261 {
262 public:
263 RemoveDynamicIndexingTraverser(TSymbolTable *symbolTable,
264 PerformanceDiagnostics *perfDiagnostics);
265
266 bool visitBinary(Visit visit, TIntermBinary *node) override;
267
268 void insertHelperDefinitions(TIntermNode *root);
269
270 void nextIteration();
271
usedTreeInsertion() const272 bool usedTreeInsertion() const { return mUsedTreeInsertion; }
273
274 protected:
275 // Maps of types that are indexed to the indexing function ids used for them. Note that these
276 // can not store multiple variants of the same type with different precisions - only one
277 // precision gets stored.
278 std::map<TType, TFunction *> mIndexedVecAndMatrixTypes;
279 std::map<TType, TFunction *> mWrittenVecAndMatrixTypes;
280
281 bool mUsedTreeInsertion;
282
283 // When true, the traverser will remove side effects from any indexing expression.
284 // This is done so that in code like
285 // V[j++][i]++.
286 // where V is an array of vectors, j++ will only be evaluated once.
287 bool mRemoveIndexSideEffectsInSubtree;
288
289 PerformanceDiagnostics *mPerfDiagnostics;
290 };
291
RemoveDynamicIndexingTraverser(TSymbolTable * symbolTable,PerformanceDiagnostics * perfDiagnostics)292 RemoveDynamicIndexingTraverser::RemoveDynamicIndexingTraverser(
293 TSymbolTable *symbolTable,
294 PerformanceDiagnostics *perfDiagnostics)
295 : TLValueTrackingTraverser(true, false, false, symbolTable),
296 mUsedTreeInsertion(false),
297 mRemoveIndexSideEffectsInSubtree(false),
298 mPerfDiagnostics(perfDiagnostics)
299 {}
300
insertHelperDefinitions(TIntermNode * root)301 void RemoveDynamicIndexingTraverser::insertHelperDefinitions(TIntermNode *root)
302 {
303 TIntermBlock *rootBlock = root->getAsBlock();
304 ASSERT(rootBlock != nullptr);
305 TIntermSequence insertions;
306 for (auto &type : mIndexedVecAndMatrixTypes)
307 {
308 insertions.push_back(
309 GetIndexFunctionDefinition(type.first, false, *type.second, mSymbolTable));
310 }
311 for (auto &type : mWrittenVecAndMatrixTypes)
312 {
313 insertions.push_back(
314 GetIndexFunctionDefinition(type.first, true, *type.second, mSymbolTable));
315 }
316 rootBlock->insertChildNodes(0, insertions);
317 }
318
319 // Create a call to dyn_index_*() based on an indirect indexing op node
CreateIndexFunctionCall(TIntermBinary * node,TIntermTyped * index,TFunction * indexingFunction)320 TIntermAggregate *CreateIndexFunctionCall(TIntermBinary *node,
321 TIntermTyped *index,
322 TFunction *indexingFunction)
323 {
324 ASSERT(node->getOp() == EOpIndexIndirect);
325 TIntermSequence *arguments = new TIntermSequence();
326 arguments->push_back(node->getLeft());
327 arguments->push_back(index);
328
329 TIntermAggregate *indexingCall =
330 TIntermAggregate::CreateFunctionCall(*indexingFunction, arguments);
331 indexingCall->setLine(node->getLine());
332 return indexingCall;
333 }
334
CreateIndexedWriteFunctionCall(TIntermBinary * node,TVariable * index,TVariable * writtenValue,TFunction * indexedWriteFunction)335 TIntermAggregate *CreateIndexedWriteFunctionCall(TIntermBinary *node,
336 TVariable *index,
337 TVariable *writtenValue,
338 TFunction *indexedWriteFunction)
339 {
340 ASSERT(node->getOp() == EOpIndexIndirect);
341 TIntermSequence *arguments = new TIntermSequence();
342 // Deep copy the child nodes so that two pointers to the same node don't end up in the tree.
343 arguments->push_back(node->getLeft()->deepCopy());
344 arguments->push_back(CreateTempSymbolNode(index));
345 arguments->push_back(CreateTempSymbolNode(writtenValue));
346
347 TIntermAggregate *indexedWriteCall =
348 TIntermAggregate::CreateFunctionCall(*indexedWriteFunction, arguments);
349 indexedWriteCall->setLine(node->getLine());
350 return indexedWriteCall;
351 }
352
visitBinary(Visit visit,TIntermBinary * node)353 bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *node)
354 {
355 if (mUsedTreeInsertion)
356 return false;
357
358 if (node->getOp() == EOpIndexIndirect)
359 {
360 if (mRemoveIndexSideEffectsInSubtree)
361 {
362 ASSERT(node->getRight()->hasSideEffects());
363 // In case we're just removing index side effects, convert
364 // v_expr[index_expr]
365 // to this:
366 // int s0 = index_expr; v_expr[s0];
367 // Now v_expr[s0] can be safely executed several times without unintended side effects.
368 TIntermDeclaration *indexVariableDeclaration = nullptr;
369 TVariable *indexVariable = DeclareTempVariable(mSymbolTable, node->getRight(),
370 EvqTemporary, &indexVariableDeclaration);
371 insertStatementInParentBlock(indexVariableDeclaration);
372 mUsedTreeInsertion = true;
373
374 // Replace the index with the temp variable
375 TIntermSymbol *tempIndex = CreateTempSymbolNode(indexVariable);
376 queueReplacementWithParent(node, node->getRight(), tempIndex, OriginalNode::IS_DROPPED);
377 }
378 else if (IntermNodePatternMatcher::IsDynamicIndexingOfNonSSBOVectorOrMatrix(node))
379 {
380 mPerfDiagnostics->warning(node->getLine(),
381 "Performance: dynamic indexing of vectors and "
382 "matrices is emulated and can be slow.",
383 "[]");
384 bool write = isLValueRequiredHere();
385
386 #if defined(ANGLE_ENABLE_ASSERTS)
387 // Make sure that IntermNodePatternMatcher is consistent with the slightly differently
388 // implemented checks in this traverser.
389 IntermNodePatternMatcher matcher(
390 IntermNodePatternMatcher::kDynamicIndexingOfVectorOrMatrixInLValue);
391 ASSERT(matcher.match(node, getParentNode(), isLValueRequiredHere()) == write);
392 #endif
393
394 const TType &type = node->getLeft()->getType();
395 ImmutableString indexingFunctionName(GetIndexFunctionName(type, false));
396 TFunction *indexingFunction = nullptr;
397 if (mIndexedVecAndMatrixTypes.find(type) == mIndexedVecAndMatrixTypes.end())
398 {
399 indexingFunction =
400 new TFunction(mSymbolTable, indexingFunctionName, SymbolType::AngleInternal,
401 GetFieldType(type), true);
402 indexingFunction->addParameter(new TVariable(
403 mSymbolTable, kBaseName, GetBaseType(type, false), SymbolType::AngleInternal));
404 indexingFunction->addParameter(
405 new TVariable(mSymbolTable, kIndexName, kIndexType, SymbolType::AngleInternal));
406 mIndexedVecAndMatrixTypes[type] = indexingFunction;
407 }
408 else
409 {
410 indexingFunction = mIndexedVecAndMatrixTypes[type];
411 }
412
413 if (write)
414 {
415 // Convert:
416 // v_expr[index_expr]++;
417 // to this:
418 // int s0 = index_expr; float s1 = dyn_index(v_expr, s0); s1++;
419 // dyn_index_write(v_expr, s0, s1);
420 // This works even if index_expr has some side effects.
421 if (node->getLeft()->hasSideEffects())
422 {
423 // If v_expr has side effects, those need to be removed before proceeding.
424 // Otherwise the side effects of v_expr would be evaluated twice.
425 // The only case where an l-value can have side effects is when it is
426 // indexing. For example, it can be V[j++] where V is an array of vectors.
427 mRemoveIndexSideEffectsInSubtree = true;
428 return true;
429 }
430
431 TIntermBinary *leftBinary = node->getLeft()->getAsBinaryNode();
432 if (leftBinary != nullptr &&
433 IntermNodePatternMatcher::IsDynamicIndexingOfNonSSBOVectorOrMatrix(leftBinary))
434 {
435 // This is a case like:
436 // mat2 m;
437 // m[a][b]++;
438 // Process the child node m[a] first.
439 return true;
440 }
441
442 // TODO(oetuaho@nvidia.com): This is not optimal if the expression using the value
443 // only writes it and doesn't need the previous value. http://anglebug.com/1116
444
445 TFunction *indexedWriteFunction = nullptr;
446 if (mWrittenVecAndMatrixTypes.find(type) == mWrittenVecAndMatrixTypes.end())
447 {
448 ImmutableString functionName(
449 GetIndexFunctionName(node->getLeft()->getType(), true));
450 indexedWriteFunction =
451 new TFunction(mSymbolTable, functionName, SymbolType::AngleInternal,
452 StaticType::GetBasic<EbtVoid>(), false);
453 indexedWriteFunction->addParameter(new TVariable(mSymbolTable, kBaseName,
454 GetBaseType(type, true),
455 SymbolType::AngleInternal));
456 indexedWriteFunction->addParameter(new TVariable(
457 mSymbolTable, kIndexName, kIndexType, SymbolType::AngleInternal));
458 TType *valueType = GetFieldType(type);
459 valueType->setQualifier(EvqIn);
460 indexedWriteFunction->addParameter(new TVariable(
461 mSymbolTable, kValueName, static_cast<const TType *>(valueType),
462 SymbolType::AngleInternal));
463 mWrittenVecAndMatrixTypes[type] = indexedWriteFunction;
464 }
465 else
466 {
467 indexedWriteFunction = mWrittenVecAndMatrixTypes[type];
468 }
469
470 TIntermSequence insertionsBefore;
471 TIntermSequence insertionsAfter;
472
473 // Store the index in a temporary signed int variable.
474 // s0 = index_expr;
475 TIntermTyped *indexInitializer = EnsureSignedInt(node->getRight());
476 TIntermDeclaration *indexVariableDeclaration = nullptr;
477 TVariable *indexVariable = DeclareTempVariable(
478 mSymbolTable, indexInitializer, EvqTemporary, &indexVariableDeclaration);
479 insertionsBefore.push_back(indexVariableDeclaration);
480
481 // s1 = dyn_index(v_expr, s0);
482 TIntermAggregate *indexingCall = CreateIndexFunctionCall(
483 node, CreateTempSymbolNode(indexVariable), indexingFunction);
484 TIntermDeclaration *fieldVariableDeclaration = nullptr;
485 TVariable *fieldVariable = DeclareTempVariable(
486 mSymbolTable, indexingCall, EvqTemporary, &fieldVariableDeclaration);
487 insertionsBefore.push_back(fieldVariableDeclaration);
488
489 // dyn_index_write(v_expr, s0, s1);
490 TIntermAggregate *indexedWriteCall = CreateIndexedWriteFunctionCall(
491 node, indexVariable, fieldVariable, indexedWriteFunction);
492 insertionsAfter.push_back(indexedWriteCall);
493 insertStatementsInParentBlock(insertionsBefore, insertionsAfter);
494
495 // replace the node with s1
496 queueReplacement(CreateTempSymbolNode(fieldVariable), OriginalNode::IS_DROPPED);
497 mUsedTreeInsertion = true;
498 }
499 else
500 {
501 // The indexed value is not being written, so we can simply convert
502 // v_expr[index_expr]
503 // into
504 // dyn_index(v_expr, index_expr)
505 // If the index_expr is unsigned, we'll convert it to signed.
506 ASSERT(!mRemoveIndexSideEffectsInSubtree);
507 TIntermAggregate *indexingCall = CreateIndexFunctionCall(
508 node, EnsureSignedInt(node->getRight()), indexingFunction);
509 queueReplacement(indexingCall, OriginalNode::IS_DROPPED);
510 }
511 }
512 }
513 return !mUsedTreeInsertion;
514 }
515
nextIteration()516 void RemoveDynamicIndexingTraverser::nextIteration()
517 {
518 mUsedTreeInsertion = false;
519 mRemoveIndexSideEffectsInSubtree = false;
520 }
521
522 } // namespace
523
RemoveDynamicIndexing(TIntermNode * root,TSymbolTable * symbolTable,PerformanceDiagnostics * perfDiagnostics)524 void RemoveDynamicIndexing(TIntermNode *root,
525 TSymbolTable *symbolTable,
526 PerformanceDiagnostics *perfDiagnostics)
527 {
528 RemoveDynamicIndexingTraverser traverser(symbolTable, perfDiagnostics);
529 do
530 {
531 traverser.nextIteration();
532 root->traverse(&traverser);
533 traverser.updateTree();
534 } while (traverser.usedTreeInsertion());
535 // TODO(oetuaho@nvidia.com): It might be nicer to add the helper definitions also in the middle
536 // of traversal. Now the tree ends up in an inconsistent state in the middle, since there are
537 // function call nodes with no corresponding definition nodes. This needs special handling in
538 // TIntermLValueTrackingTraverser, and creates intricacies that are not easily apparent from a
539 // superficial reading of the code.
540 traverser.insertHelperDefinitions(root);
541 }
542
543 } // namespace sh
544