1 //
2 // Copyright 2019 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RewriteAtomicCounters: Emulate atomic counter buffers with storage buffers.
7 //
8
9 #include "compiler/translator/tree_ops/RewriteAtomicCounters.h"
10
11 #include "compiler/translator/Compiler.h"
12 #include "compiler/translator/ImmutableStringBuilder.h"
13 #include "compiler/translator/StaticType.h"
14 #include "compiler/translator/SymbolTable.h"
15 #include "compiler/translator/tree_util/IntermNode_util.h"
16 #include "compiler/translator/tree_util/IntermTraverse.h"
17 #include "compiler/translator/tree_util/ReplaceVariable.h"
18
19 namespace sh
20 {
21 namespace
22 {
23 constexpr ImmutableString kAtomicCounterTypeName = ImmutableString("ANGLE_atomic_uint");
24 constexpr ImmutableString kAtomicCountersVarName = ImmutableString("atomicCounters");
25 constexpr ImmutableString kAtomicCounterFieldName = ImmutableString("counters");
26
27 // DeclareAtomicCountersBuffer adds a storage buffer array that's used with atomic counters.
DeclareAtomicCountersBuffers(TIntermBlock * root,TSymbolTable * symbolTable)28 const TVariable *DeclareAtomicCountersBuffers(TIntermBlock *root, TSymbolTable *symbolTable)
29 {
30 // Define `uint counters[];` as the only field in the interface block.
31 TFieldList *fieldList = new TFieldList;
32 TType *counterType = new TType(EbtUInt);
33 counterType->makeArray(0);
34
35 TField *countersField =
36 new TField(counterType, kAtomicCounterFieldName, TSourceLoc(), SymbolType::AngleInternal);
37
38 fieldList->push_back(countersField);
39
40 TMemoryQualifier coherentMemory = TMemoryQualifier::Create();
41 coherentMemory.coherent = true;
42
43 // There are a maximum of 8 atomic counter buffers per IMPLEMENTATION_MAX_ATOMIC_COUNTER_BUFFERS
44 // in libANGLE/Constants.h.
45 constexpr uint32_t kMaxAtomicCounterBuffers = 8;
46
47 // Define a storage block "ANGLEAtomicCounters" with instance name "atomicCounters".
48 return DeclareInterfaceBlock(
49 root, symbolTable, fieldList, EvqBuffer, coherentMemory, kMaxAtomicCounterBuffers,
50 ImmutableString(vk::kAtomicCountersBlockName), kAtomicCountersVarName);
51 }
52
CreateUIntConstant(uint32_t value)53 TIntermConstantUnion *CreateUIntConstant(uint32_t value)
54 {
55 TType *constantType = new TType(*StaticType::GetBasic<EbtUInt, 1>());
56 constantType->setQualifier(EvqConst);
57
58 TConstantUnion *constantValue = new TConstantUnion;
59 constantValue->setUConst(value);
60 return new TIntermConstantUnion(constantValue, *constantType);
61 }
62
CreateAtomicCounterConstant(TType * atomicCounterType,uint32_t binding,uint32_t offset)63 TIntermTyped *CreateAtomicCounterConstant(TType *atomicCounterType,
64 uint32_t binding,
65 uint32_t offset)
66 {
67 ASSERT(atomicCounterType->getBasicType() == EbtStruct);
68
69 TIntermSequence *arguments = new TIntermSequence();
70 arguments->push_back(CreateUIntConstant(binding));
71 arguments->push_back(CreateUIntConstant(offset));
72
73 return TIntermAggregate::CreateConstructor(*atomicCounterType, arguments);
74 }
75
CreateAtomicCounterRef(const TVariable * atomicCounters,const TIntermTyped * bindingOffset,const TIntermTyped * bufferOffsets)76 TIntermBinary *CreateAtomicCounterRef(const TVariable *atomicCounters,
77 const TIntermTyped *bindingOffset,
78 const TIntermTyped *bufferOffsets)
79 {
80 // The atomic counters storage buffer declaration looks as such:
81 //
82 // layout(...) buffer ANGLEAtomicCounters
83 // {
84 // uint counters[];
85 // } atomicCounters[N];
86 //
87 // Where N is large enough to accommodate atomic counter buffer bindings used in the shader.
88 //
89 // Given an ANGLEAtomicCounter variable (which is a struct of {binding, offset}), we need to
90 // return:
91 //
92 // atomicCounters[binding].counters[offset]
93 //
94 // The offset itself is the provided one plus an offset given through uniforms.
95
96 TIntermSymbol *atomicCountersRef = new TIntermSymbol(atomicCounters);
97
98 TIntermConstantUnion *bindingFieldRef = CreateIndexNode(0);
99 TIntermConstantUnion *offsetFieldRef = CreateIndexNode(1);
100 TIntermConstantUnion *countersFieldRef = CreateIndexNode(0);
101
102 // Create references to bindingOffset.binding and bindingOffset.offset.
103 TIntermBinary *binding =
104 new TIntermBinary(EOpIndexDirectStruct, bindingOffset->deepCopy(), bindingFieldRef);
105 TIntermBinary *offset =
106 new TIntermBinary(EOpIndexDirectStruct, bindingOffset->deepCopy(), offsetFieldRef);
107
108 // Create reference to atomicCounters[bindingOffset.binding]
109 TIntermBinary *countersBlock = new TIntermBinary(EOpIndexDirect, atomicCountersRef, binding);
110
111 // Create reference to atomicCounters[bindingOffset.binding].counters
112 TIntermBinary *counters =
113 new TIntermBinary(EOpIndexDirectInterfaceBlock, countersBlock, countersFieldRef);
114
115 // Create bufferOffsets[binding / 4]. Each uint in bufferOffsets contains offsets for 4
116 // bindings.
117 TIntermBinary *bindingDivFour =
118 new TIntermBinary(EOpDiv, binding->deepCopy(), CreateUIntConstant(4));
119 TIntermBinary *bufferOffsetUint =
120 new TIntermBinary(EOpIndexDirect, bufferOffsets->deepCopy(), bindingDivFour);
121
122 // Create (binding % 4) * 8
123 TIntermBinary *bindingModFour =
124 new TIntermBinary(EOpIMod, binding->deepCopy(), CreateUIntConstant(4));
125 TIntermBinary *bufferOffsetShift =
126 new TIntermBinary(EOpMul, bindingModFour, CreateUIntConstant(8));
127
128 // Create bufferOffsets[binding / 4] >> ((binding % 4) * 8) & 0xFF
129 TIntermBinary *bufferOffsetShifted =
130 new TIntermBinary(EOpBitShiftRight, bufferOffsetUint, bufferOffsetShift);
131 TIntermBinary *bufferOffset =
132 new TIntermBinary(EOpBitwiseAnd, bufferOffsetShifted, CreateUIntConstant(0xFF));
133
134 // return atomicCounters[bindingOffset.binding].counters[bindingOffset.offset + bufferOffset]
135 offset = new TIntermBinary(EOpAdd, offset, bufferOffset);
136 return new TIntermBinary(EOpIndexDirect, counters, offset);
137 }
138
139 // Traverser that:
140 //
141 // 1. Converts the |atomic_uint| types to |{uint,uint}| for binding and offset.
142 // 2. Substitutes the |uniform atomic_uint| declarations with a global declaration that holds the
143 // binding and offset.
144 // 3. Substitutes |atomicVar[n]| with |buffer[binding].counters[offset + n]|.
145 class RewriteAtomicCountersTraverser : public TIntermTraverser
146 {
147 public:
RewriteAtomicCountersTraverser(TSymbolTable * symbolTable,const TVariable * atomicCounters,const TIntermTyped * acbBufferOffsets)148 RewriteAtomicCountersTraverser(TSymbolTable *symbolTable,
149 const TVariable *atomicCounters,
150 const TIntermTyped *acbBufferOffsets)
151 : TIntermTraverser(true, true, true, symbolTable),
152 mAtomicCounters(atomicCounters),
153 mAcbBufferOffsets(acbBufferOffsets),
154 mAtomicCounterType(nullptr),
155 mAtomicCounterTypeConst(nullptr),
156 mAtomicCounterTypeDeclaration(nullptr)
157 {}
158
visitDeclaration(Visit visit,TIntermDeclaration * node)159 bool visitDeclaration(Visit visit, TIntermDeclaration *node) override
160 {
161 if (visit != PreVisit)
162 {
163 return true;
164 }
165
166 const TIntermSequence &sequence = *(node->getSequence());
167
168 TIntermTyped *variable = sequence.front()->getAsTyped();
169 const TType &type = variable->getType();
170 bool isAtomicCounter = type.getQualifier() == EvqUniform && type.isAtomicCounter();
171
172 if (isAtomicCounter)
173 {
174 // Atomic counters cannot have initializers, so the declaration must necessarily be a
175 // symbol.
176 TIntermSymbol *samplerVariable = variable->getAsSymbolNode();
177 ASSERT(samplerVariable != nullptr);
178
179 declareAtomicCounter(&samplerVariable->variable(), node);
180 return false;
181 }
182
183 return true;
184 }
185
visitFunctionPrototype(TIntermFunctionPrototype * node)186 void visitFunctionPrototype(TIntermFunctionPrototype *node) override
187 {
188 const TFunction *function = node->getFunction();
189 // Go over the parameters and replace the atomic arguments with a uint type.
190 mRetyper.visitFunctionPrototype();
191 for (size_t paramIndex = 0; paramIndex < function->getParamCount(); ++paramIndex)
192 {
193 const TVariable *param = function->getParam(paramIndex);
194 TVariable *replacement = convertFunctionParameter(node, param);
195 if (replacement)
196 {
197 mRetyper.replaceFunctionParam(param, replacement);
198 }
199 }
200
201 TIntermFunctionPrototype *replacementPrototype =
202 mRetyper.convertFunctionPrototype(mSymbolTable, function);
203 if (replacementPrototype)
204 {
205 queueReplacement(replacementPrototype, OriginalNode::IS_DROPPED);
206 }
207 }
208
visitAggregate(Visit visit,TIntermAggregate * node)209 bool visitAggregate(Visit visit, TIntermAggregate *node) override
210 {
211 if (visit == PreVisit)
212 {
213 mRetyper.preVisitAggregate();
214 }
215
216 if (visit != PostVisit)
217 {
218 return true;
219 }
220
221 if (node->getOp() == EOpCallBuiltInFunction)
222 {
223 convertBuiltinFunction(node);
224 }
225 else if (node->getOp() == EOpCallFunctionInAST)
226 {
227 TIntermAggregate *substituteCall = mRetyper.convertASTFunction(node);
228 if (substituteCall)
229 {
230 queueReplacement(substituteCall, OriginalNode::IS_DROPPED);
231 }
232 }
233 mRetyper.postVisitAggregate();
234
235 return true;
236 }
237
visitSymbol(TIntermSymbol * symbol)238 void visitSymbol(TIntermSymbol *symbol) override
239 {
240 const TVariable *symbolVariable = &symbol->variable();
241
242 if (!symbol->getType().isAtomicCounter())
243 {
244 return;
245 }
246
247 // The symbol is either referencing a global atomic counter, or is a function parameter. In
248 // either case, it could be an array. The are the following possibilities:
249 //
250 // layout(..) uniform atomic_uint ac;
251 // layout(..) uniform atomic_uint acArray[N];
252 //
253 // void func(inout atomic_uint c)
254 // {
255 // otherFunc(c);
256 // }
257 //
258 // void funcArray(inout atomic_uint cArray[N])
259 // {
260 // otherFuncArray(cArray);
261 // otherFunc(cArray[n]);
262 // }
263 //
264 // void funcGlobal()
265 // {
266 // func(ac);
267 // func(acArray[n]);
268 // funcArray(acArray);
269 // atomicIncrement(ac);
270 // atomicIncrement(acArray[n]);
271 // }
272 //
273 // This should translate to:
274 //
275 // buffer ANGLEAtomicCounters
276 // {
277 // uint counters[];
278 // } atomicCounters;
279 //
280 // struct ANGLEAtomicCounter
281 // {
282 // uint binding;
283 // uint offset;
284 // };
285 // const ANGLEAtomicCounter ac = {<binding>, <offset>};
286 // const ANGLEAtomicCounter acArray = {<binding>, <offset>};
287 //
288 // void func(inout ANGLEAtomicCounter c)
289 // {
290 // otherFunc(c);
291 // }
292 //
293 // void funcArray(inout uint cArray)
294 // {
295 // otherFuncArray(cArray);
296 // otherFunc({cArray.binding, cArray.offset + n});
297 // }
298 //
299 // void funcGlobal()
300 // {
301 // func(ac);
302 // func(acArray+n);
303 // funcArray(acArray);
304 // atomicAdd(atomicCounters[ac.binding]counters[ac.offset]);
305 // atomicAdd(atomicCounters[ac.binding]counters[ac.offset+n]);
306 // }
307 //
308 // In all cases, the argument transformation is stored in mRetyper. In the function call's
309 // PostVisit, if it's a builtin, the look up in |atomicCounters.counters| is done as well as
310 // the builtin function change. Otherwise, the transformed argument is passed on as is.
311 //
312
313 TIntermTyped *bindingOffset =
314 new TIntermSymbol(mRetyper.getVariableReplacement(symbolVariable));
315 ASSERT(bindingOffset != nullptr);
316
317 TIntermNode *argument = convertFunctionArgument(symbol, &bindingOffset);
318
319 if (mRetyper.isInAggregate())
320 {
321 mRetyper.replaceFunctionCallArg(argument, bindingOffset);
322 }
323 else
324 {
325 // If there's a stray ac[i] lying around, just delete it. This can happen if the shader
326 // uses ac[i].length(), which in RemoveArrayLengthMethod() will result in an ineffective
327 // statement that's just ac[i]; (similarly for a stray ac;, it doesn't have to be
328 // subscripted). Note that the subscript could have side effects, but the
329 // convertFunctionArgument above has already generated code that includes the subscript
330 // (and therefore its side-effect).
331 TIntermBlock *block = nullptr;
332 for (uint32_t ancestorIndex = 0; block == nullptr; ++ancestorIndex)
333 {
334 block = getAncestorNode(ancestorIndex)->getAsBlock();
335 }
336
337 TIntermSequence emptySequence;
338 mMultiReplacements.emplace_back(block, argument, emptySequence);
339 }
340 }
341
getAtomicCounterTypeDeclaration()342 TIntermDeclaration *getAtomicCounterTypeDeclaration() { return mAtomicCounterTypeDeclaration; }
343
344 private:
declareAtomicCounter(const TVariable * atomicCounterVar,TIntermDeclaration * node)345 void declareAtomicCounter(const TVariable *atomicCounterVar, TIntermDeclaration *node)
346 {
347 // Create a global variable that contains the binding and offset of this atomic counter
348 // declaration.
349 if (mAtomicCounterType == nullptr)
350 {
351 declareAtomicCounterType();
352 }
353 ASSERT(mAtomicCounterTypeConst);
354
355 TVariable *bindingOffset = new TVariable(mSymbolTable, atomicCounterVar->name(),
356 mAtomicCounterTypeConst, SymbolType::UserDefined);
357
358 const TType &atomicCounterType = atomicCounterVar->getType();
359 uint32_t offset = atomicCounterType.getLayoutQualifier().offset;
360 uint32_t binding = atomicCounterType.getLayoutQualifier().binding;
361
362 ASSERT(offset % 4 == 0);
363 TIntermTyped *bindingOffsetInitValue =
364 CreateAtomicCounterConstant(mAtomicCounterTypeConst, binding, offset / 4);
365
366 TIntermSymbol *bindingOffsetSymbol = new TIntermSymbol(bindingOffset);
367 TIntermBinary *bindingOffsetInit =
368 new TIntermBinary(EOpInitialize, bindingOffsetSymbol, bindingOffsetInitValue);
369
370 TIntermDeclaration *bindingOffsetDeclaration = new TIntermDeclaration();
371 bindingOffsetDeclaration->appendDeclarator(bindingOffsetInit);
372
373 // Replace the atomic_uint declaration with the binding/offset declaration.
374 TIntermSequence replacement;
375 replacement.push_back(bindingOffsetDeclaration);
376 mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node, replacement);
377
378 // Remember the binding/offset variable.
379 mRetyper.replaceGlobalVariable(atomicCounterVar, bindingOffset);
380 }
381
declareAtomicCounterType()382 void declareAtomicCounterType()
383 {
384 ASSERT(mAtomicCounterType == nullptr);
385
386 TFieldList *fields = new TFieldList();
387 fields->push_back(new TField(new TType(EbtUInt, EbpUndefined, EvqGlobal, 1, 1),
388 ImmutableString("binding"), TSourceLoc(),
389 SymbolType::AngleInternal));
390 fields->push_back(new TField(new TType(EbtUInt, EbpUndefined, EvqGlobal, 1, 1),
391 ImmutableString("arrayIndex"), TSourceLoc(),
392 SymbolType::AngleInternal));
393 TStructure *atomicCounterTypeStruct =
394 new TStructure(mSymbolTable, kAtomicCounterTypeName, fields, SymbolType::AngleInternal);
395 mAtomicCounterType = new TType(atomicCounterTypeStruct, false);
396
397 mAtomicCounterTypeDeclaration = new TIntermDeclaration;
398 TVariable *emptyVariable = new TVariable(mSymbolTable, kEmptyImmutableString,
399 mAtomicCounterType, SymbolType::Empty);
400 mAtomicCounterTypeDeclaration->appendDeclarator(new TIntermSymbol(emptyVariable));
401
402 // Keep a const variant around as well.
403 mAtomicCounterTypeConst = new TType(*mAtomicCounterType);
404 mAtomicCounterTypeConst->setQualifier(EvqConst);
405 }
406
convertFunctionParameter(TIntermNode * parent,const TVariable * param)407 TVariable *convertFunctionParameter(TIntermNode *parent, const TVariable *param)
408 {
409 if (!param->getType().isAtomicCounter())
410 {
411 return nullptr;
412 }
413 if (mAtomicCounterType == nullptr)
414 {
415 declareAtomicCounterType();
416 }
417
418 const TType *paramType = ¶m->getType();
419 TType *newType =
420 paramType->getQualifier() == EvqConst ? mAtomicCounterTypeConst : mAtomicCounterType;
421
422 TVariable *replacementVar =
423 new TVariable(mSymbolTable, param->name(), newType, SymbolType::UserDefined);
424
425 return replacementVar;
426 }
427
convertFunctionArgumentHelper(const TVector<unsigned int> & runningArraySizeProducts,TIntermTyped * flattenedSubscript,uint32_t depth,uint32_t * subscriptCountOut)428 TIntermTyped *convertFunctionArgumentHelper(
429 const TVector<unsigned int> &runningArraySizeProducts,
430 TIntermTyped *flattenedSubscript,
431 uint32_t depth,
432 uint32_t *subscriptCountOut)
433 {
434 std::string prefix(depth, ' ');
435 TIntermNode *parent = getAncestorNode(depth);
436 ASSERT(parent);
437
438 TIntermBinary *arrayExpression = parent->getAsBinaryNode();
439 if (!arrayExpression)
440 {
441 // If the parent is not an array subscript operation, we have reached the end of the
442 // subscript chain. Note the depth that's traversed so the corresponding node can be
443 // taken as the function argument.
444 *subscriptCountOut = depth;
445 return flattenedSubscript;
446 }
447
448 ASSERT(arrayExpression->getOp() == EOpIndexDirect ||
449 arrayExpression->getOp() == EOpIndexIndirect);
450
451 // Assume i = n - depth. Get Pi. See comment in convertFunctionArgument.
452 ASSERT(depth < runningArraySizeProducts.size());
453 uint32_t thisDimensionSize =
454 runningArraySizeProducts[runningArraySizeProducts.size() - 1 - depth];
455
456 // Get Ii.
457 TIntermTyped *thisDimensionOffset = arrayExpression->getRight();
458
459 TIntermConstantUnion *subscriptAsConstant = thisDimensionOffset->getAsConstantUnion();
460 const bool subscriptIsZero = subscriptAsConstant && subscriptAsConstant->isZero(0);
461
462 // If Ii is zero, don't need to add Ii*Pi; that's zero.
463 if (!subscriptIsZero)
464 {
465 thisDimensionOffset = thisDimensionOffset->deepCopy();
466
467 // If Pi is 1, don't multiply. Just accumulate Ii.
468 if (thisDimensionSize != 1)
469 {
470 thisDimensionOffset = new TIntermBinary(EOpMul, thisDimensionOffset,
471 CreateUIntConstant(thisDimensionSize));
472 }
473
474 // Accumulate with the previous running offset, if any.
475 if (flattenedSubscript)
476 {
477 flattenedSubscript =
478 new TIntermBinary(EOpAdd, flattenedSubscript, thisDimensionOffset);
479 }
480 else
481 {
482 flattenedSubscript = thisDimensionOffset;
483 }
484 }
485
486 // Note: GLSL only allows 2 nested levels of arrays, so this recursion is bounded.
487 return convertFunctionArgumentHelper(runningArraySizeProducts, flattenedSubscript,
488 depth + 1, subscriptCountOut);
489 }
490
convertFunctionArgument(TIntermNode * symbol,TIntermTyped ** bindingOffset)491 TIntermNode *convertFunctionArgument(TIntermNode *symbol, TIntermTyped **bindingOffset)
492 {
493 // Assume a general case of array declaration with N dimensions:
494 //
495 // atomic_uint ac[Dn]..[D2][D1];
496 //
497 // Let's define
498 //
499 // Pn = D(n-1)*...*D2*D1
500 //
501 // In that case, we have:
502 //
503 // ac[In] = ac + In*Pn
504 // ac[In][I(n-1)] = ac + In*Pn + I(n-1)*P(n-1)
505 // ac[In]...[Ii] = ac + In*Pn + ... + Ii*Pi
506 //
507 // We have just visited a symbol; ac. Walking the parent chain, we will visit the
508 // expressions in the above order (ac, ac[In], ac[In][I(n-1)], ...). We therefore can
509 // simply walk the parent chain and accumulate Ii*Pi to obtain the offset from the base of
510 // ac.
511
512 TIntermSymbol *argumentAsSymbol = symbol->getAsSymbolNode();
513 ASSERT(argumentAsSymbol);
514
515 const TSpan<const unsigned int> &arraySizes = argumentAsSymbol->getType().getArraySizes();
516
517 // Calculate Pi
518 TVector<unsigned int> runningArraySizeProducts;
519 if (!arraySizes.empty())
520 {
521 runningArraySizeProducts.resize(arraySizes.size());
522 uint32_t runningProduct = 1;
523 for (size_t dimension = 0; dimension < arraySizes.size(); ++dimension)
524 {
525 runningArraySizeProducts[dimension] = runningProduct;
526 runningProduct *= arraySizes[dimension];
527 }
528 }
529
530 // Walk the parent chain and accumulate Ii*Pi
531 uint32_t subscriptCount = 0;
532 TIntermTyped *flattenedSubscript =
533 convertFunctionArgumentHelper(runningArraySizeProducts, nullptr, 0, &subscriptCount);
534
535 // Find the function argument, which is either in the form of ac (i.e. there are no
536 // subscripts, in which case that's the function argument), or ac[In]...[Ii] (in which case
537 // the function argument is the (n-i)th ancestor of ac.
538 //
539 // Note that this is the case because no other operation is allowed on ac other than
540 // subscript.
541 TIntermNode *argument = subscriptCount == 0 ? symbol : getAncestorNode(subscriptCount - 1);
542 ASSERT(argument != nullptr);
543
544 // If not subscripted, keep the argument as-is.
545 if (flattenedSubscript == nullptr)
546 {
547 return argument;
548 }
549
550 // Copy the atomic counter binding/offset constant and modify it by adding the array
551 // subscript to its offset field.
552 TVariable *modified = CreateTempVariable(mSymbolTable, mAtomicCounterType);
553 TIntermDeclaration *modifiedDecl = CreateTempInitDeclarationNode(modified, *bindingOffset);
554
555 TIntermSymbol *modifiedSymbol = new TIntermSymbol(modified);
556 TConstantUnion *offsetFieldIndex = new TConstantUnion;
557 offsetFieldIndex->setIConst(1);
558 TIntermConstantUnion *offsetFieldRef =
559 new TIntermConstantUnion(offsetFieldIndex, *StaticType::GetBasic<EbtUInt>());
560 TIntermBinary *offsetField =
561 new TIntermBinary(EOpIndexDirectStruct, modifiedSymbol, offsetFieldRef);
562
563 TIntermBinary *modifiedOffset =
564 new TIntermBinary(EOpAddAssign, offsetField, flattenedSubscript);
565
566 TIntermSequence *modifySequence = new TIntermSequence({modifiedDecl, modifiedOffset});
567 insertStatementsInParentBlock(*modifySequence);
568
569 *bindingOffset = modifiedSymbol->deepCopy();
570
571 return argument;
572 }
573
convertBuiltinFunction(TIntermAggregate * node)574 void convertBuiltinFunction(TIntermAggregate *node)
575 {
576 // If the function is |memoryBarrierAtomicCounter|, simply replace it with
577 // |memoryBarrierBuffer|.
578 if (node->getFunction()->name() == "memoryBarrierAtomicCounter")
579 {
580 TIntermTyped *substituteCall = CreateBuiltInFunctionCallNode(
581 "memoryBarrierBuffer", new TIntermSequence, *mSymbolTable, 310);
582 queueReplacement(substituteCall, OriginalNode::IS_DROPPED);
583 return;
584 }
585
586 // If it's an |atomicCounter*| function, replace the function with an |atomic*| equivalent.
587 if (!node->getFunction()->isAtomicCounterFunction())
588 {
589 return;
590 }
591
592 const ImmutableString &functionName = node->getFunction()->name();
593 TIntermSequence *arguments = node->getSequence();
594
595 // Note: atomicAdd(0) is used for atomic reads.
596 uint32_t valueChange = 0;
597 constexpr char kAtomicAddFunction[] = "atomicAdd";
598 bool isDecrement = false;
599
600 if (functionName == "atomicCounterIncrement")
601 {
602 valueChange = 1;
603 }
604 else if (functionName == "atomicCounterDecrement")
605 {
606 // uint values are required to wrap around, so 0xFFFFFFFFu is used as -1.
607 valueChange = std::numeric_limits<uint32_t>::max();
608 static_assert(static_cast<uint32_t>(-1) == std::numeric_limits<uint32_t>::max(),
609 "uint32_t max is not -1");
610
611 isDecrement = true;
612 }
613 else
614 {
615 ASSERT(functionName == "atomicCounter");
616 }
617
618 const TIntermNode *param = (*arguments)[0];
619
620 TIntermTyped *bindingOffset = mRetyper.getFunctionCallArgReplacement(param);
621
622 TIntermSequence *substituteArguments = new TIntermSequence;
623 substituteArguments->push_back(
624 CreateAtomicCounterRef(mAtomicCounters, bindingOffset, mAcbBufferOffsets));
625 substituteArguments->push_back(CreateUIntConstant(valueChange));
626
627 TIntermTyped *substituteCall = CreateBuiltInFunctionCallNode(
628 kAtomicAddFunction, substituteArguments, *mSymbolTable, 310);
629
630 // Note that atomicCounterDecrement returns the *new* value instead of the prior value,
631 // unlike atomicAdd. So we need to do a -1 on the result as well.
632 if (isDecrement)
633 {
634 substituteCall = new TIntermBinary(EOpSub, substituteCall, CreateUIntConstant(1));
635 }
636
637 queueReplacement(substituteCall, OriginalNode::IS_DROPPED);
638 }
639
640 const TVariable *mAtomicCounters;
641 const TIntermTyped *mAcbBufferOffsets;
642
643 RetypeOpaqueVariablesHelper mRetyper;
644
645 TType *mAtomicCounterType;
646 TType *mAtomicCounterTypeConst;
647
648 // Stored to be put at the top of the shader after the pass.
649 TIntermDeclaration *mAtomicCounterTypeDeclaration;
650 };
651
652 } // anonymous namespace
653
RewriteAtomicCounters(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,const TIntermTyped * acbBufferOffsets)654 bool RewriteAtomicCounters(TCompiler *compiler,
655 TIntermBlock *root,
656 TSymbolTable *symbolTable,
657 const TIntermTyped *acbBufferOffsets)
658 {
659 const TVariable *atomicCounters = DeclareAtomicCountersBuffers(root, symbolTable);
660
661 RewriteAtomicCountersTraverser traverser(symbolTable, atomicCounters, acbBufferOffsets);
662 root->traverse(&traverser);
663 if (!traverser.updateTree(compiler, root))
664 {
665 return false;
666 }
667
668 TIntermDeclaration *atomicCounterTypeDeclaration = traverser.getAtomicCounterTypeDeclaration();
669 if (atomicCounterTypeDeclaration)
670 {
671 root->getSequence()->insert(root->getSequence()->begin(), atomicCounterTypeDeclaration);
672 }
673
674 return compiler->validateAST(root);
675 }
676 } // namespace sh
677