• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2021 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // MonomorphizeUnsupportedFunctions: Monomorphize functions that are called with
7 // parameters that are incompatible with both Vulkan GLSL and Metal.
8 //
9 
10 #include "compiler/translator/tree_ops/MonomorphizeUnsupportedFunctions.h"
11 
12 #include "compiler/translator/ImmutableStringBuilder.h"
13 #include "compiler/translator/SymbolTable.h"
14 #include "compiler/translator/tree_util/IntermNode_util.h"
15 #include "compiler/translator/tree_util/IntermTraverse.h"
16 #include "compiler/translator/tree_util/ReplaceVariable.h"
17 
18 namespace sh
19 {
20 namespace
21 {
22 struct Argument
23 {
24     size_t argumentIndex;
25     TIntermTyped *argument;
26 };
27 
28 struct FunctionData
29 {
30     // Whether the original function is used.  If this is false, the function can be removed because
31     // all callers have been modified.
32     bool isOriginalUsed;
33     // The original definition of the function, used to create the monomorphized version.
34     TIntermFunctionDefinition *originalDefinition;
35     // List of monomorphized versions of this function.  They will be added next to the original
36     // version (or replace it).
37     TVector<TIntermFunctionDefinition *> monomorphizedDefinitions;
38 };
39 
40 using FunctionMap = angle::HashMap<const TFunction *, FunctionData>;
41 
42 // Traverse the function definitions and initialize the map.  Allows visitAggregate to have access
43 // to TIntermFunctionDefinition even when the function is only forward declared at that point.
InitializeFunctionMap(TIntermBlock * root,FunctionMap * functionMapOut)44 void InitializeFunctionMap(TIntermBlock *root, FunctionMap *functionMapOut)
45 {
46     TIntermSequence &sequence = *root->getSequence();
47 
48     for (TIntermNode *node : sequence)
49     {
50         TIntermFunctionDefinition *asFuncDef = node->getAsFunctionDefinition();
51         if (asFuncDef != nullptr)
52         {
53             const TFunction *function = asFuncDef->getFunction();
54             ASSERT(function && functionMapOut->find(function) == functionMapOut->end());
55             (*functionMapOut)[function] = FunctionData{false, asFuncDef, {}};
56         }
57     }
58 }
59 
GetBaseUniform(TIntermTyped * node,bool * isSamplerInStructOut)60 const TVariable *GetBaseUniform(TIntermTyped *node, bool *isSamplerInStructOut)
61 {
62     *isSamplerInStructOut = false;
63 
64     while (node->getAsBinaryNode())
65     {
66         TIntermBinary *asBinary = node->getAsBinaryNode();
67 
68         TOperator op = asBinary->getOp();
69 
70         // No opaque uniform can be inside an interface block.
71         if (op == EOpIndexDirectInterfaceBlock)
72         {
73             return nullptr;
74         }
75 
76         if (op == EOpIndexDirectStruct)
77         {
78             *isSamplerInStructOut = true;
79         }
80 
81         node = asBinary->getLeft();
82     }
83 
84     // Only interested in uniform opaque types.  If a function call within another function uses
85     // opaque uniforms in an unsupported way, it will be replaced in a follow up pass after the
86     // calling function is monomorphized.
87     if (node->getType().getQualifier() != EvqUniform)
88     {
89         return nullptr;
90     }
91 
92     ASSERT(IsOpaqueType(node->getType().getBasicType()) ||
93            node->getType().isStructureContainingSamplers());
94 
95     TIntermSymbol *asSymbol = node->getAsSymbolNode();
96     ASSERT(asSymbol);
97 
98     return &asSymbol->variable();
99 }
100 
ExtractSideEffects(TSymbolTable * symbolTable,TIntermTyped * node,TIntermSequence * replacementIndices)101 TIntermTyped *ExtractSideEffects(TSymbolTable *symbolTable,
102                                  TIntermTyped *node,
103                                  TIntermSequence *replacementIndices)
104 {
105     TIntermTyped *withoutSideEffects = node->deepCopy();
106 
107     for (TIntermBinary *asBinary = withoutSideEffects->getAsBinaryNode(); asBinary;
108          asBinary                = asBinary->getLeft()->getAsBinaryNode())
109     {
110         TOperator op        = asBinary->getOp();
111         TIntermTyped *index = asBinary->getRight();
112 
113         if (op == EOpIndexDirectStruct)
114         {
115             break;
116         }
117 
118         // No side effects with constant expressions.
119         if (op == EOpIndexDirect)
120         {
121             ASSERT(index->getAsConstantUnion());
122             continue;
123         }
124 
125         ASSERT(op == EOpIndexIndirect);
126 
127         // If the index is a symbol, there's no side effect, so leave it as-is.
128         if (index->getAsSymbolNode())
129         {
130             continue;
131         }
132 
133         // Otherwise create a temp variable initialized with the index and use that temp variable as
134         // the index.
135         TIntermDeclaration *tempDecl = nullptr;
136         TVariable *tempVar = DeclareTempVariable(symbolTable, index, EvqTemporary, &tempDecl);
137 
138         replacementIndices->push_back(tempDecl);
139         asBinary->replaceChildNode(index, new TIntermSymbol(tempVar));
140     }
141 
142     return withoutSideEffects;
143 }
144 
CreateMonomorphizedFunctionCallArgs(const TIntermSequence & originalCallArguments,const TVector<Argument> & replacedArguments,TIntermSequence * substituteArgsOut)145 void CreateMonomorphizedFunctionCallArgs(const TIntermSequence &originalCallArguments,
146                                          const TVector<Argument> &replacedArguments,
147                                          TIntermSequence *substituteArgsOut)
148 {
149     size_t nextReplacedArg = 0;
150     for (size_t argIndex = 0; argIndex < originalCallArguments.size(); ++argIndex)
151     {
152         if (nextReplacedArg >= replacedArguments.size() ||
153             argIndex != replacedArguments[nextReplacedArg].argumentIndex)
154         {
155             // Not replaced, keep argument as is.
156             substituteArgsOut->push_back(originalCallArguments[argIndex]);
157         }
158         else
159         {
160             TIntermTyped *argument = replacedArguments[nextReplacedArg].argument;
161 
162             // Iterate over indices of the argument and create a new arg for every non-const
163             // index.  Note that the index itself may be an expression, and it may require further
164             // substitution in the next pass.
165             while (argument->getAsBinaryNode())
166             {
167                 TIntermBinary *asBinary = argument->getAsBinaryNode();
168                 if (asBinary->getOp() == EOpIndexIndirect)
169                 {
170                     TIntermTyped *index = asBinary->getRight();
171                     substituteArgsOut->push_back(index->deepCopy());
172                 }
173                 argument = asBinary->getLeft();
174             }
175 
176             ++nextReplacedArg;
177         }
178     }
179 }
180 
MonomorphizeFunction(TSymbolTable * symbolTable,const TFunction * original,TVector<Argument> * replacedArguments,VariableReplacementMap * argumentMapOut)181 const TFunction *MonomorphizeFunction(TSymbolTable *symbolTable,
182                                       const TFunction *original,
183                                       TVector<Argument> *replacedArguments,
184                                       VariableReplacementMap *argumentMapOut)
185 {
186     TFunction *substituteFunction =
187         new TFunction(symbolTable, kEmptyImmutableString, SymbolType::AngleInternal,
188                       &original->getReturnType(), original->isKnownToNotHaveSideEffects());
189 
190     size_t nextReplacedArg = 0;
191     for (size_t paramIndex = 0; paramIndex < original->getParamCount(); ++paramIndex)
192     {
193         const TVariable *originalParam = original->getParam(paramIndex);
194 
195         if (nextReplacedArg >= replacedArguments->size() ||
196             paramIndex != (*replacedArguments)[nextReplacedArg].argumentIndex)
197         {
198             TVariable *substituteArgument =
199                 new TVariable(symbolTable, originalParam->name(), &originalParam->getType(),
200                               originalParam->symbolType());
201             // Not replaced, add an identical parameter.
202             substituteFunction->addParameter(substituteArgument);
203             (*argumentMapOut)[originalParam] = new TIntermSymbol(substituteArgument);
204         }
205         else
206         {
207             TIntermTyped *substituteArgument = (*replacedArguments)[nextReplacedArg].argument;
208             (*argumentMapOut)[originalParam] = substituteArgument;
209 
210             // Iterate over indices of the argument and create a new parameter for every non-const
211             // index (which may be an expression).  Replace the symbol in the argument with a
212             // variable of the index type.  This is later used to replace the parameter in the
213             // function body.
214             while (substituteArgument->getAsBinaryNode())
215             {
216                 TIntermBinary *asBinary = substituteArgument->getAsBinaryNode();
217                 if (asBinary->getOp() == EOpIndexIndirect)
218                 {
219                     TIntermTyped *index = asBinary->getRight();
220                     TType *indexType    = new TType(index->getType());
221                     indexType->setQualifier(EvqParamIn);
222 
223                     TVariable *param = new TVariable(symbolTable, kEmptyImmutableString, indexType,
224                                                      SymbolType::AngleInternal);
225                     substituteFunction->addParameter(param);
226 
227                     // The argument now uses the function parameters as indices.
228                     asBinary->replaceChildNode(asBinary->getRight(), new TIntermSymbol(param));
229                 }
230                 substituteArgument = asBinary->getLeft();
231             }
232 
233             ++nextReplacedArg;
234         }
235     }
236 
237     return substituteFunction;
238 }
239 
240 class MonomorphizeTraverser final : public TIntermTraverser
241 {
242   public:
MonomorphizeTraverser(TCompiler * compiler,TSymbolTable * symbolTable,const ShCompileOptions & compileOptions,UnsupportedFunctionArgsBitSet unsupportedFunctionArgs,FunctionMap * functionMap)243     explicit MonomorphizeTraverser(TCompiler *compiler,
244                                    TSymbolTable *symbolTable,
245                                    const ShCompileOptions &compileOptions,
246                                    UnsupportedFunctionArgsBitSet unsupportedFunctionArgs,
247                                    FunctionMap *functionMap)
248         : TIntermTraverser(true, false, false, symbolTable),
249           mCompiler(compiler),
250           mCompileOptions(compileOptions),
251           mUnsupportedFunctionArgs(unsupportedFunctionArgs),
252           mFunctionMap(functionMap)
253     {}
254 
visitAggregate(Visit visit,TIntermAggregate * node)255     bool visitAggregate(Visit visit, TIntermAggregate *node) override
256     {
257         if (node->getOp() != EOpCallFunctionInAST)
258         {
259             return true;
260         }
261 
262         const TFunction *function = node->getFunction();
263         ASSERT(function && mFunctionMap->find(function) != mFunctionMap->end());
264 
265         FunctionData &data = (*mFunctionMap)[function];
266 
267         TIntermFunctionDefinition *monomorphized =
268             processFunctionCall(node, data.originalDefinition, &data.isOriginalUsed);
269         if (monomorphized)
270         {
271             data.monomorphizedDefinitions.push_back(monomorphized);
272         }
273 
274         return true;
275     }
276 
getAnyMonomorphized() const277     bool getAnyMonomorphized() const { return mAnyMonomorphized; }
278 
279   private:
isUnsupportedArgument(TIntermTyped * callArgument,const TVariable * funcArgument) const280     bool isUnsupportedArgument(TIntermTyped *callArgument, const TVariable *funcArgument) const
281     {
282         // Only interested in opaque uniforms and structs that contain samplers.
283         const bool isOpaqueType = IsOpaqueType(funcArgument->getType().getBasicType());
284         const bool isStructContainingSamplers =
285             funcArgument->getType().isStructureContainingSamplers();
286         if (!isOpaqueType && !isStructContainingSamplers)
287         {
288             return false;
289         }
290 
291         // If not uniform (the variable was itself a function parameter), don't process it in
292         // this pass, as we don't know which actual uniform it corresponds to.
293         bool isSamplerInStruct   = false;
294         const TVariable *uniform = GetBaseUniform(callArgument, &isSamplerInStruct);
295         if (uniform == nullptr)
296         {
297             return false;
298         }
299 
300         const TType &type = uniform->getType();
301 
302         if (mUnsupportedFunctionArgs[UnsupportedFunctionArgs::StructContainingSamplers])
303         {
304             // Monomorphize if the parameter is a structure that contains samplers (so in
305             // RewriteStructSamplers we don't need to rewrite the functions to accept multiple
306             // parameters split from the struct).
307             if (isStructContainingSamplers)
308             {
309                 return true;
310             }
311         }
312 
313         if (mUnsupportedFunctionArgs[UnsupportedFunctionArgs::ArrayOfArrayOfSamplerOrImage])
314         {
315             // Monomorphize if:
316             //
317             // - The opaque uniform is a sampler in a struct (which can create an array-of-array
318             //   situation), and the function expects an array of samplers, or
319             //
320             // - The opaque uniform is an array of array of sampler or image, and it's partially
321             //   subscripted (i.e. the function itself expects an array)
322             //
323             const bool isParameterArrayOfOpaqueType = funcArgument->getType().isArray();
324             const bool isArrayOfArrayOfSamplerOrImage =
325                 (type.isSampler() || type.isImage()) && type.isArrayOfArrays();
326             if (isSamplerInStruct && isParameterArrayOfOpaqueType)
327             {
328                 return true;
329             }
330             if (isArrayOfArrayOfSamplerOrImage && isParameterArrayOfOpaqueType)
331             {
332                 return true;
333             }
334         }
335 
336         if (mUnsupportedFunctionArgs[UnsupportedFunctionArgs::AtomicCounter])
337         {
338             if (type.isAtomicCounter())
339             {
340                 return true;
341             }
342         }
343 
344         if (mUnsupportedFunctionArgs[UnsupportedFunctionArgs::SamplerCubeEmulation])
345         {
346             // Monomorphize if the opaque uniform is a samplerCube and ES2's cube sampling emulation
347             // is requested.
348             if (type.isSamplerCube() && mCompileOptions.emulateSeamfulCubeMapSampling)
349             {
350                 return true;
351             }
352         }
353 
354         if (mUnsupportedFunctionArgs[UnsupportedFunctionArgs::Image])
355         {
356             if (type.isImage())
357             {
358                 return true;
359             }
360         }
361 
362         if (mUnsupportedFunctionArgs[UnsupportedFunctionArgs::PixelLocalStorage])
363         {
364             if (type.isPixelLocal())
365             {
366                 return true;
367             }
368         }
369 
370         return false;
371     }
372 
processFunctionCall(TIntermAggregate * functionCall,TIntermFunctionDefinition * originalDefinition,bool * isOriginalUsedOut)373     TIntermFunctionDefinition *processFunctionCall(TIntermAggregate *functionCall,
374                                                    TIntermFunctionDefinition *originalDefinition,
375                                                    bool *isOriginalUsedOut)
376     {
377         const TFunction *function            = functionCall->getFunction();
378         const TIntermSequence &callArguments = *functionCall->getSequence();
379 
380         TVector<Argument> replacedArguments;
381         TIntermSequence replacementIndices;
382 
383         // Go through function call arguments, and see if any is used in an unsupported way.
384         for (size_t argIndex = 0; argIndex < callArguments.size(); ++argIndex)
385         {
386             TIntermTyped *callArgument    = callArguments[argIndex]->getAsTyped();
387             const TVariable *funcArgument = function->getParam(argIndex);
388             if (isUnsupportedArgument(callArgument, funcArgument))
389             {
390                 // Copy the argument and extract the side effects.
391                 TIntermTyped *argument =
392                     ExtractSideEffects(mSymbolTable, callArgument, &replacementIndices);
393 
394                 replacedArguments.push_back({argIndex, argument});
395             }
396         }
397 
398         if (replacedArguments.empty())
399         {
400             *isOriginalUsedOut = true;
401             return nullptr;
402         }
403 
404         mAnyMonomorphized = true;
405 
406         insertStatementsInParentBlock(replacementIndices);
407 
408         // Create the arguments for the substitute function call.  Done before monomorphizing the
409         // function, which transforms the arguments to what needs to be replaced in the function
410         // body.
411         TIntermSequence newCallArgs;
412         CreateMonomorphizedFunctionCallArgs(callArguments, replacedArguments, &newCallArgs);
413 
414         // Duplicate the function and substitute the replaced arguments with only the non-const
415         // indices.  Additionally, substitute the non-const indices of arguments with the new
416         // function parameters.
417         VariableReplacementMap argumentMap;
418         const TFunction *monomorphized =
419             MonomorphizeFunction(mSymbolTable, function, &replacedArguments, &argumentMap);
420 
421         // Replace this function call with a call to the new one.
422         queueReplacement(TIntermAggregate::CreateFunctionCall(*monomorphized, &newCallArgs),
423                          OriginalNode::IS_DROPPED);
424 
425         // Create a new function definition, with the body of the old function but with the replaced
426         // parameters substituted with the calling expressions.
427         TIntermFunctionPrototype *substitutePrototype = new TIntermFunctionPrototype(monomorphized);
428         TIntermBlock *substituteBlock                 = originalDefinition->getBody()->deepCopy();
429         GetDeclaratorReplacements(mSymbolTable, substituteBlock, &argumentMap);
430         bool valid = ReplaceVariables(mCompiler, substituteBlock, argumentMap);
431         ASSERT(valid);
432 
433         return new TIntermFunctionDefinition(substitutePrototype, substituteBlock);
434     }
435 
436     TCompiler *mCompiler;
437     const ShCompileOptions &mCompileOptions;
438     UnsupportedFunctionArgsBitSet mUnsupportedFunctionArgs;
439     bool mAnyMonomorphized = false;
440 
441     // Map of original to monomorphized functions.
442     FunctionMap *mFunctionMap;
443 };
444 
445 class UpdateFunctionsDefinitionsTraverser final : public TIntermTraverser
446 {
447   public:
UpdateFunctionsDefinitionsTraverser(TSymbolTable * symbolTable,const FunctionMap & functionMap)448     explicit UpdateFunctionsDefinitionsTraverser(TSymbolTable *symbolTable,
449                                                  const FunctionMap &functionMap)
450         : TIntermTraverser(true, false, false, symbolTable), mFunctionMap(functionMap)
451     {}
452 
visitFunctionPrototype(TIntermFunctionPrototype * node)453     void visitFunctionPrototype(TIntermFunctionPrototype *node) override
454     {
455         const bool isInFunctionDefinition = getParentNode()->getAsFunctionDefinition() != nullptr;
456         if (isInFunctionDefinition)
457         {
458             return;
459         }
460 
461         // Add to and possibly replace the function prototype with replacement prototypes.
462         const TFunction *function = node->getFunction();
463         ASSERT(function && mFunctionMap.find(function) != mFunctionMap.end());
464 
465         const FunctionData &data = mFunctionMap.at(function);
466 
467         // If nothing to do, leave it be.
468         if (data.monomorphizedDefinitions.empty())
469         {
470             ASSERT(data.isOriginalUsed);
471             return;
472         }
473 
474         // Replace the prototype with itself (if function is still used) as well as any
475         // monomorphized versions.
476         TIntermSequence replacement;
477         if (data.isOriginalUsed)
478         {
479             replacement.push_back(node);
480         }
481         for (TIntermFunctionDefinition *monomorphizedDefinition : data.monomorphizedDefinitions)
482         {
483             replacement.push_back(new TIntermFunctionPrototype(
484                 monomorphizedDefinition->getFunctionPrototype()->getFunction()));
485         }
486         mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node,
487                                         std::move(replacement));
488     }
489 
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)490     bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override
491     {
492         // Add to and possibly replace the function definition with replacement definitions.
493         const TFunction *function = node->getFunction();
494         ASSERT(function && mFunctionMap.find(function) != mFunctionMap.end());
495 
496         const FunctionData &data = mFunctionMap.at(function);
497 
498         // If nothing to do, leave it be.
499         if (data.monomorphizedDefinitions.empty())
500         {
501             ASSERT(data.isOriginalUsed || function->name() == "main");
502             return false;
503         }
504 
505         // Replace the definition with itself (if function is still used) as well as any
506         // monomorphized versions.
507         TIntermSequence replacement;
508         if (data.isOriginalUsed)
509         {
510             replacement.push_back(node);
511         }
512         for (TIntermFunctionDefinition *monomorphizedDefinition : data.monomorphizedDefinitions)
513         {
514             replacement.push_back(monomorphizedDefinition);
515         }
516         mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node,
517                                         std::move(replacement));
518 
519         return false;
520     }
521 
522   private:
523     const FunctionMap &mFunctionMap;
524 };
525 
SortDeclarations(TIntermBlock * root)526 void SortDeclarations(TIntermBlock *root)
527 {
528     TIntermSequence *original = root->getSequence();
529 
530     TIntermSequence replacement;
531     TIntermSequence functionDefs;
532 
533     // Accumulate non-function-definition declarations in |replacement| and function definitions in
534     // |functionDefs|.
535     for (TIntermNode *node : *original)
536     {
537         if (node->getAsFunctionDefinition() || node->getAsFunctionPrototypeNode())
538         {
539             functionDefs.push_back(node);
540         }
541         else
542         {
543             replacement.push_back(node);
544         }
545     }
546 
547     // Append function definitions to |replacement|.
548     replacement.insert(replacement.end(), functionDefs.begin(), functionDefs.end());
549 
550     // Replace root's sequence with |replacement|.
551     root->replaceAllChildren(replacement);
552 }
553 
MonomorphizeUnsupportedFunctionsImpl(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,const ShCompileOptions & compileOptions,UnsupportedFunctionArgsBitSet unsupportedFunctionArgs)554 bool MonomorphizeUnsupportedFunctionsImpl(TCompiler *compiler,
555                                           TIntermBlock *root,
556                                           TSymbolTable *symbolTable,
557                                           const ShCompileOptions &compileOptions,
558                                           UnsupportedFunctionArgsBitSet unsupportedFunctionArgs)
559 {
560     // First, sort out the declarations such that all non-function declarations are placed before
561     // function definitions.  This way when the function is replaced with one that references said
562     // declarations (i.e. uniforms), the uniform declaration is already present above it.
563     SortDeclarations(root);
564 
565     while (true)
566     {
567         FunctionMap functionMap;
568         InitializeFunctionMap(root, &functionMap);
569 
570         MonomorphizeTraverser monomorphizer(compiler, symbolTable, compileOptions,
571                                             unsupportedFunctionArgs, &functionMap);
572         root->traverse(&monomorphizer);
573 
574         if (!monomorphizer.getAnyMonomorphized())
575         {
576             break;
577         }
578 
579         if (!monomorphizer.updateTree(compiler, root))
580         {
581             return false;
582         }
583 
584         UpdateFunctionsDefinitionsTraverser functionUpdater(symbolTable, functionMap);
585         root->traverse(&functionUpdater);
586 
587         if (!functionUpdater.updateTree(compiler, root))
588         {
589             return false;
590         }
591     }
592 
593     return true;
594 }
595 }  // anonymous namespace
596 
MonomorphizeUnsupportedFunctions(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,const ShCompileOptions & compileOptions,UnsupportedFunctionArgsBitSet unsupportedFunctionArgs)597 bool MonomorphizeUnsupportedFunctions(TCompiler *compiler,
598                                       TIntermBlock *root,
599                                       TSymbolTable *symbolTable,
600                                       const ShCompileOptions &compileOptions,
601                                       UnsupportedFunctionArgsBitSet unsupportedFunctionArgs)
602 {
603     // This function actually applies multiple transformation, and the AST may not be valid until
604     // the transformations are entirely done.  Some validation is momentarily disabled.
605     bool enableValidateFunctionCall = compiler->disableValidateFunctionCall();
606 
607     bool result = MonomorphizeUnsupportedFunctionsImpl(compiler, root, symbolTable, compileOptions,
608                                                        unsupportedFunctionArgs);
609 
610     compiler->restoreValidateFunctionCall(enableValidateFunctionCall);
611     return result && compiler->validateAST(root);
612 }
613 }  // namespace sh
614