1 //
2 // Copyright 2021 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RewriteR32fImages: Change images qualified with r32f to use r32ui instead.
7 //
8
9 #include "compiler/translator/tree_ops/vulkan/RewriteR32fImages.h"
10
11 #include "compiler/translator/Compiler.h"
12 #include "compiler/translator/ImmutableStringBuilder.h"
13 #include "compiler/translator/StaticType.h"
14 #include "compiler/translator/SymbolTable.h"
15 #include "compiler/translator/tree_util/IntermNode_util.h"
16 #include "compiler/translator/tree_util/IntermTraverse.h"
17 #include "compiler/translator/tree_util/ReplaceVariable.h"
18
19 namespace sh
20 {
21 namespace
22 {
IsR32fImage(const TType & type)23 bool IsR32fImage(const TType &type)
24 {
25 return type.getQualifier() == EvqUniform && type.isImage() &&
26 type.getLayoutQualifier().imageInternalFormat == EiifR32F;
27 }
28
29 using ImageMap = angle::HashMap<const TVariable *, const TVariable *>;
30
31 TIntermTyped *RewriteBuiltinFunctionCall(TCompiler *compiler,
32 TSymbolTable *symbolTable,
33 TIntermAggregate *node,
34 const ImageMap &imageMap);
35
36 // Given an expression, this traverser calculates a new expression where builtin function calls to
37 // r32f images are replaced with ones to the mapped r32ui image. In particular, this is run on the
38 // right node of EOpIndexIndirect binary nodes, so that the expression in the index gets a chance to
39 // go through this transformation.
40 class RewriteExpressionTraverser final : public TIntermTraverser
41 {
42 public:
RewriteExpressionTraverser(TCompiler * compiler,TSymbolTable * symbolTable,const ImageMap & imageMap)43 explicit RewriteExpressionTraverser(TCompiler *compiler,
44 TSymbolTable *symbolTable,
45 const ImageMap &imageMap)
46 : TIntermTraverser(true, false, false, symbolTable),
47 mCompiler(compiler),
48 mImageMap(imageMap)
49 {}
50
visitAggregate(Visit visit,TIntermAggregate * node)51 bool visitAggregate(Visit visit, TIntermAggregate *node) override
52 {
53 TIntermTyped *rewritten =
54 RewriteBuiltinFunctionCall(mCompiler, mSymbolTable, node, mImageMap);
55 if (rewritten == nullptr)
56 {
57 return true;
58 }
59
60 queueReplacement(rewritten, OriginalNode::IS_DROPPED);
61
62 // Don't iterate as the expression is rewritten.
63 return false;
64 }
65
66 private:
67 TCompiler *mCompiler;
68
69 const ImageMap &mImageMap;
70 };
71
72 // Rewrite the index of an EOpIndexIndirect expression as well as any arguments to the builtin
73 // function call.
RewriteExpression(TCompiler * compiler,TSymbolTable * symbolTable,TIntermTyped * expression,const ImageMap & imageMap)74 TIntermTyped *RewriteExpression(TCompiler *compiler,
75 TSymbolTable *symbolTable,
76 TIntermTyped *expression,
77 const ImageMap &imageMap)
78 {
79 // Create a fake block to insert the node in. The root itself may need changing.
80 TIntermBlock block;
81 block.appendStatement(expression);
82
83 RewriteExpressionTraverser traverser(compiler, symbolTable, imageMap);
84 block.traverse(&traverser);
85
86 bool valid = traverser.updateTree(compiler, &block);
87 ASSERT(valid);
88
89 TIntermTyped *rewritten = block.getChildNode(0)->getAsTyped();
90
91 return rewritten;
92 }
93
94 // Given a builtin function call such as the following:
95 //
96 // imageLoad(expression, ...);
97 //
98 // expression is in the form of:
99 //
100 // - image uniform
101 // - image uniform array indexed with EOpIndexDirect or EOpIndexIndirect. Note that
102 // RewriteArrayOfArrayOfOpaqueUniforms has already ensured that the image array is
103 // single-dimension.
104 //
105 // The latter case (with EOpIndexIndirect) is not valid GLSL (up to GL_EXT_gpu_shader5), but if it
106 // were, the index itself could have contained an image builtin function call, so is recursively
107 // processed (in case supported in future). Additionally, the other builtin function arguments may
108 // need processing too.
109 //
110 // This function creates a similar expression where the image uniforms (of type r32f) are replaced
111 // with those of r32ui type.
112 //
RewriteBuiltinFunctionCall(TCompiler * compiler,TSymbolTable * symbolTable,TIntermAggregate * node,const ImageMap & imageMap)113 TIntermTyped *RewriteBuiltinFunctionCall(TCompiler *compiler,
114 TSymbolTable *symbolTable,
115 TIntermAggregate *node,
116 const ImageMap &imageMap)
117 {
118 if (!BuiltInGroup::IsBuiltIn(node->getOp()))
119 {
120 // AST functions don't require modification as r32f image function parameters are removed by
121 // MonomorphizeUnsupportedFunctionsInVulkanGLSL.
122 return nullptr;
123 }
124
125 // If it's an |image*| function, replace the function with an equivalent that uses an r32ui
126 // image.
127 if (!node->getFunction()->isImageFunction())
128 {
129 return nullptr;
130 }
131
132 TIntermSequence *arguments = node->getSequence();
133
134 TIntermTyped *imageExpression = (*arguments)[0]->getAsTyped();
135 ASSERT(imageExpression);
136
137 // Find the image uniform that's being indexed, if indexed.
138 TIntermBinary *asBinary = imageExpression->getAsBinaryNode();
139 TIntermSymbol *imageUniform = imageExpression->getAsSymbolNode();
140
141 if (asBinary)
142 {
143 ASSERT(asBinary->getOp() == EOpIndexDirect || asBinary->getOp() == EOpIndexIndirect);
144 imageUniform = asBinary->getLeft()->getAsSymbolNode();
145 }
146
147 ASSERT(imageUniform);
148 if (!IsR32fImage(imageUniform->getType()))
149 {
150 return nullptr;
151 }
152
153 ASSERT(imageMap.find(&imageUniform->variable()) != imageMap.end());
154 const TVariable *replacementImage = imageMap.at(&imageUniform->variable());
155
156 // Build the expression again, with the image uniform replaced. If index is dynamic,
157 // recursively process it.
158 TIntermTyped *replacementExpression = new TIntermSymbol(replacementImage);
159
160 // Index it, if indexed.
161 if (asBinary != nullptr)
162 {
163 TIntermTyped *index = asBinary->getRight();
164
165 switch (asBinary->getOp())
166 {
167 case EOpIndexDirect:
168 break;
169 case EOpIndexIndirect:
170 {
171 // Run RewriteExpressionTraverser on the index node. This case is currently
172 // impossible with known extensions.
173 UNREACHABLE();
174 index = RewriteExpression(compiler, symbolTable, index, imageMap);
175 break;
176 }
177 default:
178 UNREACHABLE();
179 break;
180 }
181
182 replacementExpression = new TIntermBinary(asBinary->getOp(), replacementExpression, index);
183 }
184
185 TIntermSequence substituteArguments;
186 substituteArguments.push_back(replacementExpression);
187
188 for (size_t argIndex = 1; argIndex < arguments->size(); ++argIndex)
189 {
190 TIntermTyped *arg = (*arguments)[argIndex]->getAsTyped();
191
192 // Run RewriteExpressionTraverser on the argument. It may itself be an expression with an
193 // r32f image that needs to be rewritten.
194 arg = RewriteExpression(compiler, symbolTable, arg, imageMap);
195 substituteArguments.push_back(arg);
196 }
197
198 const ImmutableString &functionName = node->getFunction()->name();
199 bool isImageAtomicExchange = functionName == "imageAtomicExchange";
200 bool isImageLoad = false;
201
202 if (functionName == "imageStore" || isImageAtomicExchange)
203 {
204 // The last parameter is float data, which should be changed to floatBitsToUint(data).
205 TIntermTyped *data = substituteArguments.back()->getAsTyped();
206 substituteArguments.back() =
207 CreateBuiltInUnaryFunctionCallNode("floatBitsToUint", data, *symbolTable, 300);
208 }
209 else if (functionName == "imageLoad")
210 {
211 isImageLoad = true;
212 }
213 else
214 {
215 // imageSize does not have any other arguments.
216 ASSERT(functionName == "imageSize");
217 ASSERT(arguments->size() == 1);
218 }
219
220 TIntermTyped *replacementCall =
221 CreateBuiltInFunctionCallNode(functionName.data(), &substituteArguments, *symbolTable, 310);
222
223 // If imageLoad or imageAtomicExchange, the result is now uint, which should be converted with
224 // uintBitsToFloat. With imageLoad, the alpha channel should always read 1.0 regardless.
225 if (isImageLoad || isImageAtomicExchange)
226 {
227 if (isImageLoad)
228 {
229 // imageLoad().rgb
230 replacementCall = new TIntermSwizzle(replacementCall, {0, 1, 2});
231 }
232
233 // uintBitsToFloat(imageLoad().rgb), or uintBitsToFloat(imageAtomicExchange())
234 replacementCall = CreateBuiltInUnaryFunctionCallNode("uintBitsToFloat", replacementCall,
235 *symbolTable, 300);
236
237 if (isImageLoad)
238 {
239 // vec4(uintBitsToFloat(imageLoad().rgb), 1.0)
240 const TType &vec4Type = *StaticType::GetBasic<EbtFloat, 4>();
241 TIntermSequence constructorArgs = {replacementCall, CreateFloatNode(1.0f)};
242 replacementCall = TIntermAggregate::CreateConstructor(vec4Type, &constructorArgs);
243 }
244 }
245
246 return replacementCall;
247 }
248
249 // Traverser that:
250 //
251 // 1. Converts the layout(r32f, ...) ... image* name; declarations to use the r32ui format
252 // 2. Converts |imageLoad| and |imageStore| functions to use |uintBitsToFloat| and |floatBitsToUint|
253 // respectively.
254 // 3. Converts |imageAtomicExchange| to use |floatBitsToUint| and |uintBitsToFloat|.
255 class RewriteR32fImagesTraverser : public TIntermTraverser
256 {
257 public:
RewriteR32fImagesTraverser(TCompiler * compiler,TSymbolTable * symbolTable)258 RewriteR32fImagesTraverser(TCompiler *compiler, TSymbolTable *symbolTable)
259 : TIntermTraverser(true, false, false, symbolTable), mCompiler(compiler)
260 {}
261
visitDeclaration(Visit visit,TIntermDeclaration * node)262 bool visitDeclaration(Visit visit, TIntermDeclaration *node) override
263 {
264 if (visit != PreVisit)
265 {
266 return true;
267 }
268
269 const TIntermSequence &sequence = *(node->getSequence());
270
271 TIntermTyped *declVariable = sequence.front()->getAsTyped();
272 const TType &type = declVariable->getType();
273
274 if (!IsR32fImage(type))
275 {
276 return true;
277 }
278
279 TIntermSymbol *oldSymbol = declVariable->getAsSymbolNode();
280 ASSERT(oldSymbol != nullptr);
281
282 const TVariable &oldVariable = oldSymbol->variable();
283
284 TType *newType = new TType(type);
285 TLayoutQualifier layoutQualifier = type.getLayoutQualifier();
286 layoutQualifier.imageInternalFormat = EiifR32UI;
287 newType->setLayoutQualifier(layoutQualifier);
288
289 switch (type.getBasicType())
290 {
291 case EbtImage2D:
292 newType->setBasicType(EbtUImage2D);
293 break;
294 case EbtImage3D:
295 newType->setBasicType(EbtUImage3D);
296 break;
297 case EbtImage2DArray:
298 newType->setBasicType(EbtUImage2DArray);
299 break;
300 case EbtImageCube:
301 newType->setBasicType(EbtUImageCube);
302 break;
303 case EbtImage1D:
304 newType->setBasicType(EbtUImage1D);
305 break;
306 case EbtImage1DArray:
307 newType->setBasicType(EbtUImage1DArray);
308 break;
309 case EbtImage2DMS:
310 newType->setBasicType(EbtUImage2DMS);
311 break;
312 case EbtImage2DMSArray:
313 newType->setBasicType(EbtUImage2DMSArray);
314 break;
315 case EbtImageCubeArray:
316 newType->setBasicType(EbtUImageCubeArray);
317 break;
318 case EbtImageRect:
319 newType->setBasicType(EbtUImageRect);
320 break;
321 case EbtImageBuffer:
322 newType->setBasicType(EbtUImageBuffer);
323 break;
324 default:
325 UNREACHABLE();
326 }
327
328 TVariable *newVariable =
329 new TVariable(oldVariable.uniqueId(), oldVariable.name(), oldVariable.symbolType(),
330 oldVariable.extensions(), newType);
331
332 mImageMap[&oldVariable] = newVariable;
333
334 TIntermDeclaration *newDecl = new TIntermDeclaration();
335 newDecl->appendDeclarator(new TIntermSymbol(newVariable));
336
337 queueReplacement(newDecl, OriginalNode::IS_DROPPED);
338
339 return false;
340 }
341
342 // Same implementation as in RewriteExpressionTraverser. That traverser cannot replace root.
visitAggregate(Visit visit,TIntermAggregate * node)343 bool visitAggregate(Visit visit, TIntermAggregate *node) override
344 {
345 TIntermTyped *rewritten =
346 RewriteBuiltinFunctionCall(mCompiler, mSymbolTable, node, mImageMap);
347 if (rewritten == nullptr)
348 {
349 return true;
350 }
351
352 queueReplacement(rewritten, OriginalNode::IS_DROPPED);
353
354 return false;
355 }
356
visitSymbol(TIntermSymbol * symbol)357 void visitSymbol(TIntermSymbol *symbol) override
358 {
359 // Cannot encounter the image symbol directly. It can only be used with built-in functions,
360 // and therefore it's handled by visitAggregate.
361 ASSERT(!IsR32fImage(symbol->getType()));
362 }
363
364 private:
365 TCompiler *mCompiler;
366
367 // Map from r32f image to r32ui image
368 ImageMap mImageMap;
369 };
370
371 } // anonymous namespace
372
RewriteR32fImages(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable)373 bool RewriteR32fImages(TCompiler *compiler, TIntermBlock *root, TSymbolTable *symbolTable)
374 {
375 RewriteR32fImagesTraverser traverser(compiler, symbolTable);
376 root->traverse(&traverser);
377 return traverser.updateTree(compiler, root);
378 }
379 } // namespace sh
380