1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 /* Header-only library for various helpers of test harness 18 * See frameworks/ml/nn/runtime/test/TestGenerated.cpp for how this is used. 19 */ 20 #ifndef ANDROID_ML_NN_TOOLS_TEST_GENERATOR_TEST_HARNESS_H 21 #define ANDROID_ML_NN_TOOLS_TEST_GENERATOR_TEST_HARNESS_H 22 23 #include <gtest/gtest.h> 24 25 #include <functional> 26 #include <map> 27 #include <tuple> 28 #include <vector> 29 30 namespace generated_tests { 31 typedef std::map<int, std::vector<float>> Float32Operands; 32 typedef std::map<int, std::vector<int32_t>> Int32Operands; 33 typedef std::map<int, std::vector<uint8_t>> Quant8Operands; 34 typedef std::tuple<Float32Operands, // ANEURALNETWORKS_TENSOR_FLOAT32 35 Int32Operands, // ANEURALNETWORKS_TENSOR_INT32 36 Quant8Operands // ANEURALNETWORKS_TENSOR_QUANT8_ASYMM 37 > 38 MixedTyped; 39 typedef std::pair<MixedTyped, MixedTyped> MixedTypedExampleType; 40 41 template <typename T> 42 struct MixedTypedIndex {}; 43 44 template <> 45 struct MixedTypedIndex<float> { 46 static constexpr size_t index = 0; 47 }; 48 template <> 49 struct MixedTypedIndex<int32_t> { 50 static constexpr size_t index = 1; 51 }; 52 template <> 53 struct MixedTypedIndex<uint8_t> { 54 static constexpr size_t index = 2; 55 }; 56 57 // Go through all index-value pairs of a given input type 58 template <typename T> 59 inline void for_each(const MixedTyped& idx_and_data, 60 std::function<void(int, const std::vector<T>&)> execute) { 61 for (auto& i : std::get<MixedTypedIndex<T>::index>(idx_and_data)) { 62 execute(i.first, i.second); 63 } 64 } 65 66 // non-const variant of for_each 67 template <typename T> 68 inline void for_each(MixedTyped& idx_and_data, 69 std::function<void(int, std::vector<T>&)> execute) { 70 for (auto& i : std::get<MixedTypedIndex<T>::index>(idx_and_data)) { 71 execute(i.first, i.second); 72 } 73 } 74 75 // internal helper for for_all 76 template <typename T> 77 inline void for_all_internal( 78 MixedTyped& idx_and_data, 79 std::function<void(int, void*, size_t)> execute_this) { 80 for_each<T>(idx_and_data, [&execute_this](int idx, std::vector<T>& m) { 81 execute_this(idx, static_cast<void*>(m.data()), m.size() * sizeof(T)); 82 }); 83 } 84 85 // Go through all index-value pairs of all input types 86 // expects a functor that takes (int index, void *raw data, size_t sz) 87 inline void for_all(MixedTyped& idx_and_data, 88 std::function<void(int, void*, size_t)> execute_this) { 89 for_all_internal<float>(idx_and_data, execute_this); 90 for_all_internal<int32_t>(idx_and_data, execute_this); 91 for_all_internal<uint8_t>(idx_and_data, execute_this); 92 } 93 94 // Const variant of internal helper for for_all 95 template <typename T> 96 inline void for_all_internal( 97 const MixedTyped& idx_and_data, 98 std::function<void(int, const void*, size_t)> execute_this) { 99 for_each<T>(idx_and_data, [&execute_this](int idx, const std::vector<T>& m) { 100 execute_this(idx, static_cast<const void*>(m.data()), m.size() * sizeof(T)); 101 }); 102 } 103 104 // Go through all index-value pairs (const variant) 105 // expects a functor that takes (int index, const void *raw data, size_t sz) 106 inline void for_all( 107 const MixedTyped& idx_and_data, 108 std::function<void(int, const void*, size_t)> execute_this) { 109 for_all_internal<float>(idx_and_data, execute_this); 110 for_all_internal<int32_t>(idx_and_data, execute_this); 111 for_all_internal<uint8_t>(idx_and_data, execute_this); 112 } 113 114 // Helper template - resize test output per golden 115 template <typename ty, size_t tuple_index> 116 void resize_accordingly_(const MixedTyped& golden, MixedTyped& test) { 117 std::function<void(int, const std::vector<ty>&)> execute = 118 [&test](int index, const std::vector<ty>& m) { 119 auto& t = std::get<tuple_index>(test); 120 t[index].resize(m.size()); 121 }; 122 for_each<ty>(golden, execute); 123 } 124 125 inline void resize_accordingly(const MixedTyped& golden, MixedTyped& test) { 126 resize_accordingly_<float, 0>(golden, test); 127 resize_accordingly_<int32_t, 1>(golden, test); 128 resize_accordingly_<uint8_t, 2>(golden, test); 129 } 130 131 template <typename ty, size_t tuple_index> 132 void filter_internal(const MixedTyped& golden, MixedTyped* filtered, 133 std::function<bool(int)> is_ignored) { 134 for_each<ty>(golden, 135 [filtered, &is_ignored](int index, const std::vector<ty>& m) { 136 auto& g = std::get<tuple_index>(*filtered); 137 if (!is_ignored(index)) g[index] = m; 138 }); 139 } 140 141 inline MixedTyped filter(const MixedTyped& golden, 142 std::function<bool(int)> is_ignored) { 143 MixedTyped filtered; 144 filter_internal<float, 0>(golden, &filtered, is_ignored); 145 filter_internal<int32_t, 1>(golden, &filtered, is_ignored); 146 filter_internal<uint8_t, 2>(golden, &filtered, is_ignored); 147 return filtered; 148 } 149 150 // Compare results 151 #define VECTOR_TYPE(x) \ 152 typename std::tuple_element<x, MixedTyped>::type::mapped_type 153 #define VALUE_TYPE(x) VECTOR_TYPE(x)::value_type 154 template <size_t tuple_index> 155 void compare_( 156 const MixedTyped& golden, const MixedTyped& test, 157 std::function<void(VALUE_TYPE(tuple_index), VALUE_TYPE(tuple_index))> 158 cmp) { 159 for_each<VALUE_TYPE(tuple_index)>( 160 golden, 161 [&test, &cmp](int index, const VECTOR_TYPE(tuple_index) & m) { 162 const auto& test_operands = std::get<tuple_index>(test); 163 const auto& test_ty = test_operands.find(index); 164 ASSERT_NE(test_ty, test_operands.end()); 165 for (unsigned int i = 0; i < m.size(); i++) { 166 SCOPED_TRACE(testing::Message() 167 << "When comparing element " << i); 168 cmp(m[i], test_ty->second[i]); 169 } 170 }); 171 } 172 #undef VALUE_TYPE 173 #undef VECTOR_TYPE 174 inline void compare(const MixedTyped& golden, const MixedTyped& test) { 175 compare_<0>(golden, test, 176 [](float g, float t) { EXPECT_NEAR(g, t, 1.e-5f); }); 177 compare_<1>(golden, test, [](int32_t g, int32_t t) { EXPECT_EQ(g, t); }); 178 compare_<2>(golden, test, [](uint8_t g, uint8_t t) { EXPECT_NEAR(g, t, 1); }); 179 } 180 181 }; // namespace generated_tests 182 183 #endif // ANDROID_ML_NN_TOOLS_TEST_GENERATOR_TEST_HARNESS_H 184