1 //
2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 // Analysis of the AST needed for HLSL generation
8
9 #include "compiler/translator/ASTMetadataHLSL.h"
10
11 #include "compiler/translator/CallDAG.h"
12 #include "compiler/translator/SymbolTable.h"
13 #include "compiler/translator/tree_util/IntermTraverse.h"
14
15 namespace sh
16 {
17
18 namespace
19 {
20
21 // Class used to traverse the AST of a function definition, checking if the
22 // function uses a gradient, and writing the set of control flow using gradients.
23 // It assumes that the analysis has already been made for the function's
24 // callees.
25 class PullGradient : public TIntermTraverser
26 {
27 public:
PullGradient(MetadataList * metadataList,size_t index,const CallDAG & dag)28 PullGradient(MetadataList *metadataList, size_t index, const CallDAG &dag)
29 : TIntermTraverser(true, false, true),
30 mMetadataList(metadataList),
31 mMetadata(&(*metadataList)[index]),
32 mIndex(index),
33 mDag(dag)
34 {
35 ASSERT(index < metadataList->size());
36
37 // ESSL 100 builtin gradient functions
38 mGradientBuiltinFunctions.insert(ImmutableString("texture2D"));
39 mGradientBuiltinFunctions.insert(ImmutableString("texture2DProj"));
40 mGradientBuiltinFunctions.insert(ImmutableString("textureCube"));
41
42 // ESSL 300 builtin gradient functions
43 mGradientBuiltinFunctions.insert(ImmutableString("texture"));
44 mGradientBuiltinFunctions.insert(ImmutableString("textureProj"));
45 mGradientBuiltinFunctions.insert(ImmutableString("textureOffset"));
46 mGradientBuiltinFunctions.insert(ImmutableString("textureProjOffset"));
47
48 // ESSL 310 doesn't add builtin gradient functions
49 }
50
traverse(TIntermFunctionDefinition * node)51 void traverse(TIntermFunctionDefinition *node)
52 {
53 node->traverse(this);
54 ASSERT(mParents.empty());
55 }
56
57 // Called when a gradient operation or a call to a function using a gradient is found.
onGradient()58 void onGradient()
59 {
60 mMetadata->mUsesGradient = true;
61 // Mark the latest control flow as using a gradient.
62 if (!mParents.empty())
63 {
64 mMetadata->mControlFlowsContainingGradient.insert(mParents.back());
65 }
66 }
67
visitControlFlow(Visit visit,TIntermNode * node)68 void visitControlFlow(Visit visit, TIntermNode *node)
69 {
70 if (visit == PreVisit)
71 {
72 mParents.push_back(node);
73 }
74 else if (visit == PostVisit)
75 {
76 ASSERT(mParents.back() == node);
77 mParents.pop_back();
78 // A control flow's using a gradient means its parents are too.
79 if (mMetadata->mControlFlowsContainingGradient.count(node) > 0 && !mParents.empty())
80 {
81 mMetadata->mControlFlowsContainingGradient.insert(mParents.back());
82 }
83 }
84 }
85
visitLoop(Visit visit,TIntermLoop * loop)86 bool visitLoop(Visit visit, TIntermLoop *loop) override
87 {
88 visitControlFlow(visit, loop);
89 return true;
90 }
91
visitIfElse(Visit visit,TIntermIfElse * ifElse)92 bool visitIfElse(Visit visit, TIntermIfElse *ifElse) override
93 {
94 visitControlFlow(visit, ifElse);
95 return true;
96 }
97
visitUnary(Visit visit,TIntermUnary * node)98 bool visitUnary(Visit visit, TIntermUnary *node) override
99 {
100 if (visit == PreVisit)
101 {
102 switch (node->getOp())
103 {
104 case EOpDFdx:
105 case EOpDFdy:
106 case EOpFwidth:
107 onGradient();
108 break;
109 default:
110 break;
111 }
112 }
113
114 return true;
115 }
116
visitAggregate(Visit visit,TIntermAggregate * node)117 bool visitAggregate(Visit visit, TIntermAggregate *node) override
118 {
119 if (visit == PreVisit)
120 {
121 if (node->getOp() == EOpCallFunctionInAST)
122 {
123 size_t calleeIndex = mDag.findIndex(node->getFunction()->uniqueId());
124 ASSERT(calleeIndex != CallDAG::InvalidIndex && calleeIndex < mIndex);
125
126 if ((*mMetadataList)[calleeIndex].mUsesGradient)
127 {
128 onGradient();
129 }
130 }
131 else if (node->getOp() == EOpCallBuiltInFunction)
132 {
133 if (mGradientBuiltinFunctions.find(node->getFunction()->name()) !=
134 mGradientBuiltinFunctions.end())
135 {
136 onGradient();
137 }
138 }
139 }
140
141 return true;
142 }
143
144 private:
145 MetadataList *mMetadataList;
146 ASTMetadataHLSL *mMetadata;
147 size_t mIndex;
148 const CallDAG &mDag;
149
150 // Contains a stack of the control flow nodes that are parents of the node being
151 // currently visited. It is used to mark control flows using a gradient.
152 std::vector<TIntermNode *> mParents;
153
154 // A list of builtin functions that use gradients
155 std::set<ImmutableString> mGradientBuiltinFunctions;
156 };
157
158 // Traverses the AST of a function definition to compute the the discontinuous loops
159 // and the if statements containing gradient loops. It assumes that the gradient loops
160 // (loops that contain a gradient) have already been computed and that it has already
161 // traversed the current function's callees.
162 class PullComputeDiscontinuousAndGradientLoops : public TIntermTraverser
163 {
164 public:
PullComputeDiscontinuousAndGradientLoops(MetadataList * metadataList,size_t index,const CallDAG & dag)165 PullComputeDiscontinuousAndGradientLoops(MetadataList *metadataList,
166 size_t index,
167 const CallDAG &dag)
168 : TIntermTraverser(true, false, true),
169 mMetadataList(metadataList),
170 mMetadata(&(*metadataList)[index]),
171 mIndex(index),
172 mDag(dag)
173 {}
174
traverse(TIntermFunctionDefinition * node)175 void traverse(TIntermFunctionDefinition *node)
176 {
177 node->traverse(this);
178 ASSERT(mLoopsAndSwitches.empty());
179 ASSERT(mIfs.empty());
180 }
181
182 // Called when traversing a gradient loop or a call to a function with a
183 // gradient loop in its call graph.
onGradientLoop()184 void onGradientLoop()
185 {
186 mMetadata->mHasGradientLoopInCallGraph = true;
187 // Mark the latest if as using a discontinuous loop.
188 if (!mIfs.empty())
189 {
190 mMetadata->mIfsContainingGradientLoop.insert(mIfs.back());
191 }
192 }
193
visitLoop(Visit visit,TIntermLoop * loop)194 bool visitLoop(Visit visit, TIntermLoop *loop) override
195 {
196 if (visit == PreVisit)
197 {
198 mLoopsAndSwitches.push_back(loop);
199
200 if (mMetadata->hasGradientInCallGraph(loop))
201 {
202 onGradientLoop();
203 }
204 }
205 else if (visit == PostVisit)
206 {
207 ASSERT(mLoopsAndSwitches.back() == loop);
208 mLoopsAndSwitches.pop_back();
209 }
210
211 return true;
212 }
213
visitIfElse(Visit visit,TIntermIfElse * node)214 bool visitIfElse(Visit visit, TIntermIfElse *node) override
215 {
216 if (visit == PreVisit)
217 {
218 mIfs.push_back(node);
219 }
220 else if (visit == PostVisit)
221 {
222 ASSERT(mIfs.back() == node);
223 mIfs.pop_back();
224 // An if using a discontinuous loop means its parents ifs are also discontinuous.
225 if (mMetadata->mIfsContainingGradientLoop.count(node) > 0 && !mIfs.empty())
226 {
227 mMetadata->mIfsContainingGradientLoop.insert(mIfs.back());
228 }
229 }
230
231 return true;
232 }
233
visitBranch(Visit visit,TIntermBranch * node)234 bool visitBranch(Visit visit, TIntermBranch *node) override
235 {
236 if (visit == PreVisit)
237 {
238 switch (node->getFlowOp())
239 {
240 case EOpBreak:
241 {
242 ASSERT(!mLoopsAndSwitches.empty());
243 TIntermLoop *loop = mLoopsAndSwitches.back()->getAsLoopNode();
244 if (loop != nullptr)
245 {
246 mMetadata->mDiscontinuousLoops.insert(loop);
247 }
248 }
249 break;
250 case EOpContinue:
251 {
252 ASSERT(!mLoopsAndSwitches.empty());
253 TIntermLoop *loop = nullptr;
254 size_t i = mLoopsAndSwitches.size();
255 while (loop == nullptr && i > 0)
256 {
257 --i;
258 loop = mLoopsAndSwitches.at(i)->getAsLoopNode();
259 }
260 ASSERT(loop != nullptr);
261 mMetadata->mDiscontinuousLoops.insert(loop);
262 }
263 break;
264 case EOpKill:
265 case EOpReturn:
266 // A return or discard jumps out of all the enclosing loops
267 if (!mLoopsAndSwitches.empty())
268 {
269 for (TIntermNode *intermNode : mLoopsAndSwitches)
270 {
271 TIntermLoop *loop = intermNode->getAsLoopNode();
272 if (loop)
273 {
274 mMetadata->mDiscontinuousLoops.insert(loop);
275 }
276 }
277 }
278 break;
279 default:
280 UNREACHABLE();
281 }
282 }
283
284 return true;
285 }
286
visitAggregate(Visit visit,TIntermAggregate * node)287 bool visitAggregate(Visit visit, TIntermAggregate *node) override
288 {
289 if (visit == PreVisit && node->getOp() == EOpCallFunctionInAST)
290 {
291 size_t calleeIndex = mDag.findIndex(node->getFunction()->uniqueId());
292 ASSERT(calleeIndex != CallDAG::InvalidIndex && calleeIndex < mIndex);
293
294 if ((*mMetadataList)[calleeIndex].mHasGradientLoopInCallGraph)
295 {
296 onGradientLoop();
297 }
298 }
299
300 return true;
301 }
302
visitSwitch(Visit visit,TIntermSwitch * node)303 bool visitSwitch(Visit visit, TIntermSwitch *node) override
304 {
305 if (visit == PreVisit)
306 {
307 mLoopsAndSwitches.push_back(node);
308 }
309 else if (visit == PostVisit)
310 {
311 ASSERT(mLoopsAndSwitches.back() == node);
312 mLoopsAndSwitches.pop_back();
313 }
314 return true;
315 }
316
317 private:
318 MetadataList *mMetadataList;
319 ASTMetadataHLSL *mMetadata;
320 size_t mIndex;
321 const CallDAG &mDag;
322
323 std::vector<TIntermNode *> mLoopsAndSwitches;
324 std::vector<TIntermIfElse *> mIfs;
325 };
326
327 // Tags all the functions called in a discontinuous loop
328 class PushDiscontinuousLoops : public TIntermTraverser
329 {
330 public:
PushDiscontinuousLoops(MetadataList * metadataList,size_t index,const CallDAG & dag)331 PushDiscontinuousLoops(MetadataList *metadataList, size_t index, const CallDAG &dag)
332 : TIntermTraverser(true, true, true),
333 mMetadataList(metadataList),
334 mMetadata(&(*metadataList)[index]),
335 mIndex(index),
336 mDag(dag),
337 mNestedDiscont(mMetadata->mCalledInDiscontinuousLoop ? 1 : 0)
338 {}
339
traverse(TIntermFunctionDefinition * node)340 void traverse(TIntermFunctionDefinition *node)
341 {
342 node->traverse(this);
343 ASSERT(mNestedDiscont == (mMetadata->mCalledInDiscontinuousLoop ? 1 : 0));
344 }
345
visitLoop(Visit visit,TIntermLoop * loop)346 bool visitLoop(Visit visit, TIntermLoop *loop) override
347 {
348 bool isDiscontinuous = mMetadata->mDiscontinuousLoops.count(loop) > 0;
349
350 if (visit == PreVisit && isDiscontinuous)
351 {
352 mNestedDiscont++;
353 }
354 else if (visit == PostVisit && isDiscontinuous)
355 {
356 mNestedDiscont--;
357 }
358
359 return true;
360 }
361
visitAggregate(Visit visit,TIntermAggregate * node)362 bool visitAggregate(Visit visit, TIntermAggregate *node) override
363 {
364 switch (node->getOp())
365 {
366 case EOpCallFunctionInAST:
367 if (visit == PreVisit && mNestedDiscont > 0)
368 {
369 size_t calleeIndex = mDag.findIndex(node->getFunction()->uniqueId());
370 ASSERT(calleeIndex != CallDAG::InvalidIndex && calleeIndex < mIndex);
371
372 (*mMetadataList)[calleeIndex].mCalledInDiscontinuousLoop = true;
373 }
374 break;
375 default:
376 break;
377 }
378 return true;
379 }
380
381 private:
382 MetadataList *mMetadataList;
383 ASTMetadataHLSL *mMetadata;
384 size_t mIndex;
385 const CallDAG &mDag;
386
387 int mNestedDiscont;
388 };
389 } // namespace
390
hasGradientInCallGraph(TIntermLoop * node)391 bool ASTMetadataHLSL::hasGradientInCallGraph(TIntermLoop *node)
392 {
393 return mControlFlowsContainingGradient.count(node) > 0;
394 }
395
hasGradientLoop(TIntermIfElse * node)396 bool ASTMetadataHLSL::hasGradientLoop(TIntermIfElse *node)
397 {
398 return mIfsContainingGradientLoop.count(node) > 0;
399 }
400
CreateASTMetadataHLSL(TIntermNode * root,const CallDAG & callDag)401 MetadataList CreateASTMetadataHLSL(TIntermNode *root, const CallDAG &callDag)
402 {
403 MetadataList metadataList(callDag.size());
404
405 // Compute all the information related to when gradient operations are used.
406 // We want to know for each function and control flow operation if they have
407 // a gradient operation in their call graph (shortened to "using a gradient"
408 // in the rest of the file).
409 //
410 // This computation is logically split in three steps:
411 // 1 - For each function compute if it uses a gradient in its body, ignoring
412 // calls to other user-defined functions.
413 // 2 - For each function determine if it uses a gradient in its call graph,
414 // using the result of step 1 and the CallDAG to know its callees.
415 // 3 - For each control flow statement of each function, check if it uses a
416 // gradient in the function's body, or if it calls a user-defined function that
417 // uses a gradient.
418 //
419 // We take advantage of the call graph being a DAG and instead compute 1, 2 and 3
420 // for leaves first, then going down the tree. This is correct because 1 doesn't
421 // depend on other functions, and 2 and 3 depend only on callees.
422 for (size_t i = 0; i < callDag.size(); i++)
423 {
424 PullGradient pull(&metadataList, i, callDag);
425 pull.traverse(callDag.getRecordFromIndex(i).node);
426 }
427
428 // Compute which loops are discontinuous and which function are called in
429 // these loops. The same way computing gradient usage is a "pull" process,
430 // computing "bing used in a discont. loop" is a push process. However we also
431 // need to know what ifs have a discontinuous loop inside so we do the same type
432 // of callgraph analysis as for the gradient.
433
434 // First compute which loops are discontinuous (no specific order) and pull
435 // the ifs and functions using a gradient loop.
436 for (size_t i = 0; i < callDag.size(); i++)
437 {
438 PullComputeDiscontinuousAndGradientLoops pull(&metadataList, i, callDag);
439 pull.traverse(callDag.getRecordFromIndex(i).node);
440 }
441
442 // Then push the information to callees, either from the a local discontinuous
443 // loop or from the caller being called in a discontinuous loop already
444 for (size_t i = callDag.size(); i-- > 0;)
445 {
446 PushDiscontinuousLoops push(&metadataList, i, callDag);
447 push.traverse(callDag.getRecordFromIndex(i).node);
448 }
449
450 // We create "Lod0" version of functions with the gradient operations replaced
451 // by non-gradient operations so that the D3D compiler is happier with discont
452 // loops.
453 for (auto &metadata : metadataList)
454 {
455 metadata.mNeedsLod0 = metadata.mCalledInDiscontinuousLoop && metadata.mUsesGradient;
456 }
457
458 return metadataList;
459 }
460
461 } // namespace sh
462