1 //
2 // Copyright 2016 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // Implementation of the integer pow expressions HLSL bug workaround.
7 // See header for more info.
8
9 #include "compiler/translator/tree_ops/d3d/ExpandIntegerPowExpressions.h"
10
11 #include <cmath>
12 #include <cstdlib>
13
14 #include "compiler/translator/tree_util/IntermNode_util.h"
15 #include "compiler/translator/tree_util/IntermTraverse.h"
16
17 namespace sh
18 {
19
20 namespace
21 {
22
23 class Traverser : public TIntermTraverser
24 {
25 public:
26 ANGLE_NO_DISCARD static bool Apply(TCompiler *compiler,
27 TIntermNode *root,
28 TSymbolTable *symbolTable);
29
30 private:
31 Traverser(TSymbolTable *symbolTable);
32 bool visitAggregate(Visit visit, TIntermAggregate *node) override;
33 void nextIteration();
34
35 bool mFound = false;
36 };
37
38 // static
Apply(TCompiler * compiler,TIntermNode * root,TSymbolTable * symbolTable)39 bool Traverser::Apply(TCompiler *compiler, TIntermNode *root, TSymbolTable *symbolTable)
40 {
41 Traverser traverser(symbolTable);
42 do
43 {
44 traverser.nextIteration();
45 root->traverse(&traverser);
46 if (traverser.mFound)
47 {
48 if (!traverser.updateTree(compiler, root))
49 {
50 return false;
51 }
52 }
53 } while (traverser.mFound);
54
55 return true;
56 }
57
Traverser(TSymbolTable * symbolTable)58 Traverser::Traverser(TSymbolTable *symbolTable) : TIntermTraverser(true, false, false, symbolTable)
59 {}
60
nextIteration()61 void Traverser::nextIteration()
62 {
63 mFound = false;
64 }
65
visitAggregate(Visit visit,TIntermAggregate * node)66 bool Traverser::visitAggregate(Visit visit, TIntermAggregate *node)
67 {
68 if (mFound)
69 {
70 return false;
71 }
72
73 // Test 0: skip non-pow operators.
74 if (node->getOp() != EOpPow)
75 {
76 return true;
77 }
78
79 const TIntermSequence *sequence = node->getSequence();
80 ASSERT(sequence->size() == 2u);
81 const TIntermConstantUnion *constantExponent = sequence->at(1)->getAsConstantUnion();
82
83 // Test 1: check for a single constant.
84 if (!constantExponent || constantExponent->getNominalSize() != 1)
85 {
86 return true;
87 }
88
89 float exponentValue = constantExponent->getConstantValue()->getFConst();
90
91 // Test 2: exponentValue is in the problematic range.
92 if (exponentValue < -5.0f || exponentValue > 9.0f)
93 {
94 return true;
95 }
96
97 // Test 3: exponentValue is integer or pretty close to an integer.
98 if (std::abs(exponentValue - std::round(exponentValue)) > 0.0001f)
99 {
100 return true;
101 }
102
103 // Test 4: skip -1, 0, and 1
104 int exponent = static_cast<int>(std::round(exponentValue));
105 int n = std::abs(exponent);
106 if (n < 2)
107 {
108 return true;
109 }
110
111 // Potential problem case detected, apply workaround.
112
113 TIntermTyped *lhs = sequence->at(0)->getAsTyped();
114 ASSERT(lhs);
115
116 TIntermDeclaration *lhsVariableDeclaration = nullptr;
117 TVariable *lhsVariable =
118 DeclareTempVariable(mSymbolTable, lhs, EvqTemporary, &lhsVariableDeclaration);
119 insertStatementInParentBlock(lhsVariableDeclaration);
120
121 // Create a chain of n-1 multiples.
122 TIntermTyped *current = CreateTempSymbolNode(lhsVariable);
123 for (int i = 1; i < n; ++i)
124 {
125 TIntermBinary *mul = new TIntermBinary(EOpMul, current, CreateTempSymbolNode(lhsVariable));
126 mul->setLine(node->getLine());
127 current = mul;
128 }
129
130 // For negative pow, compute the reciprocal of the positive pow.
131 if (exponent < 0)
132 {
133 TConstantUnion *oneVal = new TConstantUnion();
134 oneVal->setFConst(1.0f);
135 TIntermConstantUnion *oneNode = new TIntermConstantUnion(oneVal, node->getType());
136 TIntermBinary *div = new TIntermBinary(EOpDiv, oneNode, current);
137 current = div;
138 }
139
140 queueReplacement(current, OriginalNode::IS_DROPPED);
141 mFound = true;
142 return false;
143 }
144
145 } // anonymous namespace
146
ExpandIntegerPowExpressions(TCompiler * compiler,TIntermNode * root,TSymbolTable * symbolTable)147 bool ExpandIntegerPowExpressions(TCompiler *compiler, TIntermNode *root, TSymbolTable *symbolTable)
148 {
149 return Traverser::Apply(compiler, root, symbolTable);
150 }
151
152 } // namespace sh
153