• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // Top level driver for models and examples generated by test_generator.py
18 
19 #include "NeuralNetworksWrapper.h"
20 #include "TestHarness.h"
21 
22 #include <gtest/gtest.h>
23 #include <cassert>
24 #include <cmath>
25 #include <iostream>
26 #include <map>
27 
28 namespace generated_tests {
29 using namespace android::nn::wrapper;
30 
31 template <typename T>
32 class Example {
33    public:
34     typedef T ElementType;
35     typedef std::pair<std::map<int, std::vector<T>>,
36                       std::map<int, std::vector<T>>>
37         ExampleType;
38 
Execute(std::function<void (Model *)> create_model,std::vector<ExampleType> & examples,std::function<bool (const T,const T)> compare)39     static bool Execute(std::function<void(Model*)> create_model,
40                         std::vector<ExampleType>& examples,
41                         std::function<bool(const T, const T)> compare) {
42         Model model;
43         create_model(&model);
44         model.finish();
45 
46         int example_no = 1;
47         bool error = false;
48         for (auto& example : examples) {
49             Compilation compilation(&model);
50             compilation.finish();
51             Execution execution(&compilation);
52 
53             // Go through all inputs
54             for (auto& i : example.first) {
55                 std::vector<T>& input = i.second;
56                 // We interpret an empty vector as an optional argument
57                 // that has been omitted.
58                 if (input.size() == 0) {
59                     execution.setInput(i.first, nullptr, 0);
60                 } else {
61                     execution.setInput(i.first, (const void*)input.data(),
62                                        input.size() * sizeof(T));
63                 }
64             }
65 
66             std::map<int, std::vector<T>> test_outputs;
67 
68             assert(example.second.size() == 1);
69             int output_no = 0;
70             for (auto& i : example.second) {
71                 std::vector<T>& output = i.second;
72                 test_outputs[i.first].resize(output.size());
73                 std::vector<T>& test_output = test_outputs[i.first];
74                 execution.setOutput(output_no++, (void*)test_output.data(),
75                                     test_output.size() * sizeof(T));
76             }
77             Result r = execution.compute();
78             if (r != Result::NO_ERROR)
79                 std::cerr << "Execution was not completed normally\n";
80             bool mismatch = false;
81             for (auto& i : example.second) {
82                 const std::vector<T>& test = test_outputs[i.first];
83                 const std::vector<T>& golden = i.second;
84                 for (unsigned i = 0; i < golden.size(); i++) {
85                     if (compare(golden[i], test[i])) {
86                         std::cerr << " output[" << i << "] = " << (float)test[i]
87                                   << " (should be " << (float)golden[i]
88                                   << ")\n";
89                         error = error || true;
90                         mismatch = mismatch || true;
91                     }
92                 }
93             }
94             if (mismatch) {
95                 std::cerr << "Example: " << example_no++;
96                 std::cerr << " failed\n";
97             }
98         }
99         return error;
100     }
101 
102     // Test driver for those generated from ml/nn/runtime/test/spec
Execute(std::function<void (Model *)> create_model,std::function<bool (int)> is_ignored,std::vector<MixedTypedExampleType> & examples)103     static void Execute(std::function<void(Model*)> create_model,
104                         std::function<bool(int)> is_ignored,
105                         std::vector<MixedTypedExampleType>& examples) {
106         Model model;
107         create_model(&model);
108         model.finish();
109 
110         int example_no = 1;
111         for (auto& example : examples) {
112             SCOPED_TRACE(example_no++);
113             MixedTyped inputs = example.first;
114             const MixedTyped& golden = example.second;
115 
116             Compilation compilation(&model);
117             compilation.finish();
118             Execution execution(&compilation);
119 
120             // Set all inputs
121             for_all(inputs, [&execution](int idx, const void* p, size_t s) {
122                 ASSERT_EQ(Result::NO_ERROR, execution.setInput(idx, p, s));
123             });
124 
125             MixedTyped test;
126             // Go through all typed outputs
127             resize_accordingly(golden, test);
128             for_all(test, [&execution](int idx, void* p, size_t s) {
129                 ASSERT_EQ(Result::NO_ERROR, execution.setOutput(idx, p, s));
130             });
131 
132             Result r = execution.compute();
133             ASSERT_EQ(Result::NO_ERROR, r);
134             // Filter out don't cares
135             MixedTyped filtered_golden = filter(golden, is_ignored);
136             MixedTyped filtered_test = filter(test, is_ignored);
137             // We want "close-enough" results for float
138             compare(filtered_golden, filtered_test);
139         }
140     }
141 };
142 };  // namespace generated_tests
143 
144 using namespace android::nn::wrapper;
145 // Float32 examples
146 typedef generated_tests::Example<float>::ExampleType Example;
147 // Mixed-typed examples
148 typedef generated_tests::MixedTypedExampleType MixedTypedExample;
149 
Execute(std::function<void (Model *)> create_model,std::function<bool (int)> is_ignored,std::vector<MixedTypedExample> & examples)150 void Execute(std::function<void(Model*)> create_model,
151              std::function<bool(int)> is_ignored,
152              std::vector<MixedTypedExample>& examples) {
153     generated_tests::Example<float>::Execute(create_model, is_ignored,
154                                              examples);
155 }
156 
157 class GeneratedTests : public ::testing::Test {
158    protected:
SetUp()159     virtual void SetUp() {}
160 };
161 
162 // Testcases generated from runtime/test/specs/*.mod.py
163 #include "generated/all_generated_tests.cpp"
164 // End of testcases generated from runtime/test/specs/*.mod.py
165 
166 // Below are testcases geneated from TFLite testcases.
167 namespace conv_1_h3_w2_SAME {
168 std::vector<Example> examples = {
169 // Converted examples
170 #include "generated/examples/conv_1_h3_w2_SAME_tests.example.cc"
171 };
172 // Generated model constructor
173 #include "generated/models/conv_1_h3_w2_SAME.model.cpp"
174 }  // namespace conv_1_h3_w2_SAME
175 
176 namespace conv_1_h3_w2_VALID {
177 std::vector<Example> examples = {
178 // Converted examples
179 #include "generated/examples/conv_1_h3_w2_VALID_tests.example.cc"
180 };
181 // Generated model constructor
182 #include "generated/models/conv_1_h3_w2_VALID.model.cpp"
183 }  // namespace conv_1_h3_w2_VALID
184 
185 namespace conv_3_h3_w2_SAME {
186 std::vector<Example> examples = {
187 // Converted examples
188 #include "generated/examples/conv_3_h3_w2_SAME_tests.example.cc"
189 };
190 // Generated model constructor
191 #include "generated/models/conv_3_h3_w2_SAME.model.cpp"
192 }  // namespace conv_3_h3_w2_SAME
193 
194 namespace conv_3_h3_w2_VALID {
195 std::vector<Example> examples = {
196 // Converted examples
197 #include "generated/examples/conv_3_h3_w2_VALID_tests.example.cc"
198 };
199 // Generated model constructor
200 #include "generated/models/conv_3_h3_w2_VALID.model.cpp"
201 }  // namespace conv_3_h3_w2_VALID
202 
203 namespace depthwise_conv {
204 std::vector<Example> examples = {
205 // Converted examples
206 #include "generated/examples/depthwise_conv_tests.example.cc"
207 };
208 // Generated model constructor
209 #include "generated/models/depthwise_conv.model.cpp"
210 }  // namespace depthwise_conv
211 
212 namespace mobilenet {
213 std::vector<Example> examples = {
214 // Converted examples
215 #include "generated/examples/mobilenet_224_gender_basic_fixed_tests.example.cc"
216 };
217 // Generated model constructor
218 #include "generated/models/mobilenet_224_gender_basic_fixed.model.cpp"
219 }  // namespace mobilenet
220 
221 namespace {
Execute(std::function<void (Model *)> create_model,std::vector<Example> & examples)222 bool Execute(std::function<void(Model*)> create_model,
223              std::vector<Example>& examples) {
224     return generated_tests::Example<float>::Execute(
225         create_model, examples, [](float golden, float test) {
226             return std::fabs(golden - test) > 1.5e-5f;
227         });
228 }
229 }  // namespace
230 
TEST_F(GeneratedTests,conv_1_h3_w2_SAME)231 TEST_F(GeneratedTests, conv_1_h3_w2_SAME) {
232     ASSERT_EQ(
233         Execute(conv_1_h3_w2_SAME::CreateModel, conv_1_h3_w2_SAME::examples),
234         0);
235 }
236 
TEST_F(GeneratedTests,conv_1_h3_w2_VALID)237 TEST_F(GeneratedTests, conv_1_h3_w2_VALID) {
238     ASSERT_EQ(
239         Execute(conv_1_h3_w2_VALID::CreateModel, conv_1_h3_w2_VALID::examples),
240         0);
241 }
242 
TEST_F(GeneratedTests,conv_3_h3_w2_SAME)243 TEST_F(GeneratedTests, conv_3_h3_w2_SAME) {
244     ASSERT_EQ(
245         Execute(conv_3_h3_w2_SAME::CreateModel, conv_3_h3_w2_SAME::examples),
246         0);
247 }
248 
TEST_F(GeneratedTests,conv_3_h3_w2_VALID)249 TEST_F(GeneratedTests, conv_3_h3_w2_VALID) {
250     ASSERT_EQ(
251         Execute(conv_3_h3_w2_VALID::CreateModel, conv_3_h3_w2_VALID::examples),
252         0);
253 }
254 
TEST_F(GeneratedTests,depthwise_conv)255 TEST_F(GeneratedTests, depthwise_conv) {
256     ASSERT_EQ(Execute(depthwise_conv::CreateModel, depthwise_conv::examples),
257               0);
258 }
259 
TEST_F(GeneratedTests,mobilenet)260 TEST_F(GeneratedTests, mobilenet) {
261     ASSERT_EQ(Execute(mobilenet::CreateModel, mobilenet::examples), 0);
262 }
263