1 //
2 // Copyright 2021 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // MonomorphizeUnsupportedFunctions: Monomorphize functions that are called with
7 // parameters that are incompatible with both Vulkan GLSL and Metal.
8 //
9
10 #include "compiler/translator/tree_ops/MonomorphizeUnsupportedFunctions.h"
11
12 #include "compiler/translator/ImmutableStringBuilder.h"
13 #include "compiler/translator/SymbolTable.h"
14 #include "compiler/translator/tree_util/IntermNode_util.h"
15 #include "compiler/translator/tree_util/IntermTraverse.h"
16 #include "compiler/translator/tree_util/ReplaceVariable.h"
17
18 namespace sh
19 {
20 namespace
21 {
22 struct Argument
23 {
24 size_t argumentIndex;
25 TIntermTyped *argument;
26 };
27
28 struct FunctionData
29 {
30 // Whether the original function is used. If this is false, the function can be removed because
31 // all callers have been modified.
32 bool isOriginalUsed;
33 // The original definition of the function, used to create the monomorphized version.
34 TIntermFunctionDefinition *originalDefinition;
35 // List of monomorphized versions of this function. They will be added next to the original
36 // version (or replace it).
37 TVector<TIntermFunctionDefinition *> monomorphizedDefinitions;
38 };
39
40 using FunctionMap = angle::HashMap<const TFunction *, FunctionData>;
41
42 // Traverse the function definitions and initialize the map. Allows visitAggregate to have access
43 // to TIntermFunctionDefinition even when the function is only forward declared at that point.
InitializeFunctionMap(TIntermBlock * root,FunctionMap * functionMapOut)44 void InitializeFunctionMap(TIntermBlock *root, FunctionMap *functionMapOut)
45 {
46 TIntermSequence &sequence = *root->getSequence();
47
48 for (TIntermNode *node : sequence)
49 {
50 TIntermFunctionDefinition *asFuncDef = node->getAsFunctionDefinition();
51 if (asFuncDef != nullptr)
52 {
53 const TFunction *function = asFuncDef->getFunction();
54 ASSERT(function && functionMapOut->find(function) == functionMapOut->end());
55 (*functionMapOut)[function] = FunctionData{false, asFuncDef, {}};
56 }
57 }
58 }
59
GetBaseUniform(TIntermTyped * node,bool * isSamplerInStructOut)60 const TVariable *GetBaseUniform(TIntermTyped *node, bool *isSamplerInStructOut)
61 {
62 *isSamplerInStructOut = false;
63
64 while (node->getAsBinaryNode())
65 {
66 TIntermBinary *asBinary = node->getAsBinaryNode();
67
68 TOperator op = asBinary->getOp();
69
70 // No opaque uniform can be inside an interface block.
71 if (op == EOpIndexDirectInterfaceBlock)
72 {
73 return nullptr;
74 }
75
76 if (op == EOpIndexDirectStruct)
77 {
78 *isSamplerInStructOut = true;
79 }
80
81 node = asBinary->getLeft();
82 }
83
84 // Only interested in uniform opaque types. If a function call within another function uses
85 // opaque uniforms in an unsupported way, it will be replaced in a follow up pass after the
86 // calling function is monomorphized.
87 if (node->getType().getQualifier() != EvqUniform)
88 {
89 return nullptr;
90 }
91
92 ASSERT(IsOpaqueType(node->getType().getBasicType()) ||
93 node->getType().isStructureContainingSamplers());
94
95 TIntermSymbol *asSymbol = node->getAsSymbolNode();
96 ASSERT(asSymbol);
97
98 return &asSymbol->variable();
99 }
100
ExtractSideEffects(TSymbolTable * symbolTable,TIntermTyped * node,TIntermSequence * replacementIndices)101 TIntermTyped *ExtractSideEffects(TSymbolTable *symbolTable,
102 TIntermTyped *node,
103 TIntermSequence *replacementIndices)
104 {
105 TIntermTyped *withoutSideEffects = node->deepCopy();
106
107 for (TIntermBinary *asBinary = withoutSideEffects->getAsBinaryNode(); asBinary;
108 asBinary = asBinary->getLeft()->getAsBinaryNode())
109 {
110 TOperator op = asBinary->getOp();
111 TIntermTyped *index = asBinary->getRight();
112
113 if (op == EOpIndexDirectStruct)
114 {
115 break;
116 }
117
118 // No side effects with constant expressions.
119 if (op == EOpIndexDirect)
120 {
121 ASSERT(index->getAsConstantUnion());
122 continue;
123 }
124
125 ASSERT(op == EOpIndexIndirect);
126
127 // If the index is a symbol, there's no side effect, so leave it as-is.
128 if (index->getAsSymbolNode())
129 {
130 continue;
131 }
132
133 // Otherwise create a temp variable initialized with the index and use that temp variable as
134 // the index.
135 TIntermDeclaration *tempDecl = nullptr;
136 TVariable *tempVar = DeclareTempVariable(symbolTable, index, EvqTemporary, &tempDecl);
137
138 replacementIndices->push_back(tempDecl);
139 asBinary->replaceChildNode(index, new TIntermSymbol(tempVar));
140 }
141
142 return withoutSideEffects;
143 }
144
CreateMonomorphizedFunctionCallArgs(const TIntermSequence & originalCallArguments,const TVector<Argument> & replacedArguments,TIntermSequence * substituteArgsOut)145 void CreateMonomorphizedFunctionCallArgs(const TIntermSequence &originalCallArguments,
146 const TVector<Argument> &replacedArguments,
147 TIntermSequence *substituteArgsOut)
148 {
149 size_t nextReplacedArg = 0;
150 for (size_t argIndex = 0; argIndex < originalCallArguments.size(); ++argIndex)
151 {
152 if (nextReplacedArg >= replacedArguments.size() ||
153 argIndex != replacedArguments[nextReplacedArg].argumentIndex)
154 {
155 // Not replaced, keep argument as is.
156 substituteArgsOut->push_back(originalCallArguments[argIndex]);
157 }
158 else
159 {
160 TIntermTyped *argument = replacedArguments[nextReplacedArg].argument;
161
162 // Iterate over indices of the argument and create a new arg for every non-const
163 // index. Note that the index itself may be an expression, and it may require further
164 // substitution in the next pass.
165 while (argument->getAsBinaryNode())
166 {
167 TIntermBinary *asBinary = argument->getAsBinaryNode();
168 if (asBinary->getOp() == EOpIndexIndirect)
169 {
170 TIntermTyped *index = asBinary->getRight();
171 substituteArgsOut->push_back(index->deepCopy());
172 }
173 argument = asBinary->getLeft();
174 }
175
176 ++nextReplacedArg;
177 }
178 }
179 }
180
MonomorphizeFunction(TSymbolTable * symbolTable,const TFunction * original,TVector<Argument> * replacedArguments,VariableReplacementMap * argumentMapOut)181 const TFunction *MonomorphizeFunction(TSymbolTable *symbolTable,
182 const TFunction *original,
183 TVector<Argument> *replacedArguments,
184 VariableReplacementMap *argumentMapOut)
185 {
186 TFunction *substituteFunction =
187 new TFunction(symbolTable, kEmptyImmutableString, SymbolType::AngleInternal,
188 &original->getReturnType(), original->isKnownToNotHaveSideEffects());
189
190 size_t nextReplacedArg = 0;
191 for (size_t paramIndex = 0; paramIndex < original->getParamCount(); ++paramIndex)
192 {
193 const TVariable *originalParam = original->getParam(paramIndex);
194
195 if (nextReplacedArg >= replacedArguments->size() ||
196 paramIndex != (*replacedArguments)[nextReplacedArg].argumentIndex)
197 {
198 TVariable *substituteArgument =
199 new TVariable(symbolTable, originalParam->name(), &originalParam->getType(),
200 originalParam->symbolType());
201 // Not replaced, add an identical parameter.
202 substituteFunction->addParameter(substituteArgument);
203 (*argumentMapOut)[originalParam] = new TIntermSymbol(substituteArgument);
204 }
205 else
206 {
207 TIntermTyped *substituteArgument = (*replacedArguments)[nextReplacedArg].argument;
208 (*argumentMapOut)[originalParam] = substituteArgument;
209
210 // Iterate over indices of the argument and create a new parameter for every non-const
211 // index (which may be an expression). Replace the symbol in the argument with a
212 // variable of the index type. This is later used to replace the parameter in the
213 // function body.
214 while (substituteArgument->getAsBinaryNode())
215 {
216 TIntermBinary *asBinary = substituteArgument->getAsBinaryNode();
217 if (asBinary->getOp() == EOpIndexIndirect)
218 {
219 TIntermTyped *index = asBinary->getRight();
220 TType *indexType = new TType(index->getType());
221 indexType->setQualifier(EvqParamIn);
222
223 TVariable *param = new TVariable(symbolTable, kEmptyImmutableString, indexType,
224 SymbolType::AngleInternal);
225 substituteFunction->addParameter(param);
226
227 // The argument now uses the function parameters as indices.
228 asBinary->replaceChildNode(asBinary->getRight(), new TIntermSymbol(param));
229 }
230 substituteArgument = asBinary->getLeft();
231 }
232
233 ++nextReplacedArg;
234 }
235 }
236
237 return substituteFunction;
238 }
239
240 class MonomorphizeTraverser final : public TIntermTraverser
241 {
242 public:
MonomorphizeTraverser(TCompiler * compiler,TSymbolTable * symbolTable,const ShCompileOptions & compileOptions,UnsupportedFunctionArgsBitSet unsupportedFunctionArgs,FunctionMap * functionMap)243 explicit MonomorphizeTraverser(TCompiler *compiler,
244 TSymbolTable *symbolTable,
245 const ShCompileOptions &compileOptions,
246 UnsupportedFunctionArgsBitSet unsupportedFunctionArgs,
247 FunctionMap *functionMap)
248 : TIntermTraverser(true, false, false, symbolTable),
249 mCompiler(compiler),
250 mCompileOptions(compileOptions),
251 mUnsupportedFunctionArgs(unsupportedFunctionArgs),
252 mFunctionMap(functionMap)
253 {}
254
visitAggregate(Visit visit,TIntermAggregate * node)255 bool visitAggregate(Visit visit, TIntermAggregate *node) override
256 {
257 if (node->getOp() != EOpCallFunctionInAST)
258 {
259 return true;
260 }
261
262 const TFunction *function = node->getFunction();
263 ASSERT(function && mFunctionMap->find(function) != mFunctionMap->end());
264
265 FunctionData &data = (*mFunctionMap)[function];
266
267 TIntermFunctionDefinition *monomorphized =
268 processFunctionCall(node, data.originalDefinition, &data.isOriginalUsed);
269 if (monomorphized)
270 {
271 data.monomorphizedDefinitions.push_back(monomorphized);
272 }
273
274 return true;
275 }
276
getAnyMonomorphized() const277 bool getAnyMonomorphized() const { return mAnyMonomorphized; }
278
279 private:
isUnsupportedArgument(TIntermTyped * callArgument,const TVariable * funcArgument) const280 bool isUnsupportedArgument(TIntermTyped *callArgument, const TVariable *funcArgument) const
281 {
282 // Only interested in opaque uniforms and structs that contain samplers.
283 const bool isOpaqueType = IsOpaqueType(funcArgument->getType().getBasicType());
284 const bool isStructContainingSamplers =
285 funcArgument->getType().isStructureContainingSamplers();
286 if (!isOpaqueType && !isStructContainingSamplers)
287 {
288 return false;
289 }
290
291 // If not uniform (the variable was itself a function parameter), don't process it in
292 // this pass, as we don't know which actual uniform it corresponds to.
293 bool isSamplerInStruct = false;
294 const TVariable *uniform = GetBaseUniform(callArgument, &isSamplerInStruct);
295 if (uniform == nullptr)
296 {
297 return false;
298 }
299
300 const TType &type = uniform->getType();
301
302 if (mUnsupportedFunctionArgs[UnsupportedFunctionArgs::StructContainingSamplers])
303 {
304 // Monomorphize if the parameter is a structure that contains samplers (so in
305 // RewriteStructSamplers we don't need to rewrite the functions to accept multiple
306 // parameters split from the struct).
307 if (isStructContainingSamplers)
308 {
309 return true;
310 }
311 }
312
313 if (mUnsupportedFunctionArgs[UnsupportedFunctionArgs::ArrayOfArrayOfSamplerOrImage])
314 {
315 // Monomorphize if:
316 //
317 // - The opaque uniform is a sampler in a struct (which can create an array-of-array
318 // situation), and the function expects an array of samplers, or
319 //
320 // - The opaque uniform is an array of array of sampler or image, and it's partially
321 // subscripted (i.e. the function itself expects an array)
322 //
323 const bool isParameterArrayOfOpaqueType = funcArgument->getType().isArray();
324 const bool isArrayOfArrayOfSamplerOrImage =
325 (type.isSampler() || type.isImage()) && type.isArrayOfArrays();
326 if (isSamplerInStruct && isParameterArrayOfOpaqueType)
327 {
328 return true;
329 }
330 if (isArrayOfArrayOfSamplerOrImage && isParameterArrayOfOpaqueType)
331 {
332 return true;
333 }
334 }
335
336 if (mUnsupportedFunctionArgs[UnsupportedFunctionArgs::AtomicCounter])
337 {
338 if (type.isAtomicCounter())
339 {
340 return true;
341 }
342 }
343
344 if (mUnsupportedFunctionArgs[UnsupportedFunctionArgs::SamplerCubeEmulation])
345 {
346 // Monomorphize if the opaque uniform is a samplerCube and ES2's cube sampling emulation
347 // is requested.
348 if (type.isSamplerCube() && mCompileOptions.emulateSeamfulCubeMapSampling)
349 {
350 return true;
351 }
352 }
353
354 if (mUnsupportedFunctionArgs[UnsupportedFunctionArgs::Image])
355 {
356 if (type.isImage())
357 {
358 return true;
359 }
360 }
361
362 if (mUnsupportedFunctionArgs[UnsupportedFunctionArgs::PixelLocalStorage])
363 {
364 if (type.isPixelLocal())
365 {
366 return true;
367 }
368 }
369
370 return false;
371 }
372
processFunctionCall(TIntermAggregate * functionCall,TIntermFunctionDefinition * originalDefinition,bool * isOriginalUsedOut)373 TIntermFunctionDefinition *processFunctionCall(TIntermAggregate *functionCall,
374 TIntermFunctionDefinition *originalDefinition,
375 bool *isOriginalUsedOut)
376 {
377 const TFunction *function = functionCall->getFunction();
378 const TIntermSequence &callArguments = *functionCall->getSequence();
379
380 TVector<Argument> replacedArguments;
381 TIntermSequence replacementIndices;
382
383 // Go through function call arguments, and see if any is used in an unsupported way.
384 for (size_t argIndex = 0; argIndex < callArguments.size(); ++argIndex)
385 {
386 TIntermTyped *callArgument = callArguments[argIndex]->getAsTyped();
387 const TVariable *funcArgument = function->getParam(argIndex);
388 if (isUnsupportedArgument(callArgument, funcArgument))
389 {
390 // Copy the argument and extract the side effects.
391 TIntermTyped *argument =
392 ExtractSideEffects(mSymbolTable, callArgument, &replacementIndices);
393
394 replacedArguments.push_back({argIndex, argument});
395 }
396 }
397
398 if (replacedArguments.empty())
399 {
400 *isOriginalUsedOut = true;
401 return nullptr;
402 }
403
404 mAnyMonomorphized = true;
405
406 insertStatementsInParentBlock(replacementIndices);
407
408 // Create the arguments for the substitute function call. Done before monomorphizing the
409 // function, which transforms the arguments to what needs to be replaced in the function
410 // body.
411 TIntermSequence newCallArgs;
412 CreateMonomorphizedFunctionCallArgs(callArguments, replacedArguments, &newCallArgs);
413
414 // Duplicate the function and substitute the replaced arguments with only the non-const
415 // indices. Additionally, substitute the non-const indices of arguments with the new
416 // function parameters.
417 VariableReplacementMap argumentMap;
418 const TFunction *monomorphized =
419 MonomorphizeFunction(mSymbolTable, function, &replacedArguments, &argumentMap);
420
421 // Replace this function call with a call to the new one.
422 queueReplacement(TIntermAggregate::CreateFunctionCall(*monomorphized, &newCallArgs),
423 OriginalNode::IS_DROPPED);
424
425 // Create a new function definition, with the body of the old function but with the replaced
426 // parameters substituted with the calling expressions.
427 TIntermFunctionPrototype *substitutePrototype = new TIntermFunctionPrototype(monomorphized);
428 TIntermBlock *substituteBlock = originalDefinition->getBody()->deepCopy();
429 GetDeclaratorReplacements(mSymbolTable, substituteBlock, &argumentMap);
430 bool valid = ReplaceVariables(mCompiler, substituteBlock, argumentMap);
431 ASSERT(valid);
432
433 return new TIntermFunctionDefinition(substitutePrototype, substituteBlock);
434 }
435
436 TCompiler *mCompiler;
437 const ShCompileOptions &mCompileOptions;
438 UnsupportedFunctionArgsBitSet mUnsupportedFunctionArgs;
439 bool mAnyMonomorphized = false;
440
441 // Map of original to monomorphized functions.
442 FunctionMap *mFunctionMap;
443 };
444
445 class UpdateFunctionsDefinitionsTraverser final : public TIntermTraverser
446 {
447 public:
UpdateFunctionsDefinitionsTraverser(TSymbolTable * symbolTable,const FunctionMap & functionMap)448 explicit UpdateFunctionsDefinitionsTraverser(TSymbolTable *symbolTable,
449 const FunctionMap &functionMap)
450 : TIntermTraverser(true, false, false, symbolTable), mFunctionMap(functionMap)
451 {}
452
visitFunctionPrototype(TIntermFunctionPrototype * node)453 void visitFunctionPrototype(TIntermFunctionPrototype *node) override
454 {
455 const bool isInFunctionDefinition = getParentNode()->getAsFunctionDefinition() != nullptr;
456 if (isInFunctionDefinition)
457 {
458 return;
459 }
460
461 // Add to and possibly replace the function prototype with replacement prototypes.
462 const TFunction *function = node->getFunction();
463 ASSERT(function && mFunctionMap.find(function) != mFunctionMap.end());
464
465 const FunctionData &data = mFunctionMap.at(function);
466
467 // If nothing to do, leave it be.
468 if (data.monomorphizedDefinitions.empty())
469 {
470 ASSERT(data.isOriginalUsed);
471 return;
472 }
473
474 // Replace the prototype with itself (if function is still used) as well as any
475 // monomorphized versions.
476 TIntermSequence replacement;
477 if (data.isOriginalUsed)
478 {
479 replacement.push_back(node);
480 }
481 for (TIntermFunctionDefinition *monomorphizedDefinition : data.monomorphizedDefinitions)
482 {
483 replacement.push_back(new TIntermFunctionPrototype(
484 monomorphizedDefinition->getFunctionPrototype()->getFunction()));
485 }
486 mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node,
487 std::move(replacement));
488 }
489
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)490 bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override
491 {
492 // Add to and possibly replace the function definition with replacement definitions.
493 const TFunction *function = node->getFunction();
494 ASSERT(function && mFunctionMap.find(function) != mFunctionMap.end());
495
496 const FunctionData &data = mFunctionMap.at(function);
497
498 // If nothing to do, leave it be.
499 if (data.monomorphizedDefinitions.empty())
500 {
501 ASSERT(data.isOriginalUsed || function->name() == "main");
502 return false;
503 }
504
505 // Replace the definition with itself (if function is still used) as well as any
506 // monomorphized versions.
507 TIntermSequence replacement;
508 if (data.isOriginalUsed)
509 {
510 replacement.push_back(node);
511 }
512 for (TIntermFunctionDefinition *monomorphizedDefinition : data.monomorphizedDefinitions)
513 {
514 replacement.push_back(monomorphizedDefinition);
515 }
516 mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node,
517 std::move(replacement));
518
519 return false;
520 }
521
522 private:
523 const FunctionMap &mFunctionMap;
524 };
525
SortDeclarations(TIntermBlock * root)526 void SortDeclarations(TIntermBlock *root)
527 {
528 TIntermSequence *original = root->getSequence();
529
530 TIntermSequence replacement;
531 TIntermSequence functionDefs;
532
533 // Accumulate non-function-definition declarations in |replacement| and function definitions in
534 // |functionDefs|.
535 for (TIntermNode *node : *original)
536 {
537 if (node->getAsFunctionDefinition() || node->getAsFunctionPrototypeNode())
538 {
539 functionDefs.push_back(node);
540 }
541 else
542 {
543 replacement.push_back(node);
544 }
545 }
546
547 // Append function definitions to |replacement|.
548 replacement.insert(replacement.end(), functionDefs.begin(), functionDefs.end());
549
550 // Replace root's sequence with |replacement|.
551 root->replaceAllChildren(replacement);
552 }
553
MonomorphizeUnsupportedFunctionsImpl(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,const ShCompileOptions & compileOptions,UnsupportedFunctionArgsBitSet unsupportedFunctionArgs)554 bool MonomorphizeUnsupportedFunctionsImpl(TCompiler *compiler,
555 TIntermBlock *root,
556 TSymbolTable *symbolTable,
557 const ShCompileOptions &compileOptions,
558 UnsupportedFunctionArgsBitSet unsupportedFunctionArgs)
559 {
560 // First, sort out the declarations such that all non-function declarations are placed before
561 // function definitions. This way when the function is replaced with one that references said
562 // declarations (i.e. uniforms), the uniform declaration is already present above it.
563 SortDeclarations(root);
564
565 while (true)
566 {
567 FunctionMap functionMap;
568 InitializeFunctionMap(root, &functionMap);
569
570 MonomorphizeTraverser monomorphizer(compiler, symbolTable, compileOptions,
571 unsupportedFunctionArgs, &functionMap);
572 root->traverse(&monomorphizer);
573
574 if (!monomorphizer.getAnyMonomorphized())
575 {
576 break;
577 }
578
579 if (!monomorphizer.updateTree(compiler, root))
580 {
581 return false;
582 }
583
584 UpdateFunctionsDefinitionsTraverser functionUpdater(symbolTable, functionMap);
585 root->traverse(&functionUpdater);
586
587 if (!functionUpdater.updateTree(compiler, root))
588 {
589 return false;
590 }
591 }
592
593 return true;
594 }
595 } // anonymous namespace
596
MonomorphizeUnsupportedFunctions(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,const ShCompileOptions & compileOptions,UnsupportedFunctionArgsBitSet unsupportedFunctionArgs)597 bool MonomorphizeUnsupportedFunctions(TCompiler *compiler,
598 TIntermBlock *root,
599 TSymbolTable *symbolTable,
600 const ShCompileOptions &compileOptions,
601 UnsupportedFunctionArgsBitSet unsupportedFunctionArgs)
602 {
603 // This function actually applies multiple transformation, and the AST may not be valid until
604 // the transformations are entirely done. Some validation is momentarily disabled.
605 bool enableValidateFunctionCall = compiler->disableValidateFunctionCall();
606
607 bool result = MonomorphizeUnsupportedFunctionsImpl(compiler, root, symbolTable, compileOptions,
608 unsupportedFunctionArgs);
609
610 compiler->restoreValidateFunctionCall(enableValidateFunctionCall);
611 return result && compiler->validateAST(root);
612 }
613 } // namespace sh
614