• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2021 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // MonomorphizeUnsupportedFunctions: Monomorphize functions that are called with
7 // parameters that are incompatible with both Vulkan GLSL and Metal.
8 //
9 
10 #include "compiler/translator/tree_ops/MonomorphizeUnsupportedFunctions.h"
11 
12 #include "compiler/translator/ImmutableStringBuilder.h"
13 #include "compiler/translator/SymbolTable.h"
14 #include "compiler/translator/tree_util/IntermNode_util.h"
15 #include "compiler/translator/tree_util/IntermTraverse.h"
16 #include "compiler/translator/tree_util/ReplaceVariable.h"
17 
18 namespace sh
19 {
20 namespace
21 {
22 struct Argument
23 {
24     size_t argumentIndex;
25     TIntermTyped *argument;
26 };
27 
28 struct FunctionData
29 {
30     // Whether the original function is used.  If this is false, the function can be removed because
31     // all callers have been modified.
32     bool isOriginalUsed;
33     // The original definition of the function, used to create the monomorphized version.
34     TIntermFunctionDefinition *originalDefinition;
35     // List of monomorphized versions of this function.  They will be added next to the original
36     // version (or replace it).
37     TVector<TIntermFunctionDefinition *> monomorphizedDefinitions;
38 };
39 
40 using FunctionMap = angle::HashMap<const TFunction *, FunctionData>;
41 
42 // Traverse the function definitions and initialize the map.  Allows visitAggregate to have access
43 // to TIntermFunctionDefinition even when the function is only forward declared at that point.
InitializeFunctionMap(TIntermBlock * root,FunctionMap * functionMapOut)44 void InitializeFunctionMap(TIntermBlock *root, FunctionMap *functionMapOut)
45 {
46     TIntermSequence &sequence = *root->getSequence();
47 
48     for (TIntermNode *node : sequence)
49     {
50         TIntermFunctionDefinition *asFuncDef = node->getAsFunctionDefinition();
51         if (asFuncDef != nullptr)
52         {
53             const TFunction *function = asFuncDef->getFunction();
54             ASSERT(function && functionMapOut->find(function) == functionMapOut->end());
55             (*functionMapOut)[function] = FunctionData{false, asFuncDef, {}};
56         }
57     }
58 }
59 
GetBaseUniform(TIntermTyped * node,bool * isSamplerInStructOut)60 const TVariable *GetBaseUniform(TIntermTyped *node, bool *isSamplerInStructOut)
61 {
62     *isSamplerInStructOut = false;
63 
64     while (node->getAsBinaryNode())
65     {
66         TIntermBinary *asBinary = node->getAsBinaryNode();
67 
68         TOperator op = asBinary->getOp();
69 
70         // No opaque uniform can be inside an interface block.
71         if (op == EOpIndexDirectInterfaceBlock)
72         {
73             return nullptr;
74         }
75 
76         if (op == EOpIndexDirectStruct)
77         {
78             *isSamplerInStructOut = true;
79         }
80 
81         node = asBinary->getLeft();
82     }
83 
84     // Only interested in uniform opaque types.  If a function call within another function uses
85     // opaque uniforms in an unsupported way, it will be replaced in a follow up pass after the
86     // calling function is monomorphized.
87     if (node->getType().getQualifier() != EvqUniform)
88     {
89         return nullptr;
90     }
91 
92     ASSERT(IsOpaqueType(node->getType().getBasicType()) ||
93            node->getType().isStructureContainingSamplers());
94 
95     TIntermSymbol *asSymbol = node->getAsSymbolNode();
96     ASSERT(asSymbol);
97 
98     return &asSymbol->variable();
99 }
100 
ExtractSideEffects(TSymbolTable * symbolTable,TIntermTyped * node,TIntermSequence * replacementIndices)101 TIntermTyped *ExtractSideEffects(TSymbolTable *symbolTable,
102                                  TIntermTyped *node,
103                                  TIntermSequence *replacementIndices)
104 {
105     TIntermTyped *withoutSideEffects = node->deepCopy();
106 
107     for (TIntermBinary *asBinary = withoutSideEffects->getAsBinaryNode(); asBinary;
108          asBinary                = asBinary->getLeft()->getAsBinaryNode())
109     {
110         TOperator op        = asBinary->getOp();
111         TIntermTyped *index = asBinary->getRight();
112 
113         if (op == EOpIndexDirectStruct)
114         {
115             break;
116         }
117 
118         // No side effects with constant expressions.
119         if (op == EOpIndexDirect)
120         {
121             ASSERT(index->getAsConstantUnion());
122             continue;
123         }
124 
125         ASSERT(op == EOpIndexIndirect);
126 
127         // If the index is a symbol, there's no side effect, so leave it as-is.
128         if (index->getAsSymbolNode())
129         {
130             continue;
131         }
132 
133         // Otherwise create a temp variable initialized with the index and use that temp variable as
134         // the index.
135         TIntermDeclaration *tempDecl = nullptr;
136         TVariable *tempVar = DeclareTempVariable(symbolTable, index, EvqTemporary, &tempDecl);
137 
138         replacementIndices->push_back(tempDecl);
139         asBinary->replaceChildNode(index, new TIntermSymbol(tempVar));
140     }
141 
142     return withoutSideEffects;
143 }
144 
CreateMonomorphizedFunctionCallArgs(const TIntermSequence & originalCallArguments,const TVector<Argument> & replacedArguments,TIntermSequence * substituteArgsOut)145 void CreateMonomorphizedFunctionCallArgs(const TIntermSequence &originalCallArguments,
146                                          const TVector<Argument> &replacedArguments,
147                                          TIntermSequence *substituteArgsOut)
148 {
149     size_t nextReplacedArg = 0;
150     for (size_t argIndex = 0; argIndex < originalCallArguments.size(); ++argIndex)
151     {
152         if (nextReplacedArg >= replacedArguments.size() ||
153             argIndex != replacedArguments[nextReplacedArg].argumentIndex)
154         {
155             // Not replaced, keep argument as is.
156             substituteArgsOut->push_back(originalCallArguments[argIndex]);
157         }
158         else
159         {
160             TIntermTyped *argument = replacedArguments[nextReplacedArg].argument;
161 
162             // Iterate over indices of the argument and create a new arg for every non-const
163             // index.  Note that the index itself may be an expression, and it may require further
164             // substitution in the next pass.
165             while (argument->getAsBinaryNode())
166             {
167                 TIntermBinary *asBinary = argument->getAsBinaryNode();
168                 if (asBinary->getOp() == EOpIndexIndirect)
169                 {
170                     TIntermTyped *index = asBinary->getRight();
171                     substituteArgsOut->push_back(index->deepCopy());
172                 }
173                 argument = asBinary->getLeft();
174             }
175 
176             ++nextReplacedArg;
177         }
178     }
179 }
180 
MonomorphizeFunction(TSymbolTable * symbolTable,const TFunction * original,TVector<Argument> * replacedArguments,VariableReplacementMap * argumentMapOut)181 const TFunction *MonomorphizeFunction(TSymbolTable *symbolTable,
182                                       const TFunction *original,
183                                       TVector<Argument> *replacedArguments,
184                                       VariableReplacementMap *argumentMapOut)
185 {
186     TFunction *substituteFunction =
187         new TFunction(symbolTable, kEmptyImmutableString, SymbolType::AngleInternal,
188                       &original->getReturnType(), original->isKnownToNotHaveSideEffects());
189 
190     size_t nextReplacedArg = 0;
191     for (size_t paramIndex = 0; paramIndex < original->getParamCount(); ++paramIndex)
192     {
193         const TVariable *originalParam = original->getParam(paramIndex);
194 
195         if (nextReplacedArg >= replacedArguments->size() ||
196             paramIndex != (*replacedArguments)[nextReplacedArg].argumentIndex)
197         {
198             TVariable *substituteArgument =
199                 new TVariable(symbolTable, originalParam->name(), &originalParam->getType(),
200                               originalParam->symbolType());
201             // Not replaced, add an identical parameter.
202             substituteFunction->addParameter(substituteArgument);
203             (*argumentMapOut)[originalParam] = new TIntermSymbol(substituteArgument);
204         }
205         else
206         {
207             TIntermTyped *substituteArgument = (*replacedArguments)[nextReplacedArg].argument;
208             (*argumentMapOut)[originalParam] = substituteArgument;
209 
210             // Iterate over indices of the argument and create a new parameter for every non-const
211             // index (which may be an expression).  Replace the symbol in the argument with a
212             // variable of the index type.  This is later used to replace the parameter in the
213             // function body.
214             while (substituteArgument->getAsBinaryNode())
215             {
216                 TIntermBinary *asBinary = substituteArgument->getAsBinaryNode();
217                 if (asBinary->getOp() == EOpIndexIndirect)
218                 {
219                     TIntermTyped *index = asBinary->getRight();
220                     TType *indexType    = new TType(index->getType());
221                     indexType->setQualifier(EvqParamIn);
222 
223                     TVariable *param = new TVariable(symbolTable, kEmptyImmutableString, indexType,
224                                                      SymbolType::AngleInternal);
225                     substituteFunction->addParameter(param);
226 
227                     // The argument now uses the function parameters as indices.
228                     asBinary->replaceChildNode(asBinary->getRight(), new TIntermSymbol(param));
229                 }
230                 substituteArgument = asBinary->getLeft();
231             }
232 
233             ++nextReplacedArg;
234         }
235     }
236 
237     return substituteFunction;
238 }
239 
240 class MonomorphizeTraverser final : public TIntermTraverser
241 {
242   public:
MonomorphizeTraverser(TCompiler * compiler,TSymbolTable * symbolTable,ShCompileOptions compileOptions,FunctionMap * functionMap)243     explicit MonomorphizeTraverser(TCompiler *compiler,
244                                    TSymbolTable *symbolTable,
245                                    ShCompileOptions compileOptions,
246                                    FunctionMap *functionMap)
247         : TIntermTraverser(true, false, false, symbolTable),
248           mCompiler(compiler),
249           mCompileOptions(compileOptions),
250           mFunctionMap(functionMap)
251     {}
252 
visitAggregate(Visit visit,TIntermAggregate * node)253     bool visitAggregate(Visit visit, TIntermAggregate *node) override
254     {
255         if (node->getOp() != EOpCallFunctionInAST)
256         {
257             return true;
258         }
259 
260         const TFunction *function = node->getFunction();
261         ASSERT(function && mFunctionMap->find(function) != mFunctionMap->end());
262 
263         FunctionData &data = (*mFunctionMap)[function];
264 
265         TIntermFunctionDefinition *monomorphized =
266             processFunctionCall(node, data.originalDefinition, &data.isOriginalUsed);
267         if (monomorphized)
268         {
269             data.monomorphizedDefinitions.push_back(monomorphized);
270         }
271 
272         return true;
273     }
274 
getAnyMonomorphized() const275     bool getAnyMonomorphized() const { return mAnyMonomorphized; }
276 
277   private:
processFunctionCall(TIntermAggregate * functionCall,TIntermFunctionDefinition * originalDefinition,bool * isOriginalUsedOut)278     TIntermFunctionDefinition *processFunctionCall(TIntermAggregate *functionCall,
279                                                    TIntermFunctionDefinition *originalDefinition,
280                                                    bool *isOriginalUsedOut)
281     {
282         const TFunction *function            = functionCall->getFunction();
283         const TIntermSequence &callArguments = *functionCall->getSequence();
284 
285         TVector<Argument> replacedArguments;
286         TIntermSequence replacementIndices;
287 
288         // Go through function call arguments, and see if any is used in an unsupported way.
289         for (size_t argIndex = 0; argIndex < callArguments.size(); ++argIndex)
290         {
291             TIntermTyped *callArgument    = callArguments[argIndex]->getAsTyped();
292             const TVariable *funcArgument = function->getParam(argIndex);
293 
294             // Only interested in opaque uniforms and structs that contain samplers.
295             const bool isOpaqueType = IsOpaqueType(funcArgument->getType().getBasicType());
296             const bool isStructContainingSamplers =
297                 funcArgument->getType().isStructureContainingSamplers();
298             if (!isOpaqueType && !isStructContainingSamplers)
299             {
300                 continue;
301             }
302 
303             // If not uniform (the variable was itself a function parameter), don't process it in
304             // this pass, as we don't know which actual uniform it corresponds to.
305             bool isSamplerInStruct   = false;
306             const TVariable *uniform = GetBaseUniform(callArgument, &isSamplerInStruct);
307             if (uniform == nullptr)
308             {
309                 continue;
310             }
311 
312             // Conditions for monomorphization:
313             //
314             // - If the parameter is a structure that contains samplers (so in RewriteStructSamplers
315             //   we don't need to rewrite the functions to accept multiple parameters split from the
316             //   struct), or
317             // - If the opaque uniform is a sampler in a struct (which can create an array-of-array
318             //   situation), and the function expects an array of samplers, or
319             // - If the opaque uniform is an array of array of sampler or image, and it's partially
320             //   subscripted (i.e. the function itself expects an array), or
321             // - The opaque uniform is an atomic counter
322             // - The opaque uniform is a samplerCube and ES2's cube sampling emulation is requested.
323             // - The opaque uniform is an image* with r32f format.
324             //
325             const TType &type = uniform->getType();
326             const bool isArrayOfArrayOfSamplerOrImage =
327                 (type.isSampler() || type.isImage()) && type.isArrayOfArrays();
328             const bool isParameterArrayOfOpaqueType = funcArgument->getType().isArray();
329             const bool isAtomicCounter              = type.isAtomicCounter();
330             const bool isSamplerCubeEmulation =
331                 type.isSamplerCube() &&
332                 (mCompileOptions & SH_EMULATE_SEAMFUL_CUBE_MAP_SAMPLING) != 0;
333             const bool isR32fImage =
334                 type.isImage() && type.getLayoutQualifier().imageInternalFormat == EiifR32F;
335 
336             if (!(isStructContainingSamplers ||
337                   (isSamplerInStruct && isParameterArrayOfOpaqueType) ||
338                   (isArrayOfArrayOfSamplerOrImage && isParameterArrayOfOpaqueType) ||
339                   isAtomicCounter || isSamplerCubeEmulation || isR32fImage))
340             {
341                 continue;
342             }
343 
344             // Copy the argument and extract the side effects.
345             TIntermTyped *argument =
346                 ExtractSideEffects(mSymbolTable, callArgument, &replacementIndices);
347 
348             replacedArguments.push_back({argIndex, argument});
349         }
350 
351         if (replacedArguments.empty())
352         {
353             *isOriginalUsedOut = true;
354             return nullptr;
355         }
356 
357         mAnyMonomorphized = true;
358 
359         insertStatementsInParentBlock(replacementIndices);
360 
361         // Create the arguments for the substitute function call.  Done before monomorphizing the
362         // function, which transforms the arguments to what needs to be replaced in the function
363         // body.
364         TIntermSequence newCallArgs;
365         CreateMonomorphizedFunctionCallArgs(callArguments, replacedArguments, &newCallArgs);
366 
367         // Duplicate the function and substitute the replaced arguments with only the non-const
368         // indices.  Additionally, substitute the non-const indices of arguments with the new
369         // function parameters.
370         VariableReplacementMap argumentMap;
371         const TFunction *monomorphized =
372             MonomorphizeFunction(mSymbolTable, function, &replacedArguments, &argumentMap);
373 
374         // Replace this function call with a call to the new one.
375         queueReplacement(TIntermAggregate::CreateFunctionCall(*monomorphized, &newCallArgs),
376                          OriginalNode::IS_DROPPED);
377 
378         // Create a new function definition, with the body of the old function but with the replaced
379         // parameters substituted with the calling expressions.
380         TIntermFunctionPrototype *substitutePrototype = new TIntermFunctionPrototype(monomorphized);
381         TIntermBlock *substituteBlock                 = originalDefinition->getBody()->deepCopy();
382         GetDeclaratorReplacements(mSymbolTable, substituteBlock, &argumentMap);
383         bool valid = ReplaceVariables(mCompiler, substituteBlock, argumentMap);
384         ASSERT(valid);
385 
386         return new TIntermFunctionDefinition(substitutePrototype, substituteBlock);
387     }
388 
389     TCompiler *mCompiler;
390     ShCompileOptions mCompileOptions;
391     bool mAnyMonomorphized = false;
392 
393     // Map of original to monomorphized functions.
394     FunctionMap *mFunctionMap;
395 };
396 
397 class UpdateFunctionsDefinitionsTraverser final : public TIntermTraverser
398 {
399   public:
UpdateFunctionsDefinitionsTraverser(TSymbolTable * symbolTable,const FunctionMap & functionMap)400     explicit UpdateFunctionsDefinitionsTraverser(TSymbolTable *symbolTable,
401                                                  const FunctionMap &functionMap)
402         : TIntermTraverser(true, false, false, symbolTable), mFunctionMap(functionMap)
403     {}
404 
visitFunctionPrototype(TIntermFunctionPrototype * node)405     void visitFunctionPrototype(TIntermFunctionPrototype *node) override
406     {
407         const bool isInFunctionDefinition = getParentNode()->getAsFunctionDefinition() != nullptr;
408         if (isInFunctionDefinition)
409         {
410             return;
411         }
412 
413         // Add to and possibly replace the function prototype with replacement prototypes.
414         const TFunction *function = node->getFunction();
415         ASSERT(function && mFunctionMap.find(function) != mFunctionMap.end());
416 
417         const FunctionData &data = mFunctionMap.at(function);
418 
419         // If nothing to do, leave it be.
420         if (data.monomorphizedDefinitions.empty())
421         {
422             ASSERT(data.isOriginalUsed);
423             return;
424         }
425 
426         // Replace the prototype with itself (if function is still used) as well as any
427         // monomorphized versions.
428         TIntermSequence replacement;
429         if (data.isOriginalUsed)
430         {
431             replacement.push_back(node);
432         }
433         for (TIntermFunctionDefinition *monomorphizedDefinition : data.monomorphizedDefinitions)
434         {
435             replacement.push_back(new TIntermFunctionPrototype(
436                 monomorphizedDefinition->getFunctionPrototype()->getFunction()));
437         }
438         mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node,
439                                         std::move(replacement));
440     }
441 
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)442     bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override
443     {
444         // Add to and possibly replace the function definition with replacement definitions.
445         const TFunction *function = node->getFunction();
446         ASSERT(function && mFunctionMap.find(function) != mFunctionMap.end());
447 
448         const FunctionData &data = mFunctionMap.at(function);
449 
450         // If nothing to do, leave it be.
451         if (data.monomorphizedDefinitions.empty())
452         {
453             ASSERT(data.isOriginalUsed || function->name() == "main");
454             return false;
455         }
456 
457         // Replace the definition with itself (if function is still used) as well as any
458         // monomorphized versions.
459         TIntermSequence replacement;
460         if (data.isOriginalUsed)
461         {
462             replacement.push_back(node);
463         }
464         for (TIntermFunctionDefinition *monomorphizedDefinition : data.monomorphizedDefinitions)
465         {
466             replacement.push_back(monomorphizedDefinition);
467         }
468         mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node,
469                                         std::move(replacement));
470 
471         return false;
472     }
473 
474   private:
475     const FunctionMap &mFunctionMap;
476 };
477 
SortDeclarations(TIntermBlock * root)478 void SortDeclarations(TIntermBlock *root)
479 {
480     TIntermSequence *original = root->getSequence();
481 
482     TIntermSequence replacement;
483     TIntermSequence functionDefs;
484 
485     // Accumulate non-function-definition declarations in |replacement| and function definitions in
486     // |functionDefs|.
487     for (TIntermNode *node : *original)
488     {
489         if (node->getAsFunctionDefinition() || node->getAsFunctionPrototypeNode())
490         {
491             functionDefs.push_back(node);
492         }
493         else
494         {
495             replacement.push_back(node);
496         }
497     }
498 
499     // Append function definitions to |replacement|.
500     replacement.insert(replacement.end(), functionDefs.begin(), functionDefs.end());
501 
502     // Replace root's sequence with |replacement|.
503     root->replaceAllChildren(replacement);
504 }
505 
MonomorphizeUnsupportedFunctionsImpl(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,ShCompileOptions compileOptions)506 bool MonomorphizeUnsupportedFunctionsImpl(TCompiler *compiler,
507                                           TIntermBlock *root,
508                                           TSymbolTable *symbolTable,
509                                           ShCompileOptions compileOptions)
510 {
511     // First, sort out the declarations such that all non-function declarations are placed before
512     // function definitions.  This way when the function is replaced with one that references said
513     // declarations (i.e. uniforms), the uniform declaration is already present above it.
514     SortDeclarations(root);
515 
516     while (true)
517     {
518         FunctionMap functionMap;
519         InitializeFunctionMap(root, &functionMap);
520 
521         MonomorphizeTraverser monomorphizer(compiler, symbolTable, compileOptions, &functionMap);
522         root->traverse(&monomorphizer);
523 
524         if (!monomorphizer.getAnyMonomorphized())
525         {
526             break;
527         }
528 
529         if (!monomorphizer.updateTree(compiler, root))
530         {
531             return false;
532         }
533 
534         UpdateFunctionsDefinitionsTraverser functionUpdater(symbolTable, functionMap);
535         root->traverse(&functionUpdater);
536 
537         if (!functionUpdater.updateTree(compiler, root))
538         {
539             return false;
540         }
541     }
542 
543     return true;
544 }
545 }  // anonymous namespace
546 
MonomorphizeUnsupportedFunctions(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,ShCompileOptions compileOptions)547 bool MonomorphizeUnsupportedFunctions(TCompiler *compiler,
548                                       TIntermBlock *root,
549                                       TSymbolTable *symbolTable,
550                                       ShCompileOptions compileOptions)
551 {
552     // This function actually applies multiple transformation, and the AST may not be valid until
553     // the transformations are entirely done.  Some validation is momentarily disabled.
554     bool enableValidateFunctionCall = compiler->disableValidateFunctionCall();
555 
556     bool result = MonomorphizeUnsupportedFunctionsImpl(compiler, root, symbolTable, compileOptions);
557 
558     compiler->restoreValidateFunctionCall(enableValidateFunctionCall);
559     return result && compiler->validateAST(root);
560 }
561 }  // namespace sh
562