• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2024 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // SeparateStructFromFunctionDeclarations: Separate struct declarations from function declaration
7 // return type.
8 //
9 
10 #include "compiler/translator/tree_ops/SeparateStructFromFunctionDeclarations.h"
11 #include "compiler/translator/Compiler.h"
12 #include "compiler/translator/IntermRebuild.h"
13 #include "compiler/translator/SymbolTable.h"
14 
15 namespace sh
16 {
17 namespace
18 {
19 class SeparateStructFromFunctionDeclarationsTraverser : public TIntermRebuild
20 {
21   public:
SeparateStructFromFunctionDeclarationsTraverser(TCompiler & compiler)22     explicit SeparateStructFromFunctionDeclarationsTraverser(TCompiler &compiler)
23         : TIntermRebuild(compiler, true, true)
24     {}
25 
visitFunctionPrototypePre(TIntermFunctionPrototype & node)26     PreResult visitFunctionPrototypePre(TIntermFunctionPrototype &node) override
27     {
28         const TFunction *function = node.getFunction();
29         if (mFunctionsToReplace.count(function) > 0)
30         {
31             TIntermFunctionPrototype *newFuncProto =
32                 new TIntermFunctionPrototype(mFunctionsToReplace[function]);
33             return newFuncProto;
34         }
35         else if (node.getType().isStructSpecifier())
36         {
37             const TType &oldType        = node.getType();
38             const TStructure *structure = oldType.getStruct();
39             // Name unnamed inline structs
40             if (structure->symbolType() == SymbolType::Empty)
41             {
42                 structure = new TStructure(&mSymbolTable, kEmptyImmutableString,
43                                            &structure->fields(), SymbolType::AngleInternal);
44             }
45 
46             TVariable *structVar = new TVariable(&mSymbolTable, ImmutableString(""),
47                                                  new TType(structure, true), SymbolType::Empty);
48             ASSERT(!mStructDeclarations.empty());
49             mStructDeclarations.back().push_back(new TIntermDeclaration({structVar}));
50 
51             TType *returnType = new TType(structure, false);
52             if (oldType.isArray())
53             {
54                 returnType->makeArrays(oldType.getArraySizes());
55             }
56             returnType->setQualifier(oldType.getQualifier());
57 
58             const TFunction *oldFunc = function;
59             ASSERT(oldFunc->symbolType() == SymbolType::UserDefined);
60 
61             const TFunction *newFunc     = cloneFunctionAndChangeReturnType(oldFunc, returnType);
62             mFunctionsToReplace[oldFunc] = newFunc;
63 
64             return new TIntermFunctionPrototype(newFunc);
65         }
66 
67         return node;
68     }
69 
visitAggregatePre(TIntermAggregate & node)70     PreResult visitAggregatePre(TIntermAggregate &node) override
71     {
72         const TFunction *function = node.getFunction();
73         if (mFunctionsToReplace.count(function) > 0)
74         {
75             TIntermAggregate *replacementNode = TIntermAggregate::CreateFunctionCall(
76                 *mFunctionsToReplace[function], node.getSequence());
77 
78             return PreResult(replacementNode, VisitBits::Children);
79         }
80 
81         return node;
82     }
83 
visitBlockPre(TIntermBlock & node)84     PreResult visitBlockPre(TIntermBlock &node) override
85     {
86         mStructDeclarations.push_back({});
87         return node;
88     }
89 
visitBlockPost(TIntermBlock & node)90     PostResult visitBlockPost(TIntermBlock &node) override
91     {
92         ASSERT(!mStructDeclarations.empty());
93 
94         std::vector<TIntermDeclaration *> declarations = mStructDeclarations.back();
95         mStructDeclarations.pop_back();
96 
97         if (!declarations.empty())
98         {
99             TIntermBlock *blockWithStructDeclarations = new TIntermBlock();
100             if (node.isTreeRoot())
101             {
102                 blockWithStructDeclarations->setIsTreeRoot();
103             }
104 
105             for (TIntermDeclaration *structDecl : declarations)
106             {
107                 blockWithStructDeclarations->appendStatement(structDecl);
108             }
109 
110             for (TIntermNode *statement : *node.getSequence())
111             {
112                 blockWithStructDeclarations->appendStatement(statement);
113             }
114 
115             return blockWithStructDeclarations;
116         }
117 
118         return node;
119     }
120 
121   private:
cloneFunctionAndChangeReturnType(const TFunction * oldFunc,const TType * newReturnType)122     const TFunction *cloneFunctionAndChangeReturnType(const TFunction *oldFunc,
123                                                       const TType *newReturnType)
124 
125     {
126         ASSERT(oldFunc->symbolType() == SymbolType::UserDefined);
127 
128         TFunction *newFunc = new TFunction(&mSymbolTable, oldFunc->name(), oldFunc->symbolType(),
129                                            newReturnType, oldFunc->isKnownToNotHaveSideEffects());
130 
131         if (oldFunc->isDefined())
132         {
133             newFunc->setDefined();
134         }
135 
136         if (oldFunc->hasPrototypeDeclaration())
137         {
138             newFunc->setHasPrototypeDeclaration();
139         }
140 
141         const size_t paramCount = oldFunc->getParamCount();
142         for (size_t i = 0; i < paramCount; ++i)
143         {
144             const TVariable *var = oldFunc->getParam(i);
145             newFunc->addParameter(var);
146         }
147 
148         return newFunc;
149     }
150 
151     using FunctionReplacement = angle::HashMap<const TFunction *, const TFunction *>;
152     FunctionReplacement mFunctionsToReplace;
153 
154     // Stack of struct declarations to insert per block
155     std::vector<std::vector<TIntermDeclaration *>> mStructDeclarations;
156 };
157 
158 }  // anonymous namespace
159 
SeparateStructFromFunctionDeclarations(TCompiler & compiler,TIntermBlock & root)160 bool SeparateStructFromFunctionDeclarations(TCompiler &compiler, TIntermBlock &root)
161 {
162     SeparateStructFromFunctionDeclarationsTraverser separateStructDecls(compiler);
163     return separateStructDecls.rebuildRoot(root);
164 }
165 }  // namespace sh
166