1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include "compiler/translator/tree_ops/msl/RewriteOutArgs.h"
8 #include "compiler/translator/msl/IntermRebuild.h"
9
10 using namespace sh;
11
12 namespace
13 {
14
15 template <typename T>
16 class SmallMultiSet
17 {
18 public:
19 struct Entry
20 {
21 T elem;
22 size_t count;
23 };
24
find(const T & x) const25 const Entry *find(const T &x) const
26 {
27 for (auto &entry : mEntries)
28 {
29 if (x == entry.elem)
30 {
31 return &entry;
32 }
33 }
34 return nullptr;
35 }
36
multiplicity(const T & x) const37 size_t multiplicity(const T &x) const
38 {
39 const Entry *entry = find(x);
40 return entry ? entry->count : 0;
41 }
42
insert(const T & x)43 const Entry &insert(const T &x)
44 {
45 Entry *entry = findMutable(x);
46 if (entry)
47 {
48 ++entry->count;
49 return *entry;
50 }
51 else
52 {
53 mEntries.push_back({x, 1});
54 return mEntries.back();
55 }
56 }
57
clear()58 void clear() { mEntries.clear(); }
59
empty() const60 bool empty() const { return mEntries.empty(); }
61
uniqueSize() const62 size_t uniqueSize() const { return mEntries.size(); }
63
64 private:
findMutable(const T & x)65 ANGLE_INLINE Entry *findMutable(const T &x) { return const_cast<Entry *>(find(x)); }
66
67 private:
68 std::vector<Entry> mEntries;
69 };
70
GetVariable(TIntermNode & node)71 const TVariable *GetVariable(TIntermNode &node)
72 {
73 TIntermTyped *tyNode = node.getAsTyped();
74 ASSERT(tyNode);
75 if (TIntermSymbol *symbol = tyNode->getAsSymbolNode())
76 {
77 return &symbol->variable();
78 }
79 return nullptr;
80 }
81
82 class Rewriter : public TIntermRebuild
83 {
84 SmallMultiSet<const TVariable *> mVarBuffer; // reusable buffer
85 SymbolEnv &mSymbolEnv;
86
87 public:
~Rewriter()88 ~Rewriter() override { ASSERT(mVarBuffer.empty()); }
89
Rewriter(TCompiler & compiler,SymbolEnv & symbolEnv)90 Rewriter(TCompiler &compiler, SymbolEnv &symbolEnv)
91 : TIntermRebuild(compiler, false, true), mSymbolEnv(symbolEnv)
92 {}
93
argAlreadyProcessed(TIntermTyped * arg)94 static bool argAlreadyProcessed(TIntermTyped *arg)
95 {
96 if (arg->getAsAggregate())
97 {
98 const TFunction *func = arg->getAsAggregate()->getFunction();
99 // These two builtins already generate references, and the
100 // ANGLE_inout and ANGLE_out overloads in ProgramPrelude are both
101 // unnecessary and incompatible.
102 if (func && func->symbolType() == SymbolType::AngleInternal &&
103 (func->name() == "swizzle_ref" || func->name() == "elem_ref"))
104 {
105 return true;
106 }
107 }
108 return false;
109 }
110
visitAggregatePost(TIntermAggregate & aggregateNode)111 PostResult visitAggregatePost(TIntermAggregate &aggregateNode) override
112 {
113 ASSERT(mVarBuffer.empty());
114
115 const TFunction *func = aggregateNode.getFunction();
116 if (!func)
117 {
118 return aggregateNode;
119 }
120
121 TIntermSequence &args = *aggregateNode.getSequence();
122 size_t argCount = args.size();
123
124 auto getParamQualifier = [&](size_t i) {
125 const TVariable ¶m = *func->getParam(i);
126 const TType ¶mType = param.getType();
127 const TQualifier paramQual = paramType.getQualifier();
128 switch (paramQual)
129 {
130 case TQualifier::EvqParamOut:
131 case TQualifier::EvqParamInOut:
132 if (!mSymbolEnv.isReference(param))
133 {
134 mSymbolEnv.markAsReference(param, AddressSpace::Thread);
135 }
136 break;
137 default:
138 break;
139 }
140 return paramQual;
141 };
142
143 bool mightAlias = false;
144
145 for (size_t i = 0; i < argCount; ++i)
146 {
147 const TQualifier paramQual = getParamQualifier(i);
148
149 switch (paramQual)
150 {
151 case TQualifier::EvqParamOut:
152 case TQualifier::EvqParamInOut:
153 {
154 const TVariable *var = GetVariable(*args[i]);
155 if (mVarBuffer.insert(var).count > 1)
156 {
157 mightAlias = true;
158 i = argCount;
159 }
160 }
161 break;
162
163 default:
164 break;
165 }
166 }
167
168 const bool hasIndeterminateVar = mVarBuffer.find(nullptr);
169
170 if (!mightAlias)
171 {
172 mightAlias = hasIndeterminateVar && mVarBuffer.uniqueSize() > 1;
173 }
174
175 if (mightAlias)
176 {
177 for (size_t i = 0; i < argCount; ++i)
178 {
179 TIntermTyped *arg = args[i]->getAsTyped();
180 ASSERT(arg);
181 if (!argAlreadyProcessed(arg))
182 {
183 const TVariable *var = GetVariable(*arg);
184 const TQualifier paramQual = getParamQualifier(i);
185
186 if (hasIndeterminateVar || mVarBuffer.multiplicity(var) > 1)
187 {
188 switch (paramQual)
189 {
190 case TQualifier::EvqParamOut:
191 args[i] = &mSymbolEnv.callFunctionOverload(
192 Name("out"), arg->getType(), *new TIntermSequence{arg});
193 break;
194
195 case TQualifier::EvqParamInOut:
196 args[i] = &mSymbolEnv.callFunctionOverload(
197 Name("inout"), arg->getType(), *new TIntermSequence{arg});
198 break;
199
200 default:
201 break;
202 }
203 }
204 }
205 }
206 }
207
208 mVarBuffer.clear();
209
210 return aggregateNode;
211 }
212 };
213
214 } // anonymous namespace
215
RewriteOutArgs(TCompiler & compiler,TIntermBlock & root,SymbolEnv & symbolEnv)216 bool sh::RewriteOutArgs(TCompiler &compiler, TIntermBlock &root, SymbolEnv &symbolEnv)
217 {
218 Rewriter rewriter(compiler, symbolEnv);
219 if (!rewriter.rebuildRoot(root))
220 {
221 return false;
222 }
223 return true;
224 }
225