1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include "compiler/translator/TranslatorMetalDirect/RewriteOutArgs.h"
8 #include "compiler/translator/tree_util/IntermRebuild.h"
9
10 using namespace sh;
11
12 namespace
13 {
14
15 template <typename T>
16 class SmallMultiSet
17 {
18 public:
19 struct Entry
20 {
21 T elem;
22 size_t count;
23 };
24
find(const T & x) const25 const Entry *find(const T &x) const
26 {
27 for (auto &entry : mEntries)
28 {
29 if (x == entry.elem)
30 {
31 return &entry;
32 }
33 }
34 return nullptr;
35 }
36
multiplicity(const T & x) const37 size_t multiplicity(const T &x) const
38 {
39 const Entry *entry = find(x);
40 return entry ? entry->count : 0;
41 }
42
insert(const T & x)43 const Entry &insert(const T &x)
44 {
45 Entry *entry = findMutable(x);
46 if (entry)
47 {
48 ++entry->count;
49 return *entry;
50 }
51 else
52 {
53 mEntries.push_back({x, 1});
54 return mEntries.back();
55 }
56 }
57
clear()58 void clear() { mEntries.clear(); }
59
empty() const60 bool empty() const { return mEntries.empty(); }
61
uniqueSize() const62 size_t uniqueSize() const { return mEntries.size(); }
63
64 private:
findMutable(const T & x)65 ANGLE_INLINE Entry *findMutable(const T &x) { return const_cast<Entry *>(find(x)); }
66
67 private:
68 std::vector<Entry> mEntries;
69 };
70
GetVariable(TIntermNode & node)71 const TVariable *GetVariable(TIntermNode &node)
72 {
73 TIntermTyped *tyNode = node.getAsTyped();
74 ASSERT(tyNode);
75 if (TIntermSymbol *symbol = tyNode->getAsSymbolNode())
76 {
77 return &symbol->variable();
78 }
79 return nullptr;
80 }
81
82 class Rewriter : public TIntermRebuild
83 {
84 SmallMultiSet<const TVariable *> mVarBuffer; // reusable buffer
85 SymbolEnv &mSymbolEnv;
86
87 public:
~Rewriter()88 ~Rewriter() override { ASSERT(mVarBuffer.empty()); }
89
Rewriter(TCompiler & compiler,SymbolEnv & symbolEnv)90 Rewriter(TCompiler &compiler, SymbolEnv &symbolEnv)
91 : TIntermRebuild(compiler, false, true), mSymbolEnv(symbolEnv)
92 {}
93
argAlreadyProcessed(TIntermTyped * arg)94 static bool argAlreadyProcessed(TIntermTyped *arg)
95 {
96 if (arg->getAsAggregate())
97 {
98 const TFunction *func = arg->getAsAggregate()->getFunction();
99 if (func && func->symbolType() == SymbolType::AngleInternal &&
100 func->name() == "swizzle_ref")
101 {
102 return true;
103 }
104 }
105 return false;
106 }
107
visitAggregatePost(TIntermAggregate & aggregateNode)108 PostResult visitAggregatePost(TIntermAggregate &aggregateNode) override
109 {
110 ASSERT(mVarBuffer.empty());
111
112 const TFunction *func = aggregateNode.getFunction();
113 if (!func)
114 {
115 return aggregateNode;
116 }
117
118 TIntermSequence &args = *aggregateNode.getSequence();
119 size_t argCount = args.size();
120
121 auto getParamQualifier = [&](size_t i) {
122 const TVariable ¶m = *func->getParam(i);
123 const TType ¶mType = param.getType();
124 const TQualifier paramQual = paramType.getQualifier();
125 switch (paramQual)
126 {
127 case TQualifier::EvqParamOut:
128 case TQualifier::EvqParamInOut:
129 if (!mSymbolEnv.isReference(param))
130 {
131 mSymbolEnv.markAsReference(param, AddressSpace::Thread);
132 }
133 break;
134 default:
135 break;
136 }
137 return paramQual;
138 };
139
140 bool mightAlias = false;
141
142 for (size_t i = 0; i < argCount; ++i)
143 {
144 const TQualifier paramQual = getParamQualifier(i);
145
146 switch (paramQual)
147 {
148 case TQualifier::EvqParamOut:
149 case TQualifier::EvqParamInOut:
150 {
151 const TVariable *var = GetVariable(*args[i]);
152 if (mVarBuffer.insert(var).count > 1)
153 {
154 mightAlias = true;
155 i = argCount;
156 }
157 }
158 break;
159
160 default:
161 break;
162 }
163 }
164
165 const bool hasIndeterminateVar = mVarBuffer.find(nullptr);
166
167 if (!mightAlias)
168 {
169 mightAlias = hasIndeterminateVar && mVarBuffer.uniqueSize() > 1;
170 }
171
172 if (mightAlias)
173 {
174 for (size_t i = 0; i < argCount; ++i)
175 {
176 TIntermTyped *arg = args[i]->getAsTyped();
177 ASSERT(arg);
178 if (!argAlreadyProcessed(arg))
179 {
180 const TVariable *var = GetVariable(*arg);
181 const TQualifier paramQual = getParamQualifier(i);
182
183 if (hasIndeterminateVar || mVarBuffer.multiplicity(var) > 1)
184 {
185 switch (paramQual)
186 {
187 case TQualifier::EvqParamOut:
188 args[i] = &mSymbolEnv.callFunctionOverload(
189 Name("out"), arg->getType(), *new TIntermSequence{arg});
190 break;
191
192 case TQualifier::EvqParamInOut:
193 args[i] = &mSymbolEnv.callFunctionOverload(
194 Name("inout"), arg->getType(), *new TIntermSequence{arg});
195 break;
196
197 default:
198 break;
199 }
200 }
201 }
202 }
203 }
204
205 mVarBuffer.clear();
206
207 return aggregateNode;
208 }
209 };
210
211 } // anonymous namespace
212
RewriteOutArgs(TCompiler & compiler,TIntermBlock & root,SymbolEnv & symbolEnv)213 bool sh::RewriteOutArgs(TCompiler &compiler, TIntermBlock &root, SymbolEnv &symbolEnv)
214 {
215 Rewriter rewriter(compiler, symbolEnv);
216 if (!rewriter.rebuildRoot(root))
217 {
218 return false;
219 }
220 return true;
221 }
222