1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 // Top level driver for models and examples generated by test_generator.py
18
19 #include "NeuralNetworksWrapper.h"
20 #include "TestHarness.h"
21
22 #include <gtest/gtest.h>
23 #include <cassert>
24 #include <cmath>
25 #include <iostream>
26 #include <map>
27
28 namespace generated_tests {
29 using namespace android::nn::wrapper;
30
31 template <typename T>
32 class Example {
33 public:
34 typedef T ElementType;
35 typedef std::pair<std::map<int, std::vector<T>>,
36 std::map<int, std::vector<T>>>
37 ExampleType;
38
Execute(std::function<void (Model *)> create_model,std::vector<ExampleType> & examples,std::function<bool (const T,const T)> compare)39 static bool Execute(std::function<void(Model*)> create_model,
40 std::vector<ExampleType>& examples,
41 std::function<bool(const T, const T)> compare) {
42 Model model;
43 create_model(&model);
44 model.finish();
45
46 int example_no = 1;
47 bool error = false;
48 for (auto& example : examples) {
49 Compilation compilation(&model);
50 compilation.finish();
51 Execution execution(&compilation);
52
53 // Go through all inputs
54 for (auto& i : example.first) {
55 std::vector<T>& input = i.second;
56 // We interpret an empty vector as an optional argument
57 // that has been omitted.
58 if (input.size() == 0) {
59 execution.setInput(i.first, nullptr, 0);
60 } else {
61 execution.setInput(i.first, (const void*)input.data(),
62 input.size() * sizeof(T));
63 }
64 }
65
66 std::map<int, std::vector<T>> test_outputs;
67
68 assert(example.second.size() == 1);
69 int output_no = 0;
70 for (auto& i : example.second) {
71 std::vector<T>& output = i.second;
72 test_outputs[i.first].resize(output.size());
73 std::vector<T>& test_output = test_outputs[i.first];
74 execution.setOutput(output_no++, (void*)test_output.data(),
75 test_output.size() * sizeof(T));
76 }
77 Result r = execution.compute();
78 if (r != Result::NO_ERROR)
79 std::cerr << "Execution was not completed normally\n";
80 bool mismatch = false;
81 for (auto& i : example.second) {
82 const std::vector<T>& test = test_outputs[i.first];
83 const std::vector<T>& golden = i.second;
84 for (unsigned i = 0; i < golden.size(); i++) {
85 if (compare(golden[i], test[i])) {
86 std::cerr << " output[" << i << "] = " << (float)test[i]
87 << " (should be " << (float)golden[i]
88 << ")\n";
89 error = error || true;
90 mismatch = mismatch || true;
91 }
92 }
93 }
94 if (mismatch) {
95 std::cerr << "Example: " << example_no++;
96 std::cerr << " failed\n";
97 }
98 }
99 return error;
100 }
101
102 // Test driver for those generated from ml/nn/runtime/test/spec
Execute(std::function<void (Model *)> create_model,std::function<bool (int)> is_ignored,std::vector<MixedTypedExampleType> & examples)103 static void Execute(std::function<void(Model*)> create_model,
104 std::function<bool(int)> is_ignored,
105 std::vector<MixedTypedExampleType>& examples) {
106 Model model;
107 create_model(&model);
108 model.finish();
109
110 int example_no = 1;
111 for (auto& example : examples) {
112 SCOPED_TRACE(example_no++);
113 MixedTyped inputs = example.first;
114 const MixedTyped& golden = example.second;
115
116 Compilation compilation(&model);
117 compilation.finish();
118 Execution execution(&compilation);
119
120 // Set all inputs
121 for_all(inputs, [&execution](int idx, const void* p, size_t s) {
122 ASSERT_EQ(Result::NO_ERROR, execution.setInput(idx, p, s));
123 });
124
125 MixedTyped test;
126 // Go through all typed outputs
127 resize_accordingly(golden, test);
128 for_all(test, [&execution](int idx, void* p, size_t s) {
129 ASSERT_EQ(Result::NO_ERROR, execution.setOutput(idx, p, s));
130 });
131
132 Result r = execution.compute();
133 ASSERT_EQ(Result::NO_ERROR, r);
134 // Filter out don't cares
135 MixedTyped filtered_golden = filter(golden, is_ignored);
136 MixedTyped filtered_test = filter(test, is_ignored);
137 // We want "close-enough" results for float
138 compare(filtered_golden, filtered_test);
139 }
140 }
141 };
142 }; // namespace generated_tests
143
144 using namespace android::nn::wrapper;
145 // Float32 examples
146 typedef generated_tests::Example<float>::ExampleType Example;
147 // Mixed-typed examples
148 typedef generated_tests::MixedTypedExampleType MixedTypedExample;
149
Execute(std::function<void (Model *)> create_model,std::function<bool (int)> is_ignored,std::vector<MixedTypedExample> & examples)150 void Execute(std::function<void(Model*)> create_model,
151 std::function<bool(int)> is_ignored,
152 std::vector<MixedTypedExample>& examples) {
153 generated_tests::Example<float>::Execute(create_model, is_ignored,
154 examples);
155 }
156
157 class GeneratedTests : public ::testing::Test {
158 protected:
SetUp()159 virtual void SetUp() {}
160 };
161
162 // Testcases generated from runtime/test/specs/*.mod.py
163 #include "generated/all_generated_tests.cpp"
164 // End of testcases generated from runtime/test/specs/*.mod.py
165
166 // Below are testcases geneated from TFLite testcases.
167 namespace conv_1_h3_w2_SAME {
168 std::vector<Example> examples = {
169 // Converted examples
170 #include "generated/examples/conv_1_h3_w2_SAME_tests.example.cc"
171 };
172 // Generated model constructor
173 #include "generated/models/conv_1_h3_w2_SAME.model.cpp"
174 } // namespace conv_1_h3_w2_SAME
175
176 namespace conv_1_h3_w2_VALID {
177 std::vector<Example> examples = {
178 // Converted examples
179 #include "generated/examples/conv_1_h3_w2_VALID_tests.example.cc"
180 };
181 // Generated model constructor
182 #include "generated/models/conv_1_h3_w2_VALID.model.cpp"
183 } // namespace conv_1_h3_w2_VALID
184
185 namespace conv_3_h3_w2_SAME {
186 std::vector<Example> examples = {
187 // Converted examples
188 #include "generated/examples/conv_3_h3_w2_SAME_tests.example.cc"
189 };
190 // Generated model constructor
191 #include "generated/models/conv_3_h3_w2_SAME.model.cpp"
192 } // namespace conv_3_h3_w2_SAME
193
194 namespace conv_3_h3_w2_VALID {
195 std::vector<Example> examples = {
196 // Converted examples
197 #include "generated/examples/conv_3_h3_w2_VALID_tests.example.cc"
198 };
199 // Generated model constructor
200 #include "generated/models/conv_3_h3_w2_VALID.model.cpp"
201 } // namespace conv_3_h3_w2_VALID
202
203 namespace depthwise_conv {
204 std::vector<Example> examples = {
205 // Converted examples
206 #include "generated/examples/depthwise_conv_tests.example.cc"
207 };
208 // Generated model constructor
209 #include "generated/models/depthwise_conv.model.cpp"
210 } // namespace depthwise_conv
211
212 namespace mobilenet {
213 std::vector<Example> examples = {
214 // Converted examples
215 #include "generated/examples/mobilenet_224_gender_basic_fixed_tests.example.cc"
216 };
217 // Generated model constructor
218 #include "generated/models/mobilenet_224_gender_basic_fixed.model.cpp"
219 } // namespace mobilenet
220
221 namespace {
Execute(std::function<void (Model *)> create_model,std::vector<Example> & examples)222 bool Execute(std::function<void(Model*)> create_model,
223 std::vector<Example>& examples) {
224 return generated_tests::Example<float>::Execute(
225 create_model, examples, [](float golden, float test) {
226 return std::fabs(golden - test) > 1.5e-5f;
227 });
228 }
229 } // namespace
230
TEST_F(GeneratedTests,conv_1_h3_w2_SAME)231 TEST_F(GeneratedTests, conv_1_h3_w2_SAME) {
232 ASSERT_EQ(
233 Execute(conv_1_h3_w2_SAME::CreateModel, conv_1_h3_w2_SAME::examples),
234 0);
235 }
236
TEST_F(GeneratedTests,conv_1_h3_w2_VALID)237 TEST_F(GeneratedTests, conv_1_h3_w2_VALID) {
238 ASSERT_EQ(
239 Execute(conv_1_h3_w2_VALID::CreateModel, conv_1_h3_w2_VALID::examples),
240 0);
241 }
242
TEST_F(GeneratedTests,conv_3_h3_w2_SAME)243 TEST_F(GeneratedTests, conv_3_h3_w2_SAME) {
244 ASSERT_EQ(
245 Execute(conv_3_h3_w2_SAME::CreateModel, conv_3_h3_w2_SAME::examples),
246 0);
247 }
248
TEST_F(GeneratedTests,conv_3_h3_w2_VALID)249 TEST_F(GeneratedTests, conv_3_h3_w2_VALID) {
250 ASSERT_EQ(
251 Execute(conv_3_h3_w2_VALID::CreateModel, conv_3_h3_w2_VALID::examples),
252 0);
253 }
254
TEST_F(GeneratedTests,depthwise_conv)255 TEST_F(GeneratedTests, depthwise_conv) {
256 ASSERT_EQ(Execute(depthwise_conv::CreateModel, depthwise_conv::examples),
257 0);
258 }
259
TEST_F(GeneratedTests,mobilenet)260 TEST_F(GeneratedTests, mobilenet) {
261 ASSERT_EQ(Execute(mobilenet::CreateModel, mobilenet::examples), 0);
262 }
263