1 //
2 // Copyright 2015 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RemovePow_test.cpp:
7 // Tests for removing pow() function calls from the AST.
8 //
9
10 #include "GLSLANG/ShaderLang.h"
11 #include "angle_gl.h"
12 #include "compiler/translator/TranslatorGLSL.h"
13 #include "compiler/translator/tree_util/NodeSearch.h"
14 #include "gtest/gtest.h"
15
16 using namespace sh;
17
18 class RemovePowTest : public testing::Test
19 {
20 public:
RemovePowTest()21 RemovePowTest() {}
22
23 protected:
SetUp()24 void SetUp() override
25 {
26 allocator.push();
27 SetGlobalPoolAllocator(&allocator);
28 ShBuiltInResources resources;
29 sh::InitBuiltInResources(&resources);
30 mTranslatorGLSL =
31 new sh::TranslatorGLSL(GL_FRAGMENT_SHADER, SH_GLES2_SPEC, SH_GLSL_COMPATIBILITY_OUTPUT);
32 ASSERT_TRUE(mTranslatorGLSL->Init(resources));
33 }
34
TearDown()35 void TearDown() override
36 {
37 SafeDelete(mTranslatorGLSL);
38 SetGlobalPoolAllocator(nullptr);
39 allocator.pop();
40 }
41
compile(const std::string & shaderString)42 void compile(const std::string &shaderString)
43 {
44 const char *shaderStrings[] = {shaderString.c_str()};
45 mASTRoot = mTranslatorGLSL->compileTreeForTesting(
46 shaderStrings, 1, SH_OBJECT_CODE | SH_REMOVE_POW_WITH_CONSTANT_EXPONENT);
47 if (!mASTRoot)
48 {
49 TInfoSink &infoSink = mTranslatorGLSL->getInfoSink();
50 FAIL() << "Shader compilation into ESSL failed " << infoSink.info.c_str();
51 }
52 }
53
54 template <class T>
foundInAST()55 bool foundInAST()
56 {
57 return T::search(mASTRoot);
58 }
59
60 private:
61 sh::TranslatorGLSL *mTranslatorGLSL;
62 TIntermNode *mASTRoot;
63
64 angle::PoolAllocator allocator;
65 };
66
67 // Check if there's a pow() node anywhere in the tree.
68 class FindPow : public sh::NodeSearchTraverser<FindPow>
69 {
70 public:
visitBinary(Visit visit,TIntermBinary * node)71 bool visitBinary(Visit visit, TIntermBinary *node) override
72 {
73 if (node->getOp() == EOpPow)
74 {
75 mFound = true;
76 }
77 return !mFound;
78 }
79 };
80
81 // Check if the tree starting at node corresponds to exp2(y * log2(x))
82 // If the tree matches, set base to the node corresponding to x.
IsPowWorkaround(TIntermNode * node,TIntermNode ** base)83 bool IsPowWorkaround(TIntermNode *node, TIntermNode **base)
84 {
85 TIntermUnary *exp = node->getAsUnaryNode();
86 if (exp != nullptr && exp->getOp() == EOpExp2)
87 {
88 TIntermBinary *mul = exp->getOperand()->getAsBinaryNode();
89 if (mul != nullptr && mul->isMultiplication())
90 {
91 TIntermUnary *log = mul->getRight()->getAsUnaryNode();
92 if (mul->getLeft()->getAsConstantUnion() && log != nullptr)
93 {
94 if (log->getOp() == EOpLog2)
95 {
96 if (base)
97 *base = log->getOperand();
98 return true;
99 }
100 }
101 }
102 }
103 return false;
104 }
105
106 // Check if there's a node with the correct workaround to pow anywhere in the tree.
107 class FindPowWorkaround : public sh::NodeSearchTraverser<FindPowWorkaround>
108 {
109 public:
visitUnary(Visit visit,TIntermUnary * node)110 bool visitUnary(Visit visit, TIntermUnary *node) override
111 {
112 mFound = IsPowWorkaround(node, nullptr);
113 return !mFound;
114 }
115 };
116
117 // Check if there's a node with the correct workaround to pow with another workaround to pow
118 // nested within it anywhere in the tree.
119 class FindNestedPowWorkaround : public sh::NodeSearchTraverser<FindNestedPowWorkaround>
120 {
121 public:
visitUnary(Visit visit,TIntermUnary * node)122 bool visitUnary(Visit visit, TIntermUnary *node) override
123 {
124 TIntermNode *base = nullptr;
125 bool oneFound = IsPowWorkaround(node, &base);
126 if (oneFound && base)
127 mFound = IsPowWorkaround(base, nullptr);
128 return !mFound;
129 }
130 };
131
TEST_F(RemovePowTest,PowWithConstantExponent)132 TEST_F(RemovePowTest, PowWithConstantExponent)
133 {
134 const std::string &shaderString =
135 "precision mediump float;\n"
136 "uniform float u;\n"
137 "void main() {\n"
138 " gl_FragColor = pow(vec4(u), vec4(0.5));\n"
139 "}\n";
140 compile(shaderString);
141 ASSERT_FALSE(foundInAST<FindPow>());
142 ASSERT_TRUE(foundInAST<FindPowWorkaround>());
143 ASSERT_FALSE(foundInAST<FindNestedPowWorkaround>());
144 }
145
TEST_F(RemovePowTest,NestedPowWithConstantExponent)146 TEST_F(RemovePowTest, NestedPowWithConstantExponent)
147 {
148 const std::string &shaderString =
149 "precision mediump float;\n"
150 "uniform float u;\n"
151 "void main() {\n"
152 " gl_FragColor = pow(pow(vec4(u), vec4(2.0)), vec4(0.5));\n"
153 "}\n";
154 compile(shaderString);
155 ASSERT_FALSE(foundInAST<FindPow>());
156 ASSERT_TRUE(foundInAST<FindNestedPowWorkaround>());
157 }
158