• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 /* Header-only library for various helpers of test harness
18  * See frameworks/ml/nn/runtime/test/TestGenerated.cpp for how this is used.
19  */
20 #ifndef ANDROID_ML_NN_TOOLS_TEST_GENERATOR_TEST_HARNESS_H
21 #define ANDROID_ML_NN_TOOLS_TEST_GENERATOR_TEST_HARNESS_H
22 
23 #include <gtest/gtest.h>
24 
25 #include <functional>
26 #include <map>
27 #include <tuple>
28 #include <vector>
29 
30 namespace generated_tests {
31 typedef std::map<int, std::vector<float>> Float32Operands;
32 typedef std::map<int, std::vector<int32_t>> Int32Operands;
33 typedef std::map<int, std::vector<uint8_t>> Quant8Operands;
34 typedef std::tuple<Float32Operands,  // ANEURALNETWORKS_TENSOR_FLOAT32
35                    Int32Operands,    // ANEURALNETWORKS_TENSOR_INT32
36                    Quant8Operands    // ANEURALNETWORKS_TENSOR_QUANT8_ASYMM
37                    >
38         MixedTyped;
39 typedef std::pair<MixedTyped, MixedTyped> MixedTypedExampleType;
40 
41 template <typename T>
42 struct MixedTypedIndex {};
43 
44 template <>
45 struct MixedTypedIndex<float> {
46     static constexpr size_t index = 0;
47 };
48 template <>
49 struct MixedTypedIndex<int32_t> {
50     static constexpr size_t index = 1;
51 };
52 template <>
53 struct MixedTypedIndex<uint8_t> {
54     static constexpr size_t index = 2;
55 };
56 
57 // Go through all index-value pairs of a given input type
58 template <typename T>
59 inline void for_each(const MixedTyped& idx_and_data,
60                      std::function<void(int, const std::vector<T>&)> execute) {
61     for (auto& i : std::get<MixedTypedIndex<T>::index>(idx_and_data)) {
62         execute(i.first, i.second);
63     }
64 }
65 
66 // non-const variant of for_each
67 template <typename T>
68 inline void for_each(MixedTyped& idx_and_data,
69                      std::function<void(int, std::vector<T>&)> execute) {
70     for (auto& i : std::get<MixedTypedIndex<T>::index>(idx_and_data)) {
71         execute(i.first, i.second);
72     }
73 }
74 
75 // internal helper for for_all
76 template <typename T>
77 inline void for_all_internal(
78         MixedTyped& idx_and_data,
79         std::function<void(int, void*, size_t)> execute_this) {
80     for_each<T>(idx_and_data, [&execute_this](int idx, std::vector<T>& m) {
81         execute_this(idx, static_cast<void*>(m.data()), m.size() * sizeof(T));
82     });
83 }
84 
85 // Go through all index-value pairs of all input types
86 // expects a functor that takes (int index, void *raw data, size_t sz)
87 inline void for_all(MixedTyped& idx_and_data,
88                     std::function<void(int, void*, size_t)> execute_this) {
89     for_all_internal<float>(idx_and_data, execute_this);
90     for_all_internal<int32_t>(idx_and_data, execute_this);
91     for_all_internal<uint8_t>(idx_and_data, execute_this);
92 }
93 
94 // Const variant of internal helper for for_all
95 template <typename T>
96 inline void for_all_internal(
97         const MixedTyped& idx_and_data,
98         std::function<void(int, const void*, size_t)> execute_this) {
99     for_each<T>(idx_and_data, [&execute_this](int idx, const std::vector<T>& m) {
100         execute_this(idx, static_cast<const void*>(m.data()), m.size() * sizeof(T));
101     });
102 }
103 
104 // Go through all index-value pairs (const variant)
105 // expects a functor that takes (int index, const void *raw data, size_t sz)
106 inline void for_all(
107         const MixedTyped& idx_and_data,
108         std::function<void(int, const void*, size_t)> execute_this) {
109     for_all_internal<float>(idx_and_data, execute_this);
110     for_all_internal<int32_t>(idx_and_data, execute_this);
111     for_all_internal<uint8_t>(idx_and_data, execute_this);
112 }
113 
114 // Helper template - resize test output per golden
115 template <typename ty, size_t tuple_index>
116 void resize_accordingly_(const MixedTyped& golden, MixedTyped& test) {
117     std::function<void(int, const std::vector<ty>&)> execute =
118             [&test](int index, const std::vector<ty>& m) {
119                 auto& t = std::get<tuple_index>(test);
120                 t[index].resize(m.size());
121             };
122     for_each<ty>(golden, execute);
123 }
124 
125 inline void resize_accordingly(const MixedTyped& golden, MixedTyped& test) {
126     resize_accordingly_<float, 0>(golden, test);
127     resize_accordingly_<int32_t, 1>(golden, test);
128     resize_accordingly_<uint8_t, 2>(golden, test);
129 }
130 
131 template <typename ty, size_t tuple_index>
132 void filter_internal(const MixedTyped& golden, MixedTyped* filtered,
133                      std::function<bool(int)> is_ignored) {
134     for_each<ty>(golden,
135                  [filtered, &is_ignored](int index, const std::vector<ty>& m) {
136                      auto& g = std::get<tuple_index>(*filtered);
137                      if (!is_ignored(index)) g[index] = m;
138                  });
139 }
140 
141 inline MixedTyped filter(const MixedTyped& golden,
142                          std::function<bool(int)> is_ignored) {
143     MixedTyped filtered;
144     filter_internal<float, 0>(golden, &filtered, is_ignored);
145     filter_internal<int32_t, 1>(golden, &filtered, is_ignored);
146     filter_internal<uint8_t, 2>(golden, &filtered, is_ignored);
147     return filtered;
148 }
149 
150 // Compare results
151 #define VECTOR_TYPE(x) \
152     typename std::tuple_element<x, MixedTyped>::type::mapped_type
153 #define VALUE_TYPE(x) VECTOR_TYPE(x)::value_type
154 template <size_t tuple_index>
155 void compare_(
156         const MixedTyped& golden, const MixedTyped& test,
157         std::function<void(VALUE_TYPE(tuple_index), VALUE_TYPE(tuple_index))>
158                 cmp) {
159     for_each<VALUE_TYPE(tuple_index)>(
160             golden,
161             [&test, &cmp](int index, const VECTOR_TYPE(tuple_index) & m) {
162                 const auto& test_operands = std::get<tuple_index>(test);
163                 const auto& test_ty = test_operands.find(index);
164                 ASSERT_NE(test_ty, test_operands.end());
165                 for (unsigned int i = 0; i < m.size(); i++) {
166                     SCOPED_TRACE(testing::Message()
167                                  << "When comparing element " << i);
168                     cmp(m[i], test_ty->second[i]);
169                 }
170             });
171 }
172 #undef VALUE_TYPE
173 #undef VECTOR_TYPE
174 inline void compare(const MixedTyped& golden, const MixedTyped& test) {
175     compare_<0>(golden, test,
176                 [](float g, float t) { EXPECT_NEAR(g, t, 1.e-5f); });
177     compare_<1>(golden, test, [](int32_t g, int32_t t) { EXPECT_EQ(g, t); });
178     compare_<2>(golden, test, [](uint8_t g, uint8_t t) { EXPECT_NEAR(g, t, 1); });
179 }
180 
181 };  // namespace generated_tests
182 
183 #endif  // ANDROID_ML_NN_TOOLS_TEST_GENERATOR_TEST_HARNESS_H
184