• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include <algorithm>
8 #include <functional>
9 #include <unordered_map>
10 #include <unordered_set>
11 #include <vector>
12 
13 #include "compiler/translator/ImmutableStringBuilder.h"
14 #include "compiler/translator/msl/AstHelpers.h"
15 #include "compiler/translator/msl/ToposortStructs.h"
16 #include "compiler/translator/tree_util/IntermNode_util.h"
17 #include "compiler/translator/tree_util/IntermTraverse.h"
18 
19 using namespace sh;
20 
21 ////////////////////////////////////////////////////////////////////////////////
22 
23 namespace
24 {
25 
26 template <typename T>
27 using Edges = std::unordered_set<T>;
28 
29 template <typename T>
30 using Graph = std::unordered_map<T, Edges<T>>;
31 
32 struct EdgeComparator
33 {
operator ()__anon5cf736b40111::EdgeComparator34     bool operator()(const TStructure *s1, const TStructure *s2) { return s2->name() < s1->name(); }
35 };
36 
BuildGraphImpl(SymbolEnv & symbolEnv,Graph<const TStructure * > & g,const TStructure * s)37 void BuildGraphImpl(SymbolEnv &symbolEnv, Graph<const TStructure *> &g, const TStructure *s)
38 {
39     if (g.find(s) != g.end())
40     {
41         return;
42     }
43 
44     Edges<const TStructure *> &es = g[s];
45 
46     const TFieldList &fs = s->fields();
47     for (const TField *f : fs)
48     {
49         if (const TStructure *z = symbolEnv.remap(f->type()->getStruct()))
50         {
51             es.insert(z);
52             BuildGraphImpl(symbolEnv, g, z);
53             Edges<const TStructure *> &ez = g[z];
54             es.insert(ez.begin(), ez.end());
55         }
56     }
57 }
58 
BuildGraph(SymbolEnv & symbolEnv,const std::vector<const TStructure * > & structs)59 Graph<const TStructure *> BuildGraph(SymbolEnv &symbolEnv,
60                                      const std::vector<const TStructure *> &structs)
61 {
62     Graph<const TStructure *> g;
63     for (const TStructure *s : structs)
64     {
65         BuildGraphImpl(symbolEnv, g, s);
66     }
67     return g;
68 }
69 
SortEdges(const std::unordered_set<const TStructure * > & structs)70 std::vector<const TStructure *> SortEdges(const std::unordered_set<const TStructure *> &structs)
71 {
72     std::vector<const TStructure *> sorted;
73     sorted.reserve(structs.size());
74     sorted.insert(sorted.begin(), structs.begin(), structs.end());
75     std::sort(sorted.begin(), sorted.end(), EdgeComparator());
76     return sorted;
77 }
78 
79 // Algorthm: https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search
80 // Note that the algorithm is modified to visit nodes in sorted order. This
81 // ensures consistent results. Without this, the returned order (in so far as
82 // leaf nodes) is undefined, because iterating over an unordered_set of pointers
83 // depends upon the actual pointer values. Consistent results is important for
84 // code that keys off the string of shaders for caching.
85 template <typename T>
Toposort(const Graph<T> & g)86 std::vector<T> Toposort(const Graph<T> &g)
87 {
88     // nodes with temporary mark
89     std::unordered_set<T> temps;
90 
91     // nodes without permanent mark
92     std::unordered_set<T> invPerms;
93     for (const auto &entry : g)
94     {
95         invPerms.insert(entry.first);
96     }
97 
98     // L <- Empty list that will contain the sorted elements
99     std::vector<T> L;
100 
101     // function visit(node n)
102     std::function<void(T)> visit = [&](T n) -> void {
103         // if n has a permanent mark then
104         if (invPerms.find(n) == invPerms.end())
105         {
106             // return
107             return;
108         }
109         // if n has a temporary mark then
110         if (temps.find(n) != temps.end())
111         {
112             // stop   (not a DAG)
113             UNREACHABLE();
114         }
115 
116         // mark n with a temporary mark
117         temps.insert(n);
118 
119         // for each node m with an edge from n to m do
120         auto enIter = g.find(n);
121         ASSERT(enIter != g.end());
122 
123         std::vector<T> sorted = SortEdges(enIter->second);
124         for (T m : sorted)
125         {
126             // visit(m)
127             visit(m);
128         }
129 
130         // remove temporary mark from n
131         temps.erase(n);
132         // mark n with a permanent mark
133         invPerms.erase(n);
134         // add n to head of L
135         L.push_back(n);
136     };
137 
138     // while exists nodes without a permanent mark do
139     while (!invPerms.empty())
140     {
141         // select an unmarked node n
142         std::vector<T> sorted = SortEdges(invPerms);
143         T n                   = *sorted.begin();
144         // visit(n)
145         visit(n);
146     }
147 
148     return L;
149 }
150 
CreateStructEqualityFunction(TSymbolTable & symbolTable,const TStructure & aStructType)151 TIntermFunctionDefinition *CreateStructEqualityFunction(TSymbolTable &symbolTable,
152                                                         const TStructure &aStructType)
153 {
154     ////////////////////
155 
156     auto &funcEquality =
157         *new TFunction(&symbolTable, ImmutableString("equal"), SymbolType::AngleInternal,
158                        new TType(TBasicType::EbtBool), true);
159     auto &aStruct = CreateInstanceVariable(symbolTable, aStructType, Name("a"));
160     auto &bStruct = CreateInstanceVariable(symbolTable, aStructType, Name("b"));
161     funcEquality.addParameter(&aStruct);
162     funcEquality.addParameter(&bStruct);
163 
164     auto &bodyEquality = *new TIntermBlock();
165     std::vector<TIntermTyped *> andNodes;
166     ////////////////////
167 
168     const TFieldList &aFields = aStructType.fields();
169     const size_t size         = aFields.size();
170 
171     auto testEquality = [&](TIntermTyped &a, TIntermTyped &b) -> TIntermTyped * {
172         ASSERT(a.getType() == b.getType());
173         const TType &type = a.getType();
174         if (type.isVector() || type.isMatrix() || type.getStruct())
175         {
176             auto *func =
177                 new TFunction(&symbolTable, ImmutableString("equal"), SymbolType::AngleInternal,
178                               new TType(TBasicType::EbtBool), true);
179             return TIntermAggregate::CreateFunctionCall(*func, new TIntermSequence{&a, &b});
180         }
181         else
182         {
183             return new TIntermBinary(TOperator::EOpEqual, &a, &b);
184         }
185     };
186 
187     for (size_t idx = 0; idx < size; ++idx)
188     {
189         const TField &aField    = *aFields[idx];
190         const TType &aFieldType = *aField.type();
191         auto &aFieldName        = aField.name();
192 
193         if (aFieldType.isArray())
194         {
195             ASSERT(!aFieldType.isArrayOfArrays());  // TODO
196             int dim = aFieldType.getOutermostArraySize();
197             for (int d = 0; d < dim; ++d)
198             {
199                 auto &aAccess = AccessIndex(AccessField(aStruct, aFieldName), d);
200                 auto &bAccess = AccessIndex(AccessField(bStruct, aFieldName), d);
201                 auto *eqNode  = testEquality(bAccess, aAccess);
202                 andNodes.push_back(eqNode);
203             }
204         }
205         else
206         {
207             auto &aAccess = AccessField(aStruct, aFieldName);
208             auto &bAccess = AccessField(bStruct, aFieldName);
209             auto *eqNode  = testEquality(bAccess, aAccess);
210             andNodes.push_back(eqNode);
211         }
212     }
213 
214     ASSERT(andNodes.size() > 0);  // Empty structs are not allowed in GLSL
215     TIntermTyped *outNode = andNodes.back();
216     andNodes.pop_back();
217     for (TIntermTyped *andNode : andNodes)
218     {
219         outNode = new TIntermBinary(TOperator::EOpLogicalAnd, andNode, outNode);
220     }
221     bodyEquality.appendStatement(new TIntermBranch(TOperator::EOpReturn, outNode));
222     auto *funcProtoEquality = new TIntermFunctionPrototype(&funcEquality);
223     return new TIntermFunctionDefinition(funcProtoEquality, &bodyEquality);
224 }
225 
226 struct DeclaredStructure
227 {
228     TIntermDeclaration *declNode;
229     TIntermFunctionDefinition *equalityFunctionDefinition;
230     const TStructure *structure;
231 };
232 
GetAsDeclaredStructure(SymbolEnv & symbolEnv,TIntermNode & node,DeclaredStructure & out,TSymbolTable & symbolTable,const std::unordered_set<const TStructure * > & usedStructs)233 bool GetAsDeclaredStructure(SymbolEnv &symbolEnv,
234                             TIntermNode &node,
235                             DeclaredStructure &out,
236                             TSymbolTable &symbolTable,
237                             const std::unordered_set<const TStructure *> &usedStructs)
238 {
239     if (TIntermDeclaration *declNode = node.getAsDeclarationNode())
240     {
241         ASSERT(declNode->getChildCount() == 1);
242         TIntermNode &childNode = *declNode->getChildNode(0);
243 
244         if (TIntermSymbol *symbolNode = childNode.getAsSymbolNode())
245         {
246             const TVariable &var = symbolNode->variable();
247             const TType &type    = var.getType();
248             if (const TStructure *structure = symbolEnv.remap(type.getStruct()))
249             {
250                 if (type.isStructSpecifier())
251                 {
252                     out.declNode  = declNode;
253                     out.structure = structure;
254                     out.equalityFunctionDefinition =
255                         usedStructs.find(structure) == usedStructs.end()
256                             ? nullptr
257                             : CreateStructEqualityFunction(symbolTable, *structure);
258                     return true;
259                 }
260             }
261         }
262     }
263     return false;
264 }
265 
266 class FindStructEqualityUse : public TIntermTraverser
267 {
268   public:
269     SymbolEnv &mSymbolEnv;
270     std::unordered_set<const TStructure *> mUsedStructs;
271 
FindStructEqualityUse(SymbolEnv & symbolEnv)272     FindStructEqualityUse(SymbolEnv &symbolEnv)
273         : TIntermTraverser(false, false, true), mSymbolEnv(symbolEnv)
274     {}
275 
visitBinary(Visit,TIntermBinary * binary)276     bool visitBinary(Visit, TIntermBinary *binary) override
277     {
278         const TOperator op = binary->getOp();
279 
280         switch (op)
281         {
282             case TOperator::EOpEqual:
283             case TOperator::EOpNotEqual:
284             {
285                 const TType &leftType  = binary->getLeft()->getType();
286                 const TType &rightType = binary->getRight()->getType();
287                 ASSERT(leftType.getStruct() == rightType.getStruct());
288                 if (const TStructure *structure = mSymbolEnv.remap(leftType.getStruct()))
289                 {
290                     useStruct(*structure);
291                 }
292             }
293             break;
294 
295             default:
296                 break;
297         }
298 
299         return true;
300     }
301 
302   private:
useStruct(const TStructure & structure)303     void useStruct(const TStructure &structure)
304     {
305         if (mUsedStructs.insert(&structure).second)
306         {
307             for (const TField *field : structure.fields())
308             {
309                 if (const TStructure *subStruct = mSymbolEnv.remap(field->type()->getStruct()))
310                 {
311                     useStruct(*subStruct);
312                 }
313             }
314         }
315     }
316 };
317 
318 }  // anonymous namespace
319 
320 ////////////////////////////////////////////////////////////////////////////////
321 
ToposortStructs(TCompiler & compiler,SymbolEnv & symbolEnv,TIntermBlock & root,ProgramPreludeConfig & ppc)322 bool sh::ToposortStructs(TCompiler &compiler,
323                          SymbolEnv &symbolEnv,
324                          TIntermBlock &root,
325                          ProgramPreludeConfig &ppc)
326 {
327     FindStructEqualityUse finder(symbolEnv);
328     root.traverse(&finder);
329     ppc.hasStructEq = !finder.mUsedStructs.empty();
330 
331     std::vector<DeclaredStructure> declaredStructs;
332     std::vector<TIntermNode *> nonStructStmtNodes;
333 
334     {
335         DeclaredStructure declaredStruct;
336         const size_t stmtCount = root.getChildCount();
337         for (size_t i = 0; i < stmtCount; ++i)
338         {
339             TIntermNode &stmtNode = *root.getChildNode(i);
340             if (GetAsDeclaredStructure(symbolEnv, stmtNode, declaredStruct,
341                                        compiler.getSymbolTable(), finder.mUsedStructs))
342             {
343                 declaredStructs.push_back(declaredStruct);
344             }
345             else
346             {
347                 nonStructStmtNodes.push_back(&stmtNode);
348             }
349         }
350     }
351 
352     {
353         std::vector<const TStructure *> structs;
354         std::unordered_map<const TStructure *, DeclaredStructure> rawToDeclared;
355 
356         for (const DeclaredStructure &d : declaredStructs)
357         {
358             structs.push_back(d.structure);
359             ASSERT(rawToDeclared.find(d.structure) == rawToDeclared.end());
360             rawToDeclared[d.structure] = d;
361         }
362 
363         // Note: Graph may contain more than only explicitly declared structures.
364         Graph<const TStructure *> g                   = BuildGraph(symbolEnv, structs);
365         std::vector<const TStructure *> sortedStructs = Toposort(g);
366         ASSERT(declaredStructs.size() <= sortedStructs.size());
367 
368         declaredStructs.clear();
369         for (const TStructure *s : sortedStructs)
370         {
371             auto it = rawToDeclared.find(s);
372             if (it != rawToDeclared.end())
373             {
374                 auto &d = it->second;
375                 ASSERT(d.declNode);
376                 declaredStructs.push_back(d);
377             }
378         }
379     }
380 
381     {
382         TIntermSequence newStmtNodes;
383 
384         for (DeclaredStructure &declaredStruct : declaredStructs)
385         {
386             ASSERT(declaredStruct.declNode);
387             newStmtNodes.push_back(declaredStruct.declNode);
388             if (declaredStruct.equalityFunctionDefinition)
389             {
390                 newStmtNodes.push_back(declaredStruct.equalityFunctionDefinition);
391             }
392         }
393 
394         for (TIntermNode *stmtNode : nonStructStmtNodes)
395         {
396             ASSERT(stmtNode);
397             newStmtNodes.push_back(stmtNode);
398         }
399 
400         *root.getSequence() = newStmtNodes;
401     }
402 
403     return compiler.validateAST(&root);
404 }
405