• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2016 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // Implementation of the integer pow expressions HLSL bug workaround.
7 // See header for more info.
8 
9 #include "compiler/translator/tree_ops/d3d/ExpandIntegerPowExpressions.h"
10 
11 #include <cmath>
12 #include <cstdlib>
13 
14 #include "compiler/translator/tree_util/IntermNode_util.h"
15 #include "compiler/translator/tree_util/IntermTraverse.h"
16 
17 namespace sh
18 {
19 
20 namespace
21 {
22 
23 class Traverser : public TIntermTraverser
24 {
25   public:
26     ANGLE_NO_DISCARD static bool Apply(TCompiler *compiler,
27                                        TIntermNode *root,
28                                        TSymbolTable *symbolTable);
29 
30   private:
31     Traverser(TSymbolTable *symbolTable);
32     bool visitAggregate(Visit visit, TIntermAggregate *node) override;
33     void nextIteration();
34 
35     bool mFound = false;
36 };
37 
38 // static
Apply(TCompiler * compiler,TIntermNode * root,TSymbolTable * symbolTable)39 bool Traverser::Apply(TCompiler *compiler, TIntermNode *root, TSymbolTable *symbolTable)
40 {
41     Traverser traverser(symbolTable);
42     do
43     {
44         traverser.nextIteration();
45         root->traverse(&traverser);
46         if (traverser.mFound)
47         {
48             if (!traverser.updateTree(compiler, root))
49             {
50                 return false;
51             }
52         }
53     } while (traverser.mFound);
54 
55     return true;
56 }
57 
Traverser(TSymbolTable * symbolTable)58 Traverser::Traverser(TSymbolTable *symbolTable) : TIntermTraverser(true, false, false, symbolTable)
59 {}
60 
nextIteration()61 void Traverser::nextIteration()
62 {
63     mFound = false;
64 }
65 
visitAggregate(Visit visit,TIntermAggregate * node)66 bool Traverser::visitAggregate(Visit visit, TIntermAggregate *node)
67 {
68     if (mFound)
69     {
70         return false;
71     }
72 
73     // Test 0: skip non-pow operators.
74     if (node->getOp() != EOpPow)
75     {
76         return true;
77     }
78 
79     const TIntermSequence *sequence = node->getSequence();
80     ASSERT(sequence->size() == 2u);
81     const TIntermConstantUnion *constantExponent = sequence->at(1)->getAsConstantUnion();
82 
83     // Test 1: check for a single constant.
84     if (!constantExponent || constantExponent->getNominalSize() != 1)
85     {
86         return true;
87     }
88 
89     float exponentValue = constantExponent->getConstantValue()->getFConst();
90 
91     // Test 2: exponentValue is in the problematic range.
92     if (exponentValue < -5.0f || exponentValue > 9.0f)
93     {
94         return true;
95     }
96 
97     // Test 3: exponentValue is integer or pretty close to an integer.
98     if (std::abs(exponentValue - std::round(exponentValue)) > 0.0001f)
99     {
100         return true;
101     }
102 
103     // Test 4: skip -1, 0, and 1
104     int exponent = static_cast<int>(std::round(exponentValue));
105     int n        = std::abs(exponent);
106     if (n < 2)
107     {
108         return true;
109     }
110 
111     // Potential problem case detected, apply workaround.
112 
113     TIntermTyped *lhs = sequence->at(0)->getAsTyped();
114     ASSERT(lhs);
115 
116     TIntermDeclaration *lhsVariableDeclaration = nullptr;
117     TVariable *lhsVariable =
118         DeclareTempVariable(mSymbolTable, lhs, EvqTemporary, &lhsVariableDeclaration);
119     insertStatementInParentBlock(lhsVariableDeclaration);
120 
121     // Create a chain of n-1 multiples.
122     TIntermTyped *current = CreateTempSymbolNode(lhsVariable);
123     for (int i = 1; i < n; ++i)
124     {
125         TIntermBinary *mul = new TIntermBinary(EOpMul, current, CreateTempSymbolNode(lhsVariable));
126         mul->setLine(node->getLine());
127         current = mul;
128     }
129 
130     // For negative pow, compute the reciprocal of the positive pow.
131     if (exponent < 0)
132     {
133         TConstantUnion *oneVal = new TConstantUnion();
134         oneVal->setFConst(1.0f);
135         TIntermConstantUnion *oneNode = new TIntermConstantUnion(oneVal, node->getType());
136         TIntermBinary *div            = new TIntermBinary(EOpDiv, oneNode, current);
137         current                       = div;
138     }
139 
140     queueReplacement(current, OriginalNode::IS_DROPPED);
141     mFound = true;
142     return false;
143 }
144 
145 }  // anonymous namespace
146 
ExpandIntegerPowExpressions(TCompiler * compiler,TIntermNode * root,TSymbolTable * symbolTable)147 bool ExpandIntegerPowExpressions(TCompiler *compiler, TIntermNode *root, TSymbolTable *symbolTable)
148 {
149     return Traverser::Apply(compiler, root, symbolTable);
150 }
151 
152 }  // namespace sh
153