• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RemoveDynamicIndexing is an AST traverser to remove dynamic indexing of non-SSBO vectors and
7 // matrices, replacing them with calls to functions that choose which component to return or write.
8 // We don't need to consider dynamic indexing in SSBO since it can be directly as part of the offset
9 // of RWByteAddressBuffer.
10 //
11 
12 #include "compiler/translator/tree_ops/RemoveDynamicIndexing.h"
13 
14 #include "compiler/translator/Diagnostics.h"
15 #include "compiler/translator/InfoSink.h"
16 #include "compiler/translator/StaticType.h"
17 #include "compiler/translator/SymbolTable.h"
18 #include "compiler/translator/tree_util/IntermNodePatternMatcher.h"
19 #include "compiler/translator/tree_util/IntermNode_util.h"
20 #include "compiler/translator/tree_util/IntermTraverse.h"
21 
22 namespace sh
23 {
24 
25 namespace
26 {
27 
28 const TType *kIndexType = StaticType::Get<EbtInt, EbpHigh, EvqIn, 1, 1>();
29 
30 constexpr const ImmutableString kBaseName("base");
31 constexpr const ImmutableString kIndexName("index");
32 constexpr const ImmutableString kValueName("value");
33 
GetIndexFunctionName(const TType & type,bool write)34 std::string GetIndexFunctionName(const TType &type, bool write)
35 {
36     TInfoSinkBase nameSink;
37     nameSink << "dyn_index_";
38     if (write)
39     {
40         nameSink << "write_";
41     }
42     if (type.isMatrix())
43     {
44         nameSink << "mat" << type.getCols() << "x" << type.getRows();
45     }
46     else
47     {
48         switch (type.getBasicType())
49         {
50             case EbtInt:
51                 nameSink << "ivec";
52                 break;
53             case EbtBool:
54                 nameSink << "bvec";
55                 break;
56             case EbtUInt:
57                 nameSink << "uvec";
58                 break;
59             case EbtFloat:
60                 nameSink << "vec";
61                 break;
62             default:
63                 UNREACHABLE();
64         }
65         nameSink << type.getNominalSize();
66     }
67     return nameSink.str();
68 }
69 
CreateIntConstantNode(int i)70 TIntermConstantUnion *CreateIntConstantNode(int i)
71 {
72     TConstantUnion *constant = new TConstantUnion();
73     constant->setIConst(i);
74     return new TIntermConstantUnion(constant, TType(EbtInt, EbpHigh));
75 }
76 
EnsureSignedInt(TIntermTyped * node)77 TIntermTyped *EnsureSignedInt(TIntermTyped *node)
78 {
79     if (node->getBasicType() == EbtInt)
80         return node;
81 
82     TIntermSequence *arguments = new TIntermSequence();
83     arguments->push_back(node);
84     return TIntermAggregate::CreateConstructor(TType(EbtInt), arguments);
85 }
86 
GetFieldType(const TType & indexedType)87 TType *GetFieldType(const TType &indexedType)
88 {
89     if (indexedType.isMatrix())
90     {
91         TType *fieldType = new TType(indexedType.getBasicType(), indexedType.getPrecision());
92         fieldType->setPrimarySize(static_cast<unsigned char>(indexedType.getRows()));
93         return fieldType;
94     }
95     else
96     {
97         return new TType(indexedType.getBasicType(), indexedType.getPrecision());
98     }
99 }
100 
GetBaseType(const TType & type,bool write)101 const TType *GetBaseType(const TType &type, bool write)
102 {
103     TType *baseType = new TType(type);
104     // Conservatively use highp here, even if the indexed type is not highp. That way the code can't
105     // end up using mediump version of an indexing function for a highp value, if both mediump and
106     // highp values are being indexed in the shader. For HLSL precision doesn't matter, but in
107     // principle this code could be used with multiple backends.
108     baseType->setPrecision(EbpHigh);
109     baseType->setQualifier(EvqInOut);
110     if (!write)
111         baseType->setQualifier(EvqIn);
112     return baseType;
113 }
114 
115 // Generate a read or write function for one field in a vector/matrix.
116 // Out-of-range indices are clamped. This is consistent with how ANGLE handles out-of-range
117 // indices in other places.
118 // Note that indices can be either int or uint. We create only int versions of the functions,
119 // and convert uint indices to int at the call site.
120 // read function example:
121 // float dyn_index_vec2(in vec2 base, in int index)
122 // {
123 //    switch(index)
124 //    {
125 //      case (0):
126 //        return base[0];
127 //      case (1):
128 //        return base[1];
129 //      default:
130 //        break;
131 //    }
132 //    if (index < 0)
133 //      return base[0];
134 //    return base[1];
135 // }
136 // write function example:
137 // void dyn_index_write_vec2(inout vec2 base, in int index, in float value)
138 // {
139 //    switch(index)
140 //    {
141 //      case (0):
142 //        base[0] = value;
143 //        return;
144 //      case (1):
145 //        base[1] = value;
146 //        return;
147 //      default:
148 //        break;
149 //    }
150 //    if (index < 0)
151 //    {
152 //      base[0] = value;
153 //      return;
154 //    }
155 //    base[1] = value;
156 // }
157 // Note that else is not used in above functions to avoid the RewriteElseBlocks transformation.
GetIndexFunctionDefinition(const TType & type,bool write,const TFunction & func,TSymbolTable * symbolTable)158 TIntermFunctionDefinition *GetIndexFunctionDefinition(const TType &type,
159                                                       bool write,
160                                                       const TFunction &func,
161                                                       TSymbolTable *symbolTable)
162 {
163     ASSERT(!type.isArray());
164 
165     int numCases = 0;
166     if (type.isMatrix())
167     {
168         numCases = type.getCols();
169     }
170     else
171     {
172         numCases = type.getNominalSize();
173     }
174 
175     std::string functionName                = GetIndexFunctionName(type, write);
176     TIntermFunctionPrototype *prototypeNode = CreateInternalFunctionPrototypeNode(func);
177 
178     TIntermSymbol *baseParam  = new TIntermSymbol(func.getParam(0));
179     TIntermSymbol *indexParam = new TIntermSymbol(func.getParam(1));
180     TIntermSymbol *valueParam = nullptr;
181     if (write)
182     {
183         valueParam = new TIntermSymbol(func.getParam(2));
184     }
185 
186     TIntermBlock *statementList = new TIntermBlock();
187     for (int i = 0; i < numCases; ++i)
188     {
189         TIntermCase *caseNode = new TIntermCase(CreateIntConstantNode(i));
190         statementList->getSequence()->push_back(caseNode);
191 
192         TIntermBinary *indexNode =
193             new TIntermBinary(EOpIndexDirect, baseParam->deepCopy(), CreateIndexNode(i));
194         if (write)
195         {
196             TIntermBinary *assignNode =
197                 new TIntermBinary(EOpAssign, indexNode, valueParam->deepCopy());
198             statementList->getSequence()->push_back(assignNode);
199             TIntermBranch *returnNode = new TIntermBranch(EOpReturn, nullptr);
200             statementList->getSequence()->push_back(returnNode);
201         }
202         else
203         {
204             TIntermBranch *returnNode = new TIntermBranch(EOpReturn, indexNode);
205             statementList->getSequence()->push_back(returnNode);
206         }
207     }
208 
209     // Default case
210     TIntermCase *defaultNode = new TIntermCase(nullptr);
211     statementList->getSequence()->push_back(defaultNode);
212     TIntermBranch *breakNode = new TIntermBranch(EOpBreak, nullptr);
213     statementList->getSequence()->push_back(breakNode);
214 
215     TIntermSwitch *switchNode = new TIntermSwitch(indexParam->deepCopy(), statementList);
216 
217     TIntermBlock *bodyNode = new TIntermBlock();
218     bodyNode->getSequence()->push_back(switchNode);
219 
220     TIntermBinary *cond =
221         new TIntermBinary(EOpLessThan, indexParam->deepCopy(), CreateIntConstantNode(0));
222 
223     // Two blocks: one accesses (either reads or writes) the first element and returns,
224     // the other accesses the last element.
225     TIntermBlock *useFirstBlock = new TIntermBlock();
226     TIntermBlock *useLastBlock  = new TIntermBlock();
227     TIntermBinary *indexFirstNode =
228         new TIntermBinary(EOpIndexDirect, baseParam->deepCopy(), CreateIndexNode(0));
229     TIntermBinary *indexLastNode =
230         new TIntermBinary(EOpIndexDirect, baseParam->deepCopy(), CreateIndexNode(numCases - 1));
231     if (write)
232     {
233         TIntermBinary *assignFirstNode =
234             new TIntermBinary(EOpAssign, indexFirstNode, valueParam->deepCopy());
235         useFirstBlock->getSequence()->push_back(assignFirstNode);
236         TIntermBranch *returnNode = new TIntermBranch(EOpReturn, nullptr);
237         useFirstBlock->getSequence()->push_back(returnNode);
238 
239         TIntermBinary *assignLastNode =
240             new TIntermBinary(EOpAssign, indexLastNode, valueParam->deepCopy());
241         useLastBlock->getSequence()->push_back(assignLastNode);
242     }
243     else
244     {
245         TIntermBranch *returnFirstNode = new TIntermBranch(EOpReturn, indexFirstNode);
246         useFirstBlock->getSequence()->push_back(returnFirstNode);
247 
248         TIntermBranch *returnLastNode = new TIntermBranch(EOpReturn, indexLastNode);
249         useLastBlock->getSequence()->push_back(returnLastNode);
250     }
251     TIntermIfElse *ifNode = new TIntermIfElse(cond, useFirstBlock, nullptr);
252     bodyNode->getSequence()->push_back(ifNode);
253     bodyNode->getSequence()->push_back(useLastBlock);
254 
255     TIntermFunctionDefinition *indexingFunction =
256         new TIntermFunctionDefinition(prototypeNode, bodyNode);
257     return indexingFunction;
258 }
259 
260 class RemoveDynamicIndexingTraverser : public TLValueTrackingTraverser
261 {
262   public:
263     RemoveDynamicIndexingTraverser(TSymbolTable *symbolTable,
264                                    PerformanceDiagnostics *perfDiagnostics);
265 
266     bool visitBinary(Visit visit, TIntermBinary *node) override;
267 
268     void insertHelperDefinitions(TIntermNode *root);
269 
270     void nextIteration();
271 
usedTreeInsertion() const272     bool usedTreeInsertion() const { return mUsedTreeInsertion; }
273 
274   protected:
275     // Maps of types that are indexed to the indexing function ids used for them. Note that these
276     // can not store multiple variants of the same type with different precisions - only one
277     // precision gets stored.
278     std::map<TType, TFunction *> mIndexedVecAndMatrixTypes;
279     std::map<TType, TFunction *> mWrittenVecAndMatrixTypes;
280 
281     bool mUsedTreeInsertion;
282 
283     // When true, the traverser will remove side effects from any indexing expression.
284     // This is done so that in code like
285     //   V[j++][i]++.
286     // where V is an array of vectors, j++ will only be evaluated once.
287     bool mRemoveIndexSideEffectsInSubtree;
288 
289     PerformanceDiagnostics *mPerfDiagnostics;
290 };
291 
RemoveDynamicIndexingTraverser(TSymbolTable * symbolTable,PerformanceDiagnostics * perfDiagnostics)292 RemoveDynamicIndexingTraverser::RemoveDynamicIndexingTraverser(
293     TSymbolTable *symbolTable,
294     PerformanceDiagnostics *perfDiagnostics)
295     : TLValueTrackingTraverser(true, false, false, symbolTable),
296       mUsedTreeInsertion(false),
297       mRemoveIndexSideEffectsInSubtree(false),
298       mPerfDiagnostics(perfDiagnostics)
299 {}
300 
insertHelperDefinitions(TIntermNode * root)301 void RemoveDynamicIndexingTraverser::insertHelperDefinitions(TIntermNode *root)
302 {
303     TIntermBlock *rootBlock = root->getAsBlock();
304     ASSERT(rootBlock != nullptr);
305     TIntermSequence insertions;
306     for (auto &type : mIndexedVecAndMatrixTypes)
307     {
308         insertions.push_back(
309             GetIndexFunctionDefinition(type.first, false, *type.second, mSymbolTable));
310     }
311     for (auto &type : mWrittenVecAndMatrixTypes)
312     {
313         insertions.push_back(
314             GetIndexFunctionDefinition(type.first, true, *type.second, mSymbolTable));
315     }
316     rootBlock->insertChildNodes(0, insertions);
317 }
318 
319 // Create a call to dyn_index_*() based on an indirect indexing op node
CreateIndexFunctionCall(TIntermBinary * node,TIntermTyped * index,TFunction * indexingFunction)320 TIntermAggregate *CreateIndexFunctionCall(TIntermBinary *node,
321                                           TIntermTyped *index,
322                                           TFunction *indexingFunction)
323 {
324     ASSERT(node->getOp() == EOpIndexIndirect);
325     TIntermSequence *arguments = new TIntermSequence();
326     arguments->push_back(node->getLeft());
327     arguments->push_back(index);
328 
329     TIntermAggregate *indexingCall =
330         TIntermAggregate::CreateFunctionCall(*indexingFunction, arguments);
331     indexingCall->setLine(node->getLine());
332     return indexingCall;
333 }
334 
CreateIndexedWriteFunctionCall(TIntermBinary * node,TVariable * index,TVariable * writtenValue,TFunction * indexedWriteFunction)335 TIntermAggregate *CreateIndexedWriteFunctionCall(TIntermBinary *node,
336                                                  TVariable *index,
337                                                  TVariable *writtenValue,
338                                                  TFunction *indexedWriteFunction)
339 {
340     ASSERT(node->getOp() == EOpIndexIndirect);
341     TIntermSequence *arguments = new TIntermSequence();
342     // Deep copy the child nodes so that two pointers to the same node don't end up in the tree.
343     arguments->push_back(node->getLeft()->deepCopy());
344     arguments->push_back(CreateTempSymbolNode(index));
345     arguments->push_back(CreateTempSymbolNode(writtenValue));
346 
347     TIntermAggregate *indexedWriteCall =
348         TIntermAggregate::CreateFunctionCall(*indexedWriteFunction, arguments);
349     indexedWriteCall->setLine(node->getLine());
350     return indexedWriteCall;
351 }
352 
visitBinary(Visit visit,TIntermBinary * node)353 bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *node)
354 {
355     if (mUsedTreeInsertion)
356         return false;
357 
358     if (node->getOp() == EOpIndexIndirect)
359     {
360         if (mRemoveIndexSideEffectsInSubtree)
361         {
362             ASSERT(node->getRight()->hasSideEffects());
363             // In case we're just removing index side effects, convert
364             //   v_expr[index_expr]
365             // to this:
366             //   int s0 = index_expr; v_expr[s0];
367             // Now v_expr[s0] can be safely executed several times without unintended side effects.
368             TIntermDeclaration *indexVariableDeclaration = nullptr;
369             TVariable *indexVariable = DeclareTempVariable(mSymbolTable, node->getRight(),
370                                                            EvqTemporary, &indexVariableDeclaration);
371             insertStatementInParentBlock(indexVariableDeclaration);
372             mUsedTreeInsertion = true;
373 
374             // Replace the index with the temp variable
375             TIntermSymbol *tempIndex = CreateTempSymbolNode(indexVariable);
376             queueReplacementWithParent(node, node->getRight(), tempIndex, OriginalNode::IS_DROPPED);
377         }
378         else if (IntermNodePatternMatcher::IsDynamicIndexingOfNonSSBOVectorOrMatrix(node))
379         {
380             mPerfDiagnostics->warning(node->getLine(),
381                                       "Performance: dynamic indexing of vectors and "
382                                       "matrices is emulated and can be slow.",
383                                       "[]");
384             bool write = isLValueRequiredHere();
385 
386 #if defined(ANGLE_ENABLE_ASSERTS)
387             // Make sure that IntermNodePatternMatcher is consistent with the slightly differently
388             // implemented checks in this traverser.
389             IntermNodePatternMatcher matcher(
390                 IntermNodePatternMatcher::kDynamicIndexingOfVectorOrMatrixInLValue);
391             ASSERT(matcher.match(node, getParentNode(), isLValueRequiredHere()) == write);
392 #endif
393 
394             const TType &type = node->getLeft()->getType();
395             ImmutableString indexingFunctionName(GetIndexFunctionName(type, false));
396             TFunction *indexingFunction = nullptr;
397             if (mIndexedVecAndMatrixTypes.find(type) == mIndexedVecAndMatrixTypes.end())
398             {
399                 indexingFunction =
400                     new TFunction(mSymbolTable, indexingFunctionName, SymbolType::AngleInternal,
401                                   GetFieldType(type), true);
402                 indexingFunction->addParameter(new TVariable(
403                     mSymbolTable, kBaseName, GetBaseType(type, false), SymbolType::AngleInternal));
404                 indexingFunction->addParameter(
405                     new TVariable(mSymbolTable, kIndexName, kIndexType, SymbolType::AngleInternal));
406                 mIndexedVecAndMatrixTypes[type] = indexingFunction;
407             }
408             else
409             {
410                 indexingFunction = mIndexedVecAndMatrixTypes[type];
411             }
412 
413             if (write)
414             {
415                 // Convert:
416                 //   v_expr[index_expr]++;
417                 // to this:
418                 //   int s0 = index_expr; float s1 = dyn_index(v_expr, s0); s1++;
419                 //   dyn_index_write(v_expr, s0, s1);
420                 // This works even if index_expr has some side effects.
421                 if (node->getLeft()->hasSideEffects())
422                 {
423                     // If v_expr has side effects, those need to be removed before proceeding.
424                     // Otherwise the side effects of v_expr would be evaluated twice.
425                     // The only case where an l-value can have side effects is when it is
426                     // indexing. For example, it can be V[j++] where V is an array of vectors.
427                     mRemoveIndexSideEffectsInSubtree = true;
428                     return true;
429                 }
430 
431                 TIntermBinary *leftBinary = node->getLeft()->getAsBinaryNode();
432                 if (leftBinary != nullptr &&
433                     IntermNodePatternMatcher::IsDynamicIndexingOfNonSSBOVectorOrMatrix(leftBinary))
434                 {
435                     // This is a case like:
436                     // mat2 m;
437                     // m[a][b]++;
438                     // Process the child node m[a] first.
439                     return true;
440                 }
441 
442                 // TODO(oetuaho@nvidia.com): This is not optimal if the expression using the value
443                 // only writes it and doesn't need the previous value. http://anglebug.com/1116
444 
445                 TFunction *indexedWriteFunction = nullptr;
446                 if (mWrittenVecAndMatrixTypes.find(type) == mWrittenVecAndMatrixTypes.end())
447                 {
448                     ImmutableString functionName(
449                         GetIndexFunctionName(node->getLeft()->getType(), true));
450                     indexedWriteFunction =
451                         new TFunction(mSymbolTable, functionName, SymbolType::AngleInternal,
452                                       StaticType::GetBasic<EbtVoid>(), false);
453                     indexedWriteFunction->addParameter(new TVariable(mSymbolTable, kBaseName,
454                                                                      GetBaseType(type, true),
455                                                                      SymbolType::AngleInternal));
456                     indexedWriteFunction->addParameter(new TVariable(
457                         mSymbolTable, kIndexName, kIndexType, SymbolType::AngleInternal));
458                     TType *valueType = GetFieldType(type);
459                     valueType->setQualifier(EvqIn);
460                     indexedWriteFunction->addParameter(new TVariable(
461                         mSymbolTable, kValueName, static_cast<const TType *>(valueType),
462                         SymbolType::AngleInternal));
463                     mWrittenVecAndMatrixTypes[type] = indexedWriteFunction;
464                 }
465                 else
466                 {
467                     indexedWriteFunction = mWrittenVecAndMatrixTypes[type];
468                 }
469 
470                 TIntermSequence insertionsBefore;
471                 TIntermSequence insertionsAfter;
472 
473                 // Store the index in a temporary signed int variable.
474                 // s0 = index_expr;
475                 TIntermTyped *indexInitializer               = EnsureSignedInt(node->getRight());
476                 TIntermDeclaration *indexVariableDeclaration = nullptr;
477                 TVariable *indexVariable                     = DeclareTempVariable(
478                     mSymbolTable, indexInitializer, EvqTemporary, &indexVariableDeclaration);
479                 insertionsBefore.push_back(indexVariableDeclaration);
480 
481                 // s1 = dyn_index(v_expr, s0);
482                 TIntermAggregate *indexingCall = CreateIndexFunctionCall(
483                     node, CreateTempSymbolNode(indexVariable), indexingFunction);
484                 TIntermDeclaration *fieldVariableDeclaration = nullptr;
485                 TVariable *fieldVariable                     = DeclareTempVariable(
486                     mSymbolTable, indexingCall, EvqTemporary, &fieldVariableDeclaration);
487                 insertionsBefore.push_back(fieldVariableDeclaration);
488 
489                 // dyn_index_write(v_expr, s0, s1);
490                 TIntermAggregate *indexedWriteCall = CreateIndexedWriteFunctionCall(
491                     node, indexVariable, fieldVariable, indexedWriteFunction);
492                 insertionsAfter.push_back(indexedWriteCall);
493                 insertStatementsInParentBlock(insertionsBefore, insertionsAfter);
494 
495                 // replace the node with s1
496                 queueReplacement(CreateTempSymbolNode(fieldVariable), OriginalNode::IS_DROPPED);
497                 mUsedTreeInsertion = true;
498             }
499             else
500             {
501                 // The indexed value is not being written, so we can simply convert
502                 //   v_expr[index_expr]
503                 // into
504                 //   dyn_index(v_expr, index_expr)
505                 // If the index_expr is unsigned, we'll convert it to signed.
506                 ASSERT(!mRemoveIndexSideEffectsInSubtree);
507                 TIntermAggregate *indexingCall = CreateIndexFunctionCall(
508                     node, EnsureSignedInt(node->getRight()), indexingFunction);
509                 queueReplacement(indexingCall, OriginalNode::IS_DROPPED);
510             }
511         }
512     }
513     return !mUsedTreeInsertion;
514 }
515 
nextIteration()516 void RemoveDynamicIndexingTraverser::nextIteration()
517 {
518     mUsedTreeInsertion               = false;
519     mRemoveIndexSideEffectsInSubtree = false;
520 }
521 
522 }  // namespace
523 
RemoveDynamicIndexing(TIntermNode * root,TSymbolTable * symbolTable,PerformanceDiagnostics * perfDiagnostics)524 void RemoveDynamicIndexing(TIntermNode *root,
525                            TSymbolTable *symbolTable,
526                            PerformanceDiagnostics *perfDiagnostics)
527 {
528     RemoveDynamicIndexingTraverser traverser(symbolTable, perfDiagnostics);
529     do
530     {
531         traverser.nextIteration();
532         root->traverse(&traverser);
533         traverser.updateTree();
534     } while (traverser.usedTreeInsertion());
535     // TODO(oetuaho@nvidia.com): It might be nicer to add the helper definitions also in the middle
536     // of traversal. Now the tree ends up in an inconsistent state in the middle, since there are
537     // function call nodes with no corresponding definition nodes. This needs special handling in
538     // TIntermLValueTrackingTraverser, and creates intricacies that are not easily apparent from a
539     // superficial reading of the code.
540     traverser.insertHelperDefinitions(root);
541 }
542 
543 }  // namespace sh
544