• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include "compiler/translator/tree_ops/msl/RewriteOutArgs.h"
8 #include "compiler/translator/msl/IntermRebuild.h"
9 
10 using namespace sh;
11 
12 namespace
13 {
14 
15 template <typename T>
16 class SmallMultiSet
17 {
18   public:
19     struct Entry
20     {
21         T elem;
22         size_t count;
23     };
24 
find(const T & x) const25     const Entry *find(const T &x) const
26     {
27         for (auto &entry : mEntries)
28         {
29             if (x == entry.elem)
30             {
31                 return &entry;
32             }
33         }
34         return nullptr;
35     }
36 
multiplicity(const T & x) const37     size_t multiplicity(const T &x) const
38     {
39         const Entry *entry = find(x);
40         return entry ? entry->count : 0;
41     }
42 
insert(const T & x)43     const Entry &insert(const T &x)
44     {
45         Entry *entry = findMutable(x);
46         if (entry)
47         {
48             ++entry->count;
49             return *entry;
50         }
51         else
52         {
53             mEntries.push_back({x, 1});
54             return mEntries.back();
55         }
56     }
57 
clear()58     void clear() { mEntries.clear(); }
59 
empty() const60     bool empty() const { return mEntries.empty(); }
61 
uniqueSize() const62     size_t uniqueSize() const { return mEntries.size(); }
63 
64   private:
findMutable(const T & x)65     ANGLE_INLINE Entry *findMutable(const T &x) { return const_cast<Entry *>(find(x)); }
66 
67   private:
68     std::vector<Entry> mEntries;
69 };
70 
GetVariable(TIntermNode & node)71 const TVariable *GetVariable(TIntermNode &node)
72 {
73     TIntermTyped *tyNode = node.getAsTyped();
74     ASSERT(tyNode);
75     if (TIntermSymbol *symbol = tyNode->getAsSymbolNode())
76     {
77         return &symbol->variable();
78     }
79     return nullptr;
80 }
81 
82 class Rewriter : public TIntermRebuild
83 {
84     SmallMultiSet<const TVariable *> mVarBuffer;  // reusable buffer
85     SymbolEnv &mSymbolEnv;
86 
87   public:
~Rewriter()88     ~Rewriter() override { ASSERT(mVarBuffer.empty()); }
89 
Rewriter(TCompiler & compiler,SymbolEnv & symbolEnv)90     Rewriter(TCompiler &compiler, SymbolEnv &symbolEnv)
91         : TIntermRebuild(compiler, false, true), mSymbolEnv(symbolEnv)
92     {}
93 
argAlreadyProcessed(TIntermTyped * arg)94     static bool argAlreadyProcessed(TIntermTyped *arg)
95     {
96         if (arg->getAsAggregate())
97         {
98             const TFunction *func = arg->getAsAggregate()->getFunction();
99             // These two builtins already generate references, and the
100             // ANGLE_inout and ANGLE_out overloads in ProgramPrelude are both
101             // unnecessary and incompatible.
102             if (func && func->symbolType() == SymbolType::AngleInternal &&
103                 (func->name() == "swizzle_ref" || func->name() == "elem_ref"))
104             {
105                 return true;
106             }
107         }
108         return false;
109     }
110 
visitAggregatePost(TIntermAggregate & aggregateNode)111     PostResult visitAggregatePost(TIntermAggregate &aggregateNode) override
112     {
113         ASSERT(mVarBuffer.empty());
114 
115         const TFunction *func = aggregateNode.getFunction();
116         if (!func)
117         {
118             return aggregateNode;
119         }
120 
121         TIntermSequence &args = *aggregateNode.getSequence();
122         size_t argCount       = args.size();
123 
124         auto getParamQualifier = [&](size_t i) {
125             const TVariable &param     = *func->getParam(i);
126             const TType &paramType     = param.getType();
127             const TQualifier paramQual = paramType.getQualifier();
128             switch (paramQual)
129             {
130                 case TQualifier::EvqParamOut:
131                 case TQualifier::EvqParamInOut:
132                     if (!mSymbolEnv.isReference(param))
133                     {
134                         mSymbolEnv.markAsReference(param, AddressSpace::Thread);
135                     }
136                     break;
137                 default:
138                     break;
139             }
140             return paramQual;
141         };
142 
143         bool mightAlias = false;
144 
145         for (size_t i = 0; i < argCount; ++i)
146         {
147             const TQualifier paramQual = getParamQualifier(i);
148 
149             switch (paramQual)
150             {
151                 case TQualifier::EvqParamOut:
152                 case TQualifier::EvqParamInOut:
153                 {
154                     const TVariable *var = GetVariable(*args[i]);
155                     if (mVarBuffer.insert(var).count > 1)
156                     {
157                         mightAlias = true;
158                         i          = argCount;
159                     }
160                 }
161                 break;
162 
163                 default:
164                     break;
165             }
166         }
167 
168         const bool hasIndeterminateVar = mVarBuffer.find(nullptr);
169 
170         if (!mightAlias)
171         {
172             mightAlias = hasIndeterminateVar && mVarBuffer.uniqueSize() > 1;
173         }
174 
175         if (mightAlias)
176         {
177             for (size_t i = 0; i < argCount; ++i)
178             {
179                 TIntermTyped *arg = args[i]->getAsTyped();
180                 ASSERT(arg);
181                 if (!argAlreadyProcessed(arg))
182                 {
183                     const TVariable *var       = GetVariable(*arg);
184                     const TQualifier paramQual = getParamQualifier(i);
185 
186                     if (hasIndeterminateVar || mVarBuffer.multiplicity(var) > 1)
187                     {
188                         switch (paramQual)
189                         {
190                             case TQualifier::EvqParamOut:
191                                 args[i] = &mSymbolEnv.callFunctionOverload(
192                                     Name("out"), arg->getType(), *new TIntermSequence{arg});
193                                 break;
194 
195                             case TQualifier::EvqParamInOut:
196                                 args[i] = &mSymbolEnv.callFunctionOverload(
197                                     Name("inout"), arg->getType(), *new TIntermSequence{arg});
198                                 break;
199 
200                             default:
201                                 break;
202                         }
203                     }
204                 }
205             }
206         }
207 
208         mVarBuffer.clear();
209 
210         return aggregateNode;
211     }
212 };
213 
214 }  // anonymous namespace
215 
RewriteOutArgs(TCompiler & compiler,TIntermBlock & root,SymbolEnv & symbolEnv)216 bool sh::RewriteOutArgs(TCompiler &compiler, TIntermBlock &root, SymbolEnv &symbolEnv)
217 {
218     Rewriter rewriter(compiler, symbolEnv);
219     if (!rewriter.rebuildRoot(root))
220     {
221         return false;
222     }
223     return true;
224 }
225