• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2019 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RewriteAtomicCounters: Emulate atomic counter buffers with storage buffers.
7 //
8 
9 #include "compiler/translator/tree_ops/RewriteAtomicCounters.h"
10 
11 #include "compiler/translator/Compiler.h"
12 #include "compiler/translator/ImmutableStringBuilder.h"
13 #include "compiler/translator/StaticType.h"
14 #include "compiler/translator/SymbolTable.h"
15 #include "compiler/translator/tree_util/IntermNode_util.h"
16 #include "compiler/translator/tree_util/IntermTraverse.h"
17 #include "compiler/translator/tree_util/ReplaceVariable.h"
18 
19 namespace sh
20 {
21 namespace
22 {
23 constexpr ImmutableString kAtomicCounterTypeName  = ImmutableString("ANGLE_atomic_uint");
24 constexpr ImmutableString kAtomicCountersVarName  = ImmutableString("atomicCounters");
25 constexpr ImmutableString kAtomicCounterFieldName = ImmutableString("counters");
26 
27 // DeclareAtomicCountersBuffer adds a storage buffer array that's used with atomic counters.
DeclareAtomicCountersBuffers(TIntermBlock * root,TSymbolTable * symbolTable)28 const TVariable *DeclareAtomicCountersBuffers(TIntermBlock *root, TSymbolTable *symbolTable)
29 {
30     // Define `uint counters[];` as the only field in the interface block.
31     TFieldList *fieldList = new TFieldList;
32     TType *counterType    = new TType(EbtUInt);
33     counterType->makeArray(0);
34 
35     TField *countersField =
36         new TField(counterType, kAtomicCounterFieldName, TSourceLoc(), SymbolType::AngleInternal);
37 
38     fieldList->push_back(countersField);
39 
40     TMemoryQualifier coherentMemory = TMemoryQualifier::Create();
41     coherentMemory.coherent         = true;
42 
43     // There are a maximum of 8 atomic counter buffers per IMPLEMENTATION_MAX_ATOMIC_COUNTER_BUFFERS
44     // in libANGLE/Constants.h.
45     constexpr uint32_t kMaxAtomicCounterBuffers = 8;
46 
47     // Define a storage block "ANGLEAtomicCounters" with instance name "atomicCounters".
48     return DeclareInterfaceBlock(
49         root, symbolTable, fieldList, EvqBuffer, coherentMemory, kMaxAtomicCounterBuffers,
50         ImmutableString(vk::kAtomicCountersBlockName), kAtomicCountersVarName);
51 }
52 
CreateUIntConstant(uint32_t value)53 TIntermConstantUnion *CreateUIntConstant(uint32_t value)
54 {
55     TType *constantType = new TType(*StaticType::GetBasic<EbtUInt, 1>());
56     constantType->setQualifier(EvqConst);
57 
58     TConstantUnion *constantValue = new TConstantUnion;
59     constantValue->setUConst(value);
60     return new TIntermConstantUnion(constantValue, *constantType);
61 }
62 
CreateAtomicCounterConstant(TType * atomicCounterType,uint32_t binding,uint32_t offset)63 TIntermTyped *CreateAtomicCounterConstant(TType *atomicCounterType,
64                                           uint32_t binding,
65                                           uint32_t offset)
66 {
67     ASSERT(atomicCounterType->getBasicType() == EbtStruct);
68 
69     TIntermSequence *arguments = new TIntermSequence();
70     arguments->push_back(CreateUIntConstant(binding));
71     arguments->push_back(CreateUIntConstant(offset));
72 
73     return TIntermAggregate::CreateConstructor(*atomicCounterType, arguments);
74 }
75 
CreateAtomicCounterRef(const TVariable * atomicCounters,const TIntermTyped * bindingOffset,const TIntermTyped * bufferOffsets)76 TIntermBinary *CreateAtomicCounterRef(const TVariable *atomicCounters,
77                                       const TIntermTyped *bindingOffset,
78                                       const TIntermTyped *bufferOffsets)
79 {
80     // The atomic counters storage buffer declaration looks as such:
81     //
82     // layout(...) buffer ANGLEAtomicCounters
83     // {
84     //     uint counters[];
85     // } atomicCounters[N];
86     //
87     // Where N is large enough to accommodate atomic counter buffer bindings used in the shader.
88     //
89     // Given an ANGLEAtomicCounter variable (which is a struct of {binding, offset}), we need to
90     // return:
91     //
92     // atomicCounters[binding].counters[offset]
93     //
94     // The offset itself is the provided one plus an offset given through uniforms.
95 
96     TIntermSymbol *atomicCountersRef = new TIntermSymbol(atomicCounters);
97 
98     TIntermConstantUnion *bindingFieldRef  = CreateIndexNode(0);
99     TIntermConstantUnion *offsetFieldRef   = CreateIndexNode(1);
100     TIntermConstantUnion *countersFieldRef = CreateIndexNode(0);
101 
102     // Create references to bindingOffset.binding and bindingOffset.offset.
103     TIntermBinary *binding =
104         new TIntermBinary(EOpIndexDirectStruct, bindingOffset->deepCopy(), bindingFieldRef);
105     TIntermBinary *offset =
106         new TIntermBinary(EOpIndexDirectStruct, bindingOffset->deepCopy(), offsetFieldRef);
107 
108     // Create reference to atomicCounters[bindingOffset.binding]
109     TIntermBinary *countersBlock = new TIntermBinary(EOpIndexDirect, atomicCountersRef, binding);
110 
111     // Create reference to atomicCounters[bindingOffset.binding].counters
112     TIntermBinary *counters =
113         new TIntermBinary(EOpIndexDirectInterfaceBlock, countersBlock, countersFieldRef);
114 
115     // Create bufferOffsets[binding / 4].  Each uint in bufferOffsets contains offsets for 4
116     // bindings.
117     TIntermBinary *bindingDivFour =
118         new TIntermBinary(EOpDiv, binding->deepCopy(), CreateUIntConstant(4));
119     TIntermBinary *bufferOffsetUint =
120         new TIntermBinary(EOpIndexDirect, bufferOffsets->deepCopy(), bindingDivFour);
121 
122     // Create (binding % 4) * 8
123     TIntermBinary *bindingModFour =
124         new TIntermBinary(EOpIMod, binding->deepCopy(), CreateUIntConstant(4));
125     TIntermBinary *bufferOffsetShift =
126         new TIntermBinary(EOpMul, bindingModFour, CreateUIntConstant(8));
127 
128     // Create bufferOffsets[binding / 4] >> ((binding % 4) * 8) & 0xFF
129     TIntermBinary *bufferOffsetShifted =
130         new TIntermBinary(EOpBitShiftRight, bufferOffsetUint, bufferOffsetShift);
131     TIntermBinary *bufferOffset =
132         new TIntermBinary(EOpBitwiseAnd, bufferOffsetShifted, CreateUIntConstant(0xFF));
133 
134     // return atomicCounters[bindingOffset.binding].counters[bindingOffset.offset + bufferOffset]
135     offset = new TIntermBinary(EOpAdd, offset, bufferOffset);
136     return new TIntermBinary(EOpIndexDirect, counters, offset);
137 }
138 
139 // Traverser that:
140 //
141 // 1. Converts the |atomic_uint| types to |{uint,uint}| for binding and offset.
142 // 2. Substitutes the |uniform atomic_uint| declarations with a global declaration that holds the
143 //    binding and offset.
144 // 3. Substitutes |atomicVar[n]| with |buffer[binding].counters[offset + n]|.
145 class RewriteAtomicCountersTraverser : public TIntermTraverser
146 {
147   public:
RewriteAtomicCountersTraverser(TSymbolTable * symbolTable,const TVariable * atomicCounters,const TIntermTyped * acbBufferOffsets)148     RewriteAtomicCountersTraverser(TSymbolTable *symbolTable,
149                                    const TVariable *atomicCounters,
150                                    const TIntermTyped *acbBufferOffsets)
151         : TIntermTraverser(true, true, true, symbolTable),
152           mAtomicCounters(atomicCounters),
153           mAcbBufferOffsets(acbBufferOffsets),
154           mAtomicCounterType(nullptr),
155           mAtomicCounterTypeConst(nullptr),
156           mAtomicCounterTypeDeclaration(nullptr)
157     {}
158 
visitDeclaration(Visit visit,TIntermDeclaration * node)159     bool visitDeclaration(Visit visit, TIntermDeclaration *node) override
160     {
161         if (visit != PreVisit)
162         {
163             return true;
164         }
165 
166         const TIntermSequence &sequence = *(node->getSequence());
167 
168         TIntermTyped *variable = sequence.front()->getAsTyped();
169         const TType &type      = variable->getType();
170         bool isAtomicCounter   = type.getQualifier() == EvqUniform && type.isAtomicCounter();
171 
172         if (isAtomicCounter)
173         {
174             // Atomic counters cannot have initializers, so the declaration must necessarily be a
175             // symbol.
176             TIntermSymbol *samplerVariable = variable->getAsSymbolNode();
177             ASSERT(samplerVariable != nullptr);
178 
179             declareAtomicCounter(&samplerVariable->variable(), node);
180             return false;
181         }
182 
183         return true;
184     }
185 
visitFunctionPrototype(TIntermFunctionPrototype * node)186     void visitFunctionPrototype(TIntermFunctionPrototype *node) override
187     {
188         const TFunction *function = node->getFunction();
189         // Go over the parameters and replace the atomic arguments with a uint type.
190         mRetyper.visitFunctionPrototype();
191         for (size_t paramIndex = 0; paramIndex < function->getParamCount(); ++paramIndex)
192         {
193             const TVariable *param = function->getParam(paramIndex);
194             TVariable *replacement = convertFunctionParameter(node, param);
195             if (replacement)
196             {
197                 mRetyper.replaceFunctionParam(param, replacement);
198             }
199         }
200 
201         TIntermFunctionPrototype *replacementPrototype =
202             mRetyper.convertFunctionPrototype(mSymbolTable, function);
203         if (replacementPrototype)
204         {
205             queueReplacement(replacementPrototype, OriginalNode::IS_DROPPED);
206         }
207     }
208 
visitAggregate(Visit visit,TIntermAggregate * node)209     bool visitAggregate(Visit visit, TIntermAggregate *node) override
210     {
211         if (visit == PreVisit)
212         {
213             mRetyper.preVisitAggregate();
214         }
215 
216         if (visit != PostVisit)
217         {
218             return true;
219         }
220 
221         if (node->getOp() == EOpCallBuiltInFunction)
222         {
223             convertBuiltinFunction(node);
224         }
225         else if (node->getOp() == EOpCallFunctionInAST)
226         {
227             TIntermAggregate *substituteCall = mRetyper.convertASTFunction(node);
228             if (substituteCall)
229             {
230                 queueReplacement(substituteCall, OriginalNode::IS_DROPPED);
231             }
232         }
233         mRetyper.postVisitAggregate();
234 
235         return true;
236     }
237 
visitSymbol(TIntermSymbol * symbol)238     void visitSymbol(TIntermSymbol *symbol) override
239     {
240         const TVariable *symbolVariable = &symbol->variable();
241 
242         if (!symbol->getType().isAtomicCounter())
243         {
244             return;
245         }
246 
247         // The symbol is either referencing a global atomic counter, or is a function parameter.  In
248         // either case, it could be an array.  The are the following possibilities:
249         //
250         //     layout(..) uniform atomic_uint ac;
251         //     layout(..) uniform atomic_uint acArray[N];
252         //
253         //     void func(inout atomic_uint c)
254         //     {
255         //         otherFunc(c);
256         //     }
257         //
258         //     void funcArray(inout atomic_uint cArray[N])
259         //     {
260         //         otherFuncArray(cArray);
261         //         otherFunc(cArray[n]);
262         //     }
263         //
264         //     void funcGlobal()
265         //     {
266         //         func(ac);
267         //         func(acArray[n]);
268         //         funcArray(acArray);
269         //         atomicIncrement(ac);
270         //         atomicIncrement(acArray[n]);
271         //     }
272         //
273         // This should translate to:
274         //
275         //     buffer ANGLEAtomicCounters
276         //     {
277         //         uint counters[];
278         //     } atomicCounters;
279         //
280         //     struct ANGLEAtomicCounter
281         //     {
282         //         uint binding;
283         //         uint offset;
284         //     };
285         //     const ANGLEAtomicCounter ac = {<binding>, <offset>};
286         //     const ANGLEAtomicCounter acArray = {<binding>, <offset>};
287         //
288         //     void func(inout ANGLEAtomicCounter c)
289         //     {
290         //         otherFunc(c);
291         //     }
292         //
293         //     void funcArray(inout uint cArray)
294         //     {
295         //         otherFuncArray(cArray);
296         //         otherFunc({cArray.binding, cArray.offset + n});
297         //     }
298         //
299         //     void funcGlobal()
300         //     {
301         //         func(ac);
302         //         func(acArray+n);
303         //         funcArray(acArray);
304         //         atomicAdd(atomicCounters[ac.binding]counters[ac.offset]);
305         //         atomicAdd(atomicCounters[ac.binding]counters[ac.offset+n]);
306         //     }
307         //
308         // In all cases, the argument transformation is stored in mRetyper.  In the function call's
309         // PostVisit, if it's a builtin, the look up in |atomicCounters.counters| is done as well as
310         // the builtin function change.  Otherwise, the transformed argument is passed on as is.
311         //
312 
313         TIntermTyped *bindingOffset =
314             new TIntermSymbol(mRetyper.getVariableReplacement(symbolVariable));
315         ASSERT(bindingOffset != nullptr);
316 
317         TIntermNode *argument = convertFunctionArgument(symbol, &bindingOffset);
318 
319         if (mRetyper.isInAggregate())
320         {
321             mRetyper.replaceFunctionCallArg(argument, bindingOffset);
322         }
323         else
324         {
325             // If there's a stray ac[i] lying around, just delete it.  This can happen if the shader
326             // uses ac[i].length(), which in RemoveArrayLengthMethod() will result in an ineffective
327             // statement that's just ac[i]; (similarly for a stray ac;, it doesn't have to be
328             // subscripted).  Note that the subscript could have side effects, but the
329             // convertFunctionArgument above has already generated code that includes the subscript
330             // (and therefore its side-effect).
331             TIntermBlock *block = nullptr;
332             for (uint32_t ancestorIndex = 0; block == nullptr; ++ancestorIndex)
333             {
334                 block = getAncestorNode(ancestorIndex)->getAsBlock();
335             }
336 
337             TIntermSequence emptySequence;
338             mMultiReplacements.emplace_back(block, argument, emptySequence);
339         }
340     }
341 
getAtomicCounterTypeDeclaration()342     TIntermDeclaration *getAtomicCounterTypeDeclaration() { return mAtomicCounterTypeDeclaration; }
343 
344   private:
declareAtomicCounter(const TVariable * atomicCounterVar,TIntermDeclaration * node)345     void declareAtomicCounter(const TVariable *atomicCounterVar, TIntermDeclaration *node)
346     {
347         // Create a global variable that contains the binding and offset of this atomic counter
348         // declaration.
349         if (mAtomicCounterType == nullptr)
350         {
351             declareAtomicCounterType();
352         }
353         ASSERT(mAtomicCounterTypeConst);
354 
355         TVariable *bindingOffset = new TVariable(mSymbolTable, atomicCounterVar->name(),
356                                                  mAtomicCounterTypeConst, SymbolType::UserDefined);
357 
358         const TType &atomicCounterType = atomicCounterVar->getType();
359         uint32_t offset                = atomicCounterType.getLayoutQualifier().offset;
360         uint32_t binding               = atomicCounterType.getLayoutQualifier().binding;
361 
362         ASSERT(offset % 4 == 0);
363         TIntermTyped *bindingOffsetInitValue =
364             CreateAtomicCounterConstant(mAtomicCounterTypeConst, binding, offset / 4);
365 
366         TIntermSymbol *bindingOffsetSymbol = new TIntermSymbol(bindingOffset);
367         TIntermBinary *bindingOffsetInit =
368             new TIntermBinary(EOpInitialize, bindingOffsetSymbol, bindingOffsetInitValue);
369 
370         TIntermDeclaration *bindingOffsetDeclaration = new TIntermDeclaration();
371         bindingOffsetDeclaration->appendDeclarator(bindingOffsetInit);
372 
373         // Replace the atomic_uint declaration with the binding/offset declaration.
374         TIntermSequence replacement;
375         replacement.push_back(bindingOffsetDeclaration);
376         mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node, replacement);
377 
378         // Remember the binding/offset variable.
379         mRetyper.replaceGlobalVariable(atomicCounterVar, bindingOffset);
380     }
381 
declareAtomicCounterType()382     void declareAtomicCounterType()
383     {
384         ASSERT(mAtomicCounterType == nullptr);
385 
386         TFieldList *fields = new TFieldList();
387         fields->push_back(new TField(new TType(EbtUInt, EbpUndefined, EvqGlobal, 1, 1),
388                                      ImmutableString("binding"), TSourceLoc(),
389                                      SymbolType::AngleInternal));
390         fields->push_back(new TField(new TType(EbtUInt, EbpUndefined, EvqGlobal, 1, 1),
391                                      ImmutableString("arrayIndex"), TSourceLoc(),
392                                      SymbolType::AngleInternal));
393         TStructure *atomicCounterTypeStruct =
394             new TStructure(mSymbolTable, kAtomicCounterTypeName, fields, SymbolType::AngleInternal);
395         mAtomicCounterType = new TType(atomicCounterTypeStruct, false);
396 
397         mAtomicCounterTypeDeclaration = new TIntermDeclaration;
398         TVariable *emptyVariable      = new TVariable(mSymbolTable, kEmptyImmutableString,
399                                                  mAtomicCounterType, SymbolType::Empty);
400         mAtomicCounterTypeDeclaration->appendDeclarator(new TIntermSymbol(emptyVariable));
401 
402         // Keep a const variant around as well.
403         mAtomicCounterTypeConst = new TType(*mAtomicCounterType);
404         mAtomicCounterTypeConst->setQualifier(EvqConst);
405     }
406 
convertFunctionParameter(TIntermNode * parent,const TVariable * param)407     TVariable *convertFunctionParameter(TIntermNode *parent, const TVariable *param)
408     {
409         if (!param->getType().isAtomicCounter())
410         {
411             return nullptr;
412         }
413         if (mAtomicCounterType == nullptr)
414         {
415             declareAtomicCounterType();
416         }
417 
418         const TType *paramType = &param->getType();
419         TType *newType =
420             paramType->getQualifier() == EvqConst ? mAtomicCounterTypeConst : mAtomicCounterType;
421 
422         TVariable *replacementVar =
423             new TVariable(mSymbolTable, param->name(), newType, SymbolType::UserDefined);
424 
425         return replacementVar;
426     }
427 
convertFunctionArgumentHelper(const TVector<unsigned int> & runningArraySizeProducts,TIntermTyped * flattenedSubscript,uint32_t depth,uint32_t * subscriptCountOut)428     TIntermTyped *convertFunctionArgumentHelper(
429         const TVector<unsigned int> &runningArraySizeProducts,
430         TIntermTyped *flattenedSubscript,
431         uint32_t depth,
432         uint32_t *subscriptCountOut)
433     {
434         std::string prefix(depth, ' ');
435         TIntermNode *parent = getAncestorNode(depth);
436         ASSERT(parent);
437 
438         TIntermBinary *arrayExpression = parent->getAsBinaryNode();
439         if (!arrayExpression)
440         {
441             // If the parent is not an array subscript operation, we have reached the end of the
442             // subscript chain.  Note the depth that's traversed so the corresponding node can be
443             // taken as the function argument.
444             *subscriptCountOut = depth;
445             return flattenedSubscript;
446         }
447 
448         ASSERT(arrayExpression->getOp() == EOpIndexDirect ||
449                arrayExpression->getOp() == EOpIndexIndirect);
450 
451         // Assume i = n - depth.  Get Pi.  See comment in convertFunctionArgument.
452         ASSERT(depth < runningArraySizeProducts.size());
453         uint32_t thisDimensionSize =
454             runningArraySizeProducts[runningArraySizeProducts.size() - 1 - depth];
455 
456         // Get Ii.
457         TIntermTyped *thisDimensionOffset = arrayExpression->getRight();
458 
459         TIntermConstantUnion *subscriptAsConstant = thisDimensionOffset->getAsConstantUnion();
460         const bool subscriptIsZero = subscriptAsConstant && subscriptAsConstant->isZero(0);
461 
462         // If Ii is zero, don't need to add Ii*Pi; that's zero.
463         if (!subscriptIsZero)
464         {
465             thisDimensionOffset = thisDimensionOffset->deepCopy();
466 
467             // If Pi is 1, don't multiply.  Just accumulate Ii.
468             if (thisDimensionSize != 1)
469             {
470                 thisDimensionOffset = new TIntermBinary(EOpMul, thisDimensionOffset,
471                                                         CreateUIntConstant(thisDimensionSize));
472             }
473 
474             // Accumulate with the previous running offset, if any.
475             if (flattenedSubscript)
476             {
477                 flattenedSubscript =
478                     new TIntermBinary(EOpAdd, flattenedSubscript, thisDimensionOffset);
479             }
480             else
481             {
482                 flattenedSubscript = thisDimensionOffset;
483             }
484         }
485 
486         // Note: GLSL only allows 2 nested levels of arrays, so this recursion is bounded.
487         return convertFunctionArgumentHelper(runningArraySizeProducts, flattenedSubscript,
488                                              depth + 1, subscriptCountOut);
489     }
490 
convertFunctionArgument(TIntermNode * symbol,TIntermTyped ** bindingOffset)491     TIntermNode *convertFunctionArgument(TIntermNode *symbol, TIntermTyped **bindingOffset)
492     {
493         // Assume a general case of array declaration with N dimensions:
494         //
495         //     atomic_uint ac[Dn]..[D2][D1];
496         //
497         // Let's define
498         //
499         //     Pn = D(n-1)*...*D2*D1
500         //
501         // In that case, we have:
502         //
503         //     ac[In]         = ac + In*Pn
504         //     ac[In][I(n-1)] = ac + In*Pn + I(n-1)*P(n-1)
505         //     ac[In]...[Ii]  = ac + In*Pn + ... + Ii*Pi
506         //
507         // We have just visited a symbol; ac.  Walking the parent chain, we will visit the
508         // expressions in the above order (ac, ac[In], ac[In][I(n-1)], ...).  We therefore can
509         // simply walk the parent chain and accumulate Ii*Pi to obtain the offset from the base of
510         // ac.
511 
512         TIntermSymbol *argumentAsSymbol = symbol->getAsSymbolNode();
513         ASSERT(argumentAsSymbol);
514 
515         const TSpan<const unsigned int> &arraySizes = argumentAsSymbol->getType().getArraySizes();
516 
517         // Calculate Pi
518         TVector<unsigned int> runningArraySizeProducts;
519         if (!arraySizes.empty())
520         {
521             runningArraySizeProducts.resize(arraySizes.size());
522             uint32_t runningProduct = 1;
523             for (size_t dimension = 0; dimension < arraySizes.size(); ++dimension)
524             {
525                 runningArraySizeProducts[dimension] = runningProduct;
526                 runningProduct *= arraySizes[dimension];
527             }
528         }
529 
530         // Walk the parent chain and accumulate Ii*Pi
531         uint32_t subscriptCount = 0;
532         TIntermTyped *flattenedSubscript =
533             convertFunctionArgumentHelper(runningArraySizeProducts, nullptr, 0, &subscriptCount);
534 
535         // Find the function argument, which is either in the form of ac (i.e. there are no
536         // subscripts, in which case that's the function argument), or ac[In]...[Ii] (in which case
537         // the function argument is the (n-i)th ancestor of ac.
538         //
539         // Note that this is the case because no other operation is allowed on ac other than
540         // subscript.
541         TIntermNode *argument = subscriptCount == 0 ? symbol : getAncestorNode(subscriptCount - 1);
542         ASSERT(argument != nullptr);
543 
544         // If not subscripted, keep the argument as-is.
545         if (flattenedSubscript == nullptr)
546         {
547             return argument;
548         }
549 
550         // Copy the atomic counter binding/offset constant and modify it by adding the array
551         // subscript to its offset field.
552         TVariable *modified              = CreateTempVariable(mSymbolTable, mAtomicCounterType);
553         TIntermDeclaration *modifiedDecl = CreateTempInitDeclarationNode(modified, *bindingOffset);
554 
555         TIntermSymbol *modifiedSymbol    = new TIntermSymbol(modified);
556         TConstantUnion *offsetFieldIndex = new TConstantUnion;
557         offsetFieldIndex->setIConst(1);
558         TIntermConstantUnion *offsetFieldRef =
559             new TIntermConstantUnion(offsetFieldIndex, *StaticType::GetBasic<EbtUInt>());
560         TIntermBinary *offsetField =
561             new TIntermBinary(EOpIndexDirectStruct, modifiedSymbol, offsetFieldRef);
562 
563         TIntermBinary *modifiedOffset =
564             new TIntermBinary(EOpAddAssign, offsetField, flattenedSubscript);
565 
566         TIntermSequence *modifySequence = new TIntermSequence({modifiedDecl, modifiedOffset});
567         insertStatementsInParentBlock(*modifySequence);
568 
569         *bindingOffset = modifiedSymbol->deepCopy();
570 
571         return argument;
572     }
573 
convertBuiltinFunction(TIntermAggregate * node)574     void convertBuiltinFunction(TIntermAggregate *node)
575     {
576         // If the function is |memoryBarrierAtomicCounter|, simply replace it with
577         // |memoryBarrierBuffer|.
578         if (node->getFunction()->name() == "memoryBarrierAtomicCounter")
579         {
580             TIntermTyped *substituteCall = CreateBuiltInFunctionCallNode(
581                 "memoryBarrierBuffer", new TIntermSequence, *mSymbolTable, 310);
582             queueReplacement(substituteCall, OriginalNode::IS_DROPPED);
583             return;
584         }
585 
586         // If it's an |atomicCounter*| function, replace the function with an |atomic*| equivalent.
587         if (!node->getFunction()->isAtomicCounterFunction())
588         {
589             return;
590         }
591 
592         const ImmutableString &functionName = node->getFunction()->name();
593         TIntermSequence *arguments          = node->getSequence();
594 
595         // Note: atomicAdd(0) is used for atomic reads.
596         uint32_t valueChange                = 0;
597         constexpr char kAtomicAddFunction[] = "atomicAdd";
598         bool isDecrement                    = false;
599 
600         if (functionName == "atomicCounterIncrement")
601         {
602             valueChange = 1;
603         }
604         else if (functionName == "atomicCounterDecrement")
605         {
606             // uint values are required to wrap around, so 0xFFFFFFFFu is used as -1.
607             valueChange = std::numeric_limits<uint32_t>::max();
608             static_assert(static_cast<uint32_t>(-1) == std::numeric_limits<uint32_t>::max(),
609                           "uint32_t max is not -1");
610 
611             isDecrement = true;
612         }
613         else
614         {
615             ASSERT(functionName == "atomicCounter");
616         }
617 
618         const TIntermNode *param = (*arguments)[0];
619 
620         TIntermTyped *bindingOffset = mRetyper.getFunctionCallArgReplacement(param);
621 
622         TIntermSequence *substituteArguments = new TIntermSequence;
623         substituteArguments->push_back(
624             CreateAtomicCounterRef(mAtomicCounters, bindingOffset, mAcbBufferOffsets));
625         substituteArguments->push_back(CreateUIntConstant(valueChange));
626 
627         TIntermTyped *substituteCall = CreateBuiltInFunctionCallNode(
628             kAtomicAddFunction, substituteArguments, *mSymbolTable, 310);
629 
630         // Note that atomicCounterDecrement returns the *new* value instead of the prior value,
631         // unlike atomicAdd.  So we need to do a -1 on the result as well.
632         if (isDecrement)
633         {
634             substituteCall = new TIntermBinary(EOpSub, substituteCall, CreateUIntConstant(1));
635         }
636 
637         queueReplacement(substituteCall, OriginalNode::IS_DROPPED);
638     }
639 
640     const TVariable *mAtomicCounters;
641     const TIntermTyped *mAcbBufferOffsets;
642 
643     RetypeOpaqueVariablesHelper mRetyper;
644 
645     TType *mAtomicCounterType;
646     TType *mAtomicCounterTypeConst;
647 
648     // Stored to be put at the top of the shader after the pass.
649     TIntermDeclaration *mAtomicCounterTypeDeclaration;
650 };
651 
652 }  // anonymous namespace
653 
RewriteAtomicCounters(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,const TIntermTyped * acbBufferOffsets)654 bool RewriteAtomicCounters(TCompiler *compiler,
655                            TIntermBlock *root,
656                            TSymbolTable *symbolTable,
657                            const TIntermTyped *acbBufferOffsets)
658 {
659     const TVariable *atomicCounters = DeclareAtomicCountersBuffers(root, symbolTable);
660 
661     RewriteAtomicCountersTraverser traverser(symbolTable, atomicCounters, acbBufferOffsets);
662     root->traverse(&traverser);
663     if (!traverser.updateTree(compiler, root))
664     {
665         return false;
666     }
667 
668     TIntermDeclaration *atomicCounterTypeDeclaration = traverser.getAtomicCounterTypeDeclaration();
669     if (atomicCounterTypeDeclaration)
670     {
671         root->getSequence()->insert(root->getSequence()->begin(), atomicCounterTypeDeclaration);
672     }
673 
674     return compiler->validateAST(root);
675 }
676 }  // namespace sh
677