1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include <algorithm>
8 #include <functional>
9 #include <unordered_map>
10 #include <unordered_set>
11 #include <vector>
12
13 #include "compiler/translator/ImmutableStringBuilder.h"
14 #include "compiler/translator/msl/AstHelpers.h"
15 #include "compiler/translator/msl/ToposortStructs.h"
16 #include "compiler/translator/tree_util/IntermNode_util.h"
17 #include "compiler/translator/tree_util/IntermTraverse.h"
18
19 using namespace sh;
20
21 ////////////////////////////////////////////////////////////////////////////////
22
23 namespace
24 {
25
26 template <typename T>
27 using Edges = std::unordered_set<T>;
28
29 template <typename T>
30 using Graph = std::unordered_map<T, Edges<T>>;
31
32 struct EdgeComparator
33 {
operator ()__anon5cf736b40111::EdgeComparator34 bool operator()(const TStructure *s1, const TStructure *s2) { return s2->name() < s1->name(); }
35 };
36
BuildGraphImpl(SymbolEnv & symbolEnv,Graph<const TStructure * > & g,const TStructure * s)37 void BuildGraphImpl(SymbolEnv &symbolEnv, Graph<const TStructure *> &g, const TStructure *s)
38 {
39 if (g.find(s) != g.end())
40 {
41 return;
42 }
43
44 Edges<const TStructure *> &es = g[s];
45
46 const TFieldList &fs = s->fields();
47 for (const TField *f : fs)
48 {
49 if (const TStructure *z = symbolEnv.remap(f->type()->getStruct()))
50 {
51 es.insert(z);
52 BuildGraphImpl(symbolEnv, g, z);
53 Edges<const TStructure *> &ez = g[z];
54 es.insert(ez.begin(), ez.end());
55 }
56 }
57 }
58
BuildGraph(SymbolEnv & symbolEnv,const std::vector<const TStructure * > & structs)59 Graph<const TStructure *> BuildGraph(SymbolEnv &symbolEnv,
60 const std::vector<const TStructure *> &structs)
61 {
62 Graph<const TStructure *> g;
63 for (const TStructure *s : structs)
64 {
65 BuildGraphImpl(symbolEnv, g, s);
66 }
67 return g;
68 }
69
SortEdges(const std::unordered_set<const TStructure * > & structs)70 std::vector<const TStructure *> SortEdges(const std::unordered_set<const TStructure *> &structs)
71 {
72 std::vector<const TStructure *> sorted;
73 sorted.reserve(structs.size());
74 sorted.insert(sorted.begin(), structs.begin(), structs.end());
75 std::sort(sorted.begin(), sorted.end(), EdgeComparator());
76 return sorted;
77 }
78
79 // Algorthm: https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search
80 // Note that the algorithm is modified to visit nodes in sorted order. This
81 // ensures consistent results. Without this, the returned order (in so far as
82 // leaf nodes) is undefined, because iterating over an unordered_set of pointers
83 // depends upon the actual pointer values. Consistent results is important for
84 // code that keys off the string of shaders for caching.
85 template <typename T>
Toposort(const Graph<T> & g)86 std::vector<T> Toposort(const Graph<T> &g)
87 {
88 // nodes with temporary mark
89 std::unordered_set<T> temps;
90
91 // nodes without permanent mark
92 std::unordered_set<T> invPerms;
93 for (const auto &entry : g)
94 {
95 invPerms.insert(entry.first);
96 }
97
98 // L <- Empty list that will contain the sorted elements
99 std::vector<T> L;
100
101 // function visit(node n)
102 std::function<void(T)> visit = [&](T n) -> void {
103 // if n has a permanent mark then
104 if (invPerms.find(n) == invPerms.end())
105 {
106 // return
107 return;
108 }
109 // if n has a temporary mark then
110 if (temps.find(n) != temps.end())
111 {
112 // stop (not a DAG)
113 UNREACHABLE();
114 }
115
116 // mark n with a temporary mark
117 temps.insert(n);
118
119 // for each node m with an edge from n to m do
120 auto enIter = g.find(n);
121 ASSERT(enIter != g.end());
122
123 std::vector<T> sorted = SortEdges(enIter->second);
124 for (T m : sorted)
125 {
126 // visit(m)
127 visit(m);
128 }
129
130 // remove temporary mark from n
131 temps.erase(n);
132 // mark n with a permanent mark
133 invPerms.erase(n);
134 // add n to head of L
135 L.push_back(n);
136 };
137
138 // while exists nodes without a permanent mark do
139 while (!invPerms.empty())
140 {
141 // select an unmarked node n
142 std::vector<T> sorted = SortEdges(invPerms);
143 T n = *sorted.begin();
144 // visit(n)
145 visit(n);
146 }
147
148 return L;
149 }
150
CreateStructEqualityFunction(TSymbolTable & symbolTable,const TStructure & aStructType)151 TIntermFunctionDefinition *CreateStructEqualityFunction(TSymbolTable &symbolTable,
152 const TStructure &aStructType)
153 {
154 ////////////////////
155
156 auto &funcEquality =
157 *new TFunction(&symbolTable, ImmutableString("equal"), SymbolType::AngleInternal,
158 new TType(TBasicType::EbtBool), true);
159 auto &aStruct = CreateInstanceVariable(symbolTable, aStructType, Name("a"));
160 auto &bStruct = CreateInstanceVariable(symbolTable, aStructType, Name("b"));
161 funcEquality.addParameter(&aStruct);
162 funcEquality.addParameter(&bStruct);
163
164 auto &bodyEquality = *new TIntermBlock();
165 std::vector<TIntermTyped *> andNodes;
166 ////////////////////
167
168 const TFieldList &aFields = aStructType.fields();
169 const size_t size = aFields.size();
170
171 auto testEquality = [&](TIntermTyped &a, TIntermTyped &b) -> TIntermTyped * {
172 ASSERT(a.getType() == b.getType());
173 const TType &type = a.getType();
174 if (type.isVector() || type.isMatrix() || type.getStruct())
175 {
176 auto *func =
177 new TFunction(&symbolTable, ImmutableString("equal"), SymbolType::AngleInternal,
178 new TType(TBasicType::EbtBool), true);
179 return TIntermAggregate::CreateFunctionCall(*func, new TIntermSequence{&a, &b});
180 }
181 else
182 {
183 return new TIntermBinary(TOperator::EOpEqual, &a, &b);
184 }
185 };
186
187 for (size_t idx = 0; idx < size; ++idx)
188 {
189 const TField &aField = *aFields[idx];
190 const TType &aFieldType = *aField.type();
191 auto &aFieldName = aField.name();
192
193 if (aFieldType.isArray())
194 {
195 ASSERT(!aFieldType.isArrayOfArrays()); // TODO
196 int dim = aFieldType.getOutermostArraySize();
197 for (int d = 0; d < dim; ++d)
198 {
199 auto &aAccess = AccessIndex(AccessField(aStruct, aFieldName), d);
200 auto &bAccess = AccessIndex(AccessField(bStruct, aFieldName), d);
201 auto *eqNode = testEquality(bAccess, aAccess);
202 andNodes.push_back(eqNode);
203 }
204 }
205 else
206 {
207 auto &aAccess = AccessField(aStruct, aFieldName);
208 auto &bAccess = AccessField(bStruct, aFieldName);
209 auto *eqNode = testEquality(bAccess, aAccess);
210 andNodes.push_back(eqNode);
211 }
212 }
213
214 ASSERT(andNodes.size() > 0); // Empty structs are not allowed in GLSL
215 TIntermTyped *outNode = andNodes.back();
216 andNodes.pop_back();
217 for (TIntermTyped *andNode : andNodes)
218 {
219 outNode = new TIntermBinary(TOperator::EOpLogicalAnd, andNode, outNode);
220 }
221 bodyEquality.appendStatement(new TIntermBranch(TOperator::EOpReturn, outNode));
222 auto *funcProtoEquality = new TIntermFunctionPrototype(&funcEquality);
223 return new TIntermFunctionDefinition(funcProtoEquality, &bodyEquality);
224 }
225
226 struct DeclaredStructure
227 {
228 TIntermDeclaration *declNode;
229 TIntermFunctionDefinition *equalityFunctionDefinition;
230 const TStructure *structure;
231 };
232
GetAsDeclaredStructure(SymbolEnv & symbolEnv,TIntermNode & node,DeclaredStructure & out,TSymbolTable & symbolTable,const std::unordered_set<const TStructure * > & usedStructs)233 bool GetAsDeclaredStructure(SymbolEnv &symbolEnv,
234 TIntermNode &node,
235 DeclaredStructure &out,
236 TSymbolTable &symbolTable,
237 const std::unordered_set<const TStructure *> &usedStructs)
238 {
239 if (TIntermDeclaration *declNode = node.getAsDeclarationNode())
240 {
241 ASSERT(declNode->getChildCount() == 1);
242 TIntermNode &childNode = *declNode->getChildNode(0);
243
244 if (TIntermSymbol *symbolNode = childNode.getAsSymbolNode())
245 {
246 const TVariable &var = symbolNode->variable();
247 const TType &type = var.getType();
248 if (const TStructure *structure = symbolEnv.remap(type.getStruct()))
249 {
250 if (type.isStructSpecifier())
251 {
252 out.declNode = declNode;
253 out.structure = structure;
254 out.equalityFunctionDefinition =
255 usedStructs.find(structure) == usedStructs.end()
256 ? nullptr
257 : CreateStructEqualityFunction(symbolTable, *structure);
258 return true;
259 }
260 }
261 }
262 }
263 return false;
264 }
265
266 class FindStructEqualityUse : public TIntermTraverser
267 {
268 public:
269 SymbolEnv &mSymbolEnv;
270 std::unordered_set<const TStructure *> mUsedStructs;
271
FindStructEqualityUse(SymbolEnv & symbolEnv)272 FindStructEqualityUse(SymbolEnv &symbolEnv)
273 : TIntermTraverser(false, false, true), mSymbolEnv(symbolEnv)
274 {}
275
visitBinary(Visit,TIntermBinary * binary)276 bool visitBinary(Visit, TIntermBinary *binary) override
277 {
278 const TOperator op = binary->getOp();
279
280 switch (op)
281 {
282 case TOperator::EOpEqual:
283 case TOperator::EOpNotEqual:
284 {
285 const TType &leftType = binary->getLeft()->getType();
286 const TType &rightType = binary->getRight()->getType();
287 ASSERT(leftType.getStruct() == rightType.getStruct());
288 if (const TStructure *structure = mSymbolEnv.remap(leftType.getStruct()))
289 {
290 useStruct(*structure);
291 }
292 }
293 break;
294
295 default:
296 break;
297 }
298
299 return true;
300 }
301
302 private:
useStruct(const TStructure & structure)303 void useStruct(const TStructure &structure)
304 {
305 if (mUsedStructs.insert(&structure).second)
306 {
307 for (const TField *field : structure.fields())
308 {
309 if (const TStructure *subStruct = mSymbolEnv.remap(field->type()->getStruct()))
310 {
311 useStruct(*subStruct);
312 }
313 }
314 }
315 }
316 };
317
318 } // anonymous namespace
319
320 ////////////////////////////////////////////////////////////////////////////////
321
ToposortStructs(TCompiler & compiler,SymbolEnv & symbolEnv,TIntermBlock & root,ProgramPreludeConfig & ppc)322 bool sh::ToposortStructs(TCompiler &compiler,
323 SymbolEnv &symbolEnv,
324 TIntermBlock &root,
325 ProgramPreludeConfig &ppc)
326 {
327 FindStructEqualityUse finder(symbolEnv);
328 root.traverse(&finder);
329 ppc.hasStructEq = !finder.mUsedStructs.empty();
330
331 std::vector<DeclaredStructure> declaredStructs;
332 std::vector<TIntermNode *> nonStructStmtNodes;
333
334 {
335 DeclaredStructure declaredStruct;
336 const size_t stmtCount = root.getChildCount();
337 for (size_t i = 0; i < stmtCount; ++i)
338 {
339 TIntermNode &stmtNode = *root.getChildNode(i);
340 if (GetAsDeclaredStructure(symbolEnv, stmtNode, declaredStruct,
341 compiler.getSymbolTable(), finder.mUsedStructs))
342 {
343 declaredStructs.push_back(declaredStruct);
344 }
345 else
346 {
347 nonStructStmtNodes.push_back(&stmtNode);
348 }
349 }
350 }
351
352 {
353 std::vector<const TStructure *> structs;
354 std::unordered_map<const TStructure *, DeclaredStructure> rawToDeclared;
355
356 for (const DeclaredStructure &d : declaredStructs)
357 {
358 structs.push_back(d.structure);
359 ASSERT(rawToDeclared.find(d.structure) == rawToDeclared.end());
360 rawToDeclared[d.structure] = d;
361 }
362
363 // Note: Graph may contain more than only explicitly declared structures.
364 Graph<const TStructure *> g = BuildGraph(symbolEnv, structs);
365 std::vector<const TStructure *> sortedStructs = Toposort(g);
366 ASSERT(declaredStructs.size() <= sortedStructs.size());
367
368 declaredStructs.clear();
369 for (const TStructure *s : sortedStructs)
370 {
371 auto it = rawToDeclared.find(s);
372 if (it != rawToDeclared.end())
373 {
374 auto &d = it->second;
375 ASSERT(d.declNode);
376 declaredStructs.push_back(d);
377 }
378 }
379 }
380
381 {
382 TIntermSequence newStmtNodes;
383
384 for (DeclaredStructure &declaredStruct : declaredStructs)
385 {
386 ASSERT(declaredStruct.declNode);
387 newStmtNodes.push_back(declaredStruct.declNode);
388 if (declaredStruct.equalityFunctionDefinition)
389 {
390 newStmtNodes.push_back(declaredStruct.equalityFunctionDefinition);
391 }
392 }
393
394 for (TIntermNode *stmtNode : nonStructStmtNodes)
395 {
396 ASSERT(stmtNode);
397 newStmtNodes.push_back(stmtNode);
398 }
399
400 *root.getSequence() = newStmtNodes;
401 }
402
403 return compiler.validateAST(&root);
404 }
405