1 // 2 // Copyright 2016 The ANGLE Project Authors. All rights reserved. 3 // Use of this source code is governed by a BSD-style license that can be 4 // found in the LICENSE file. 5 // 6 // ConstantFoldingTest.h: 7 // Utilities for constant folding tests. 8 // 9 10 #ifndef TESTS_TEST_UTILS_CONSTANTFOLDINGTEST_H_ 11 #define TESTS_TEST_UTILS_CONSTANTFOLDINGTEST_H_ 12 13 #include <vector> 14 15 #include "common/mathutil.h" 16 #include "compiler/translator/tree_util/FindMain.h" 17 #include "compiler/translator/tree_util/FindSymbolNode.h" 18 #include "compiler/translator/tree_util/IntermTraverse.h" 19 #include "tests/test_utils/ShaderCompileTreeTest.h" 20 21 namespace sh 22 { 23 24 class TranslatorESSL; 25 26 template <typename T> 27 class ConstantFinder : public TIntermTraverser 28 { 29 public: ConstantFinder(const std::vector<T> & constantVector)30 ConstantFinder(const std::vector<T> &constantVector) 31 : TIntermTraverser(true, false, false), 32 mConstantVector(constantVector), 33 mFaultTolerance(T()), 34 mFound(false) 35 {} 36 ConstantFinder(const std::vector<T> & constantVector,const T & faultTolerance)37 ConstantFinder(const std::vector<T> &constantVector, const T &faultTolerance) 38 : TIntermTraverser(true, false, false), 39 mConstantVector(constantVector), 40 mFaultTolerance(faultTolerance), 41 mFound(false) 42 {} 43 ConstantFinder(const T & value)44 ConstantFinder(const T &value) 45 : TIntermTraverser(true, false, false), mFaultTolerance(T()), mFound(false) 46 { 47 mConstantVector.push_back(value); 48 } 49 visitConstantUnion(TIntermConstantUnion * node)50 void visitConstantUnion(TIntermConstantUnion *node) 51 { 52 if (node->getType().getObjectSize() == mConstantVector.size()) 53 { 54 bool found = true; 55 for (size_t i = 0; i < mConstantVector.size(); i++) 56 { 57 if (!isEqual(node->getConstantValue()[i], mConstantVector[i])) 58 { 59 found = false; 60 break; 61 } 62 } 63 if (found) 64 { 65 mFound = found; 66 } 67 } 68 } 69 found()70 bool found() const { return mFound; } 71 72 private: isEqual(const TConstantUnion & node,const float & value)73 bool isEqual(const TConstantUnion &node, const float &value) const 74 { 75 if (node.getType() != EbtFloat) 76 { 77 return false; 78 } 79 if (value == std::numeric_limits<float>::infinity()) 80 { 81 return gl::isInf(node.getFConst()) && node.getFConst() > 0; 82 } 83 else if (value == -std::numeric_limits<float>::infinity()) 84 { 85 return gl::isInf(node.getFConst()) && node.getFConst() < 0; 86 } 87 else if (gl::isNaN(value)) 88 { 89 // All NaNs are treated as equal. 90 return gl::isNaN(node.getFConst()); 91 } 92 return mFaultTolerance >= fabsf(node.getFConst() - value); 93 } 94 isEqual(const TConstantUnion & node,const int & value)95 bool isEqual(const TConstantUnion &node, const int &value) const 96 { 97 if (node.getType() != EbtInt) 98 { 99 return false; 100 } 101 ASSERT(mFaultTolerance < std::numeric_limits<int>::max()); 102 // abs() returns 0 at least on some platforms when the minimum int value is passed in (it 103 // doesn't have a positive counterpart). 104 return mFaultTolerance >= abs(node.getIConst() - value) && 105 (node.getIConst() - value) != std::numeric_limits<int>::min(); 106 } 107 isEqual(const TConstantUnion & node,const unsigned int & value)108 bool isEqual(const TConstantUnion &node, const unsigned int &value) const 109 { 110 if (node.getType() != EbtUInt) 111 { 112 return false; 113 } 114 ASSERT(mFaultTolerance < static_cast<unsigned int>(std::numeric_limits<int>::max())); 115 return static_cast<int>(mFaultTolerance) >= 116 abs(static_cast<int>(node.getUConst() - value)) && 117 static_cast<int>(node.getUConst() - value) != std::numeric_limits<int>::min(); 118 } 119 isEqual(const TConstantUnion & node,const bool & value)120 bool isEqual(const TConstantUnion &node, const bool &value) const 121 { 122 if (node.getType() != EbtBool) 123 { 124 return false; 125 } 126 return node.getBConst() == value; 127 } 128 129 std::vector<T> mConstantVector; 130 T mFaultTolerance; 131 bool mFound; 132 }; 133 134 class ConstantFoldingTest : public ShaderCompileTreeTest 135 { 136 public: ConstantFoldingTest()137 ConstantFoldingTest() {} 138 139 protected: getShaderType()140 ::GLenum getShaderType() const override { return GL_FRAGMENT_SHADER; } getShaderSpec()141 ShShaderSpec getShaderSpec() const override { return SH_GLES3_1_SPEC; } 142 143 template <typename T> constantFoundInAST(T constant)144 bool constantFoundInAST(T constant) 145 { 146 ConstantFinder<T> finder(constant); 147 mASTRoot->traverse(&finder); 148 return finder.found(); 149 } 150 151 template <typename T> constantVectorFoundInAST(const std::vector<T> & constantVector)152 bool constantVectorFoundInAST(const std::vector<T> &constantVector) 153 { 154 ConstantFinder<T> finder(constantVector); 155 mASTRoot->traverse(&finder); 156 return finder.found(); 157 } 158 159 template <typename T> constantColumnMajorMatrixFoundInAST(const std::vector<T> & constantMatrix)160 bool constantColumnMajorMatrixFoundInAST(const std::vector<T> &constantMatrix) 161 { 162 return constantVectorFoundInAST(constantMatrix); 163 } 164 165 template <typename T> constantVectorNearFoundInAST(const std::vector<T> & constantVector,const T & faultTolerance)166 bool constantVectorNearFoundInAST(const std::vector<T> &constantVector, const T &faultTolerance) 167 { 168 ConstantFinder<T> finder(constantVector, faultTolerance); 169 mASTRoot->traverse(&finder); 170 return finder.found(); 171 } 172 symbolFoundInAST(const char * symbolName)173 bool symbolFoundInAST(const char *symbolName) 174 { 175 return FindSymbolNode(mASTRoot, ImmutableString(symbolName)) != nullptr; 176 } 177 symbolFoundInMain(const char * symbolName)178 bool symbolFoundInMain(const char *symbolName) 179 { 180 return FindSymbolNode(FindMain(mASTRoot), ImmutableString(symbolName)) != nullptr; 181 } 182 }; 183 184 class ConstantFoldingExpressionTest : public ConstantFoldingTest 185 { 186 public: ConstantFoldingExpressionTest()187 ConstantFoldingExpressionTest() {} 188 189 void evaluateFloat(const std::string &floatExpression); 190 void evaluateInt(const std::string &intExpression); 191 void evaluateUint(const std::string &uintExpression); 192 }; 193 194 } // namespace sh 195 196 #endif // TESTS_TEST_UTILS_CONSTANTFOLDINGTEST_H_ 197