• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include "compiler/translator/TranslatorMetalDirect/RewriteOutArgs.h"
8 #include "compiler/translator/tree_util/IntermRebuild.h"
9 
10 using namespace sh;
11 
12 namespace
13 {
14 
15 template <typename T>
16 class SmallMultiSet
17 {
18   public:
19     struct Entry
20     {
21         T elem;
22         size_t count;
23     };
24 
find(const T & x) const25     const Entry *find(const T &x) const
26     {
27         for (auto &entry : mEntries)
28         {
29             if (x == entry.elem)
30             {
31                 return &entry;
32             }
33         }
34         return nullptr;
35     }
36 
multiplicity(const T & x) const37     size_t multiplicity(const T &x) const
38     {
39         const Entry *entry = find(x);
40         return entry ? entry->count : 0;
41     }
42 
insert(const T & x)43     const Entry &insert(const T &x)
44     {
45         Entry *entry = findMutable(x);
46         if (entry)
47         {
48             ++entry->count;
49             return *entry;
50         }
51         else
52         {
53             mEntries.push_back({x, 1});
54             return mEntries.back();
55         }
56     }
57 
clear()58     void clear() { mEntries.clear(); }
59 
empty() const60     bool empty() const { return mEntries.empty(); }
61 
uniqueSize() const62     size_t uniqueSize() const { return mEntries.size(); }
63 
64   private:
findMutable(const T & x)65     ANGLE_INLINE Entry *findMutable(const T &x) { return const_cast<Entry *>(find(x)); }
66 
67   private:
68     std::vector<Entry> mEntries;
69 };
70 
GetVariable(TIntermNode & node)71 const TVariable *GetVariable(TIntermNode &node)
72 {
73     TIntermTyped *tyNode = node.getAsTyped();
74     ASSERT(tyNode);
75     if (TIntermSymbol *symbol = tyNode->getAsSymbolNode())
76     {
77         return &symbol->variable();
78     }
79     return nullptr;
80 }
81 
82 class Rewriter : public TIntermRebuild
83 {
84     SmallMultiSet<const TVariable *> mVarBuffer;  // reusable buffer
85     SymbolEnv &mSymbolEnv;
86 
87   public:
~Rewriter()88     ~Rewriter() override { ASSERT(mVarBuffer.empty()); }
89 
Rewriter(TCompiler & compiler,SymbolEnv & symbolEnv)90     Rewriter(TCompiler &compiler, SymbolEnv &symbolEnv)
91         : TIntermRebuild(compiler, false, true), mSymbolEnv(symbolEnv)
92     {}
93 
argAlreadyProcessed(TIntermTyped * arg)94     static bool argAlreadyProcessed(TIntermTyped *arg)
95     {
96         if (arg->getAsAggregate())
97         {
98             const TFunction *func = arg->getAsAggregate()->getFunction();
99             if (func && func->symbolType() == SymbolType::AngleInternal &&
100                 func->name() == "swizzle_ref")
101             {
102                 return true;
103             }
104         }
105         return false;
106     }
107 
visitAggregatePost(TIntermAggregate & aggregateNode)108     PostResult visitAggregatePost(TIntermAggregate &aggregateNode) override
109     {
110         ASSERT(mVarBuffer.empty());
111 
112         const TFunction *func = aggregateNode.getFunction();
113         if (!func)
114         {
115             return aggregateNode;
116         }
117 
118         TIntermSequence &args = *aggregateNode.getSequence();
119         size_t argCount       = args.size();
120 
121         auto getParamQualifier = [&](size_t i) {
122             const TVariable &param     = *func->getParam(i);
123             const TType &paramType     = param.getType();
124             const TQualifier paramQual = paramType.getQualifier();
125             switch (paramQual)
126             {
127                 case TQualifier::EvqParamOut:
128                 case TQualifier::EvqParamInOut:
129                     if (!mSymbolEnv.isReference(param))
130                     {
131                         mSymbolEnv.markAsReference(param, AddressSpace::Thread);
132                     }
133                     break;
134                 default:
135                     break;
136             }
137             return paramQual;
138         };
139 
140         bool mightAlias = false;
141 
142         for (size_t i = 0; i < argCount; ++i)
143         {
144             const TQualifier paramQual = getParamQualifier(i);
145 
146             switch (paramQual)
147             {
148                 case TQualifier::EvqParamOut:
149                 case TQualifier::EvqParamInOut:
150                 {
151                     const TVariable *var = GetVariable(*args[i]);
152                     if (mVarBuffer.insert(var).count > 1)
153                     {
154                         mightAlias = true;
155                         i          = argCount;
156                     }
157                 }
158                 break;
159 
160                 default:
161                     break;
162             }
163         }
164 
165         const bool hasIndeterminateVar = mVarBuffer.find(nullptr);
166 
167         if (!mightAlias)
168         {
169             mightAlias = hasIndeterminateVar && mVarBuffer.uniqueSize() > 1;
170         }
171 
172         if (mightAlias)
173         {
174             for (size_t i = 0; i < argCount; ++i)
175             {
176                 TIntermTyped *arg = args[i]->getAsTyped();
177                 ASSERT(arg);
178                 if (!argAlreadyProcessed(arg))
179                 {
180                     const TVariable *var       = GetVariable(*arg);
181                     const TQualifier paramQual = getParamQualifier(i);
182 
183                     if (hasIndeterminateVar || mVarBuffer.multiplicity(var) > 1)
184                     {
185                         switch (paramQual)
186                         {
187                             case TQualifier::EvqParamOut:
188                                 args[i] = &mSymbolEnv.callFunctionOverload(
189                                     Name("out"), arg->getType(), *new TIntermSequence{arg});
190                                 break;
191 
192                             case TQualifier::EvqParamInOut:
193                                 args[i] = &mSymbolEnv.callFunctionOverload(
194                                     Name("inout"), arg->getType(), *new TIntermSequence{arg});
195                                 break;
196 
197                             default:
198                                 break;
199                         }
200                     }
201                 }
202             }
203         }
204 
205         mVarBuffer.clear();
206 
207         return aggregateNode;
208     }
209 };
210 
211 }  // anonymous namespace
212 
RewriteOutArgs(TCompiler & compiler,TIntermBlock & root,SymbolEnv & symbolEnv)213 bool sh::RewriteOutArgs(TCompiler &compiler, TIntermBlock &root, SymbolEnv &symbolEnv)
214 {
215     Rewriter rewriter(compiler, symbolEnv);
216     if (!rewriter.rebuildRoot(root))
217     {
218         return false;
219     }
220     return true;
221 }
222