1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <cstdarg>
16 #include <cstdio>
17 #include <cstdlib>
18 #include <fstream>
19 #include <map>
20 #include <sstream>
21 #include <gtest/gtest.h>
22 #include "re2/re2.h"
23 #include "tensorflow/lite/testing/parse_testdata.h"
24 #include "tensorflow/lite/testing/tflite_driver.h"
25 #include "tensorflow/lite/testing/util.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/platform/subprocess.h"
29 #include "tensorflow/core/util/command_line_flags.h"
30
31 namespace tflite {
32 namespace testing {
33
34 namespace {
35 bool FLAGS_ignore_known_bugs = true;
36 // As archive file names are test-specific, no default is possible.
37 //
38 // This test supports input as both zip and tar, as a stock android image does
39 // not have unzip but does have tar.
40 string* FLAGS_zip_file_path = new string;
41 string* FLAGS_tar_file_path = new string;
42 #ifndef __ANDROID__
43 string* FLAGS_unzip_binary_path = new string("/usr/bin/unzip");
44 string* FLAGS_tar_binary_path = new string("/bin/tar");
45 #else
46 string* FLAGS_unzip_binary_path = new string("/system/bin/unzip");
47 string* FLAGS_tar_binary_path = new string("/system/bin/tar");
48 #endif
49 bool FLAGS_use_nnapi = false;
50 bool FLAGS_ignore_unsupported_nnapi = false;
51 } // namespace
52
53 // TensorFlow system environment for file system called.
54 tensorflow::Env* env = tensorflow::Env::Default();
55
56 // List of tests that are expected to fail when
57 // --test_arg=--ignore_known_bugs=false
58 // Key is a substring of the test name and value is a bug number.
59 // TODO(ahentz): make sure we clean this list up frequently.
60 std::map<string, string> kBrokenTests = {
61 // L2Norm only supports tensors with 4D or fewer.
62 {R"(^\/l2norm_dim=.*,epsilon=.*,input_shape=\[.,.,.,.,.*\])", "67963684"},
63
64 // SpaceToBatchND only supports 4D tensors.
65 {R"(^\/space_to_batch_nd.*input_shape=\[1,4,4,4,1,1\])", "70848787"},
66
67 // L2Norm only works for dim=-1.
68 {R"(^\/l2norm_dim=-2,epsilon=.*,input_shape=\[.,.\])", "67963812"},
69 {R"(^\/l2norm_dim=0,epsilon=.*,input_shape=\[.,.\])", "67963812"},
70 {R"(^\/l2norm_dim=-2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"},
71 {R"(^\/l2norm_dim=-2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
72 {R"(^\/l2norm_dim=2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"},
73 {R"(^\/l2norm_dim=2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
74 {R"(^\/l2norm_dim=0,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"},
75 {R"(^\/l2norm_dim=0,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
76 {R"(^\/l2norm_dim=1,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"},
77 {R"(^\/l2norm_dim=1,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
78 {R"(^\/l2norm_dim=\[2,3\],epsilon=.*,input_shape=\[3,15,14,3\])",
79 "67963812"},
80 {R"(^\/l2norm_dim=\[2,3\],epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
81
82 // ResizeBilinear looks completely incompatible with Tensorflow
83 {R"(^\/resize_bilinear.*dtype=tf.int32)", "72401107"},
84
85 // Transpose only supports 1D-4D input tensors.
86 {R"(^\/transpose.*input_shape=\[.,.,.,.,.\])", "71545879"},
87
88 // No Support for float.
89 {R"(^\/floor_div.*dtype=tf\.float32)", "112859002"},
90
91 // Relu does not support int32.
92 // These test cases appends a Relu after the tested ops when
93 // activation=True. The tests are failing since Relu doesn't support int32.
94 {R"(^\/div.*activation=True.*dtype=tf\.int32)", "112968789"},
95 {R"(^\/floor_div.*activation=True.*dtype=tf\.int32)", "112968789"},
96 {R"(^\/floor_mod.*activation=True.*dtype=tf\.int32)", "112968789"},
97 {R"(^\/floor_mod.*activation=True.*dtype=tf\.int64)", "112968789"},
98
99 {R"(^\/sub.*dtype=tf\.int64)", "119126484"},
100 {R"(^\/div.*dtype=tf\.int64)", "119126484"},
101 {R"(^\/mul.*dtype=tf\.int64)", "119126484"},
102 {R"(^\/add.*dtype=tf\.int64)", "119126484"},
103 {R"(^\/floor_div.*dtype=tf\.int64)", "119126484"},
104 {R"(^\/squared_difference.*dtype=tf\.int64)", "119126484"},
105 };
106
107 // Additional list of tests that are expected to fail when
108 // --test_arg=--ignore_known_bugs=false
109 // and
110 // --test_arg=--use_nnapi=true
111 // Note that issues related to lack of NNAPI support for a particular op are
112 // handled separately; this list is specifically for broken cases where
113 // execution produces broken output.
114 // Key is a substring of the test name and value is a bug number.
115 std::map<string, string> kBrokenNnapiTests = {
116 // Certain NNAPI kernels silently fail with int32 types.
117 {R"(^\/add.*dtype=tf\.int32)", "122987564"},
118 {R"(^\/concat.*dtype=tf\.int32)", "122987564"},
119 {R"(^\/mul.*dtype=tf\.int32)", "122987564"},
120 {R"(^\/space_to_depth.*dtype=tf\.int32)", "122987564"},
121
122 // Certain NNAPI fully_connected shape permutations fail.
123 {R"(^\/fully_connected_constant_filter=True.*shape1=\[3,3\])", "122987564"},
124 {R"(^\/fully_connected_constant_filter=True.*shape1=\[4,4\])", "122987564"},
125 {R"(^\/fully_connected.*shape1=\[3,3\].*transpose_b=True)", "122987564"},
126 {R"(^\/fully_connected.*shape1=\[4,4\].*shape2=\[4,1\])", "122987564"},
127 };
128
129 // Allows test data to be unarchived into a temporary directory and makes
130 // sure those temporary directories are removed later.
131 class ArchiveEnvironment : public ::testing::Environment {
132 public:
~ArchiveEnvironment()133 ~ArchiveEnvironment() override {}
134
135 // Delete all temporary directories on teardown.
TearDown()136 void TearDown() override {
137 for (const auto& dir : temporary_directories_) {
138 tensorflow::int64 undeleted_dirs, undeleted_files;
139 TF_CHECK_OK(
140 env->DeleteRecursively(dir, &undeleted_dirs, &undeleted_files));
141 }
142 temporary_directories_.clear();
143 }
144
145 // Unarchive `archive` file into a new temporary directory `out_dir`.
UnArchive(const string & zip,const string & tar,string * out_dir)146 tensorflow::Status UnArchive(const string& zip, const string& tar,
147 string* out_dir) {
148 string dir;
149 TF_CHECK_OK(MakeTemporaryDirectory(&dir));
150 tensorflow::SubProcess proc;
151 if (!zip.empty()) {
152 string unzip_binary = *FLAGS_unzip_binary_path;
153 TF_CHECK_OK(env->FileExists(unzip_binary));
154 TF_CHECK_OK(env->FileExists(zip));
155 proc.SetProgram(unzip_binary, {"unzip", "-d", dir, zip});
156 } else {
157 string tar_binary = *FLAGS_tar_binary_path;
158 TF_CHECK_OK(env->FileExists(tar_binary));
159 TF_CHECK_OK(env->FileExists(tar));
160 // 'o' needs to be explicitly set on Android so that
161 // untarring works as non-root (otherwise tries to chown
162 // files, which fails)
163 proc.SetProgram(tar_binary, {"tar", "xfo", tar, "-C", dir});
164 }
165 proc.SetChannelAction(tensorflow::CHAN_STDOUT, tensorflow::ACTION_PIPE);
166 proc.SetChannelAction(tensorflow::CHAN_STDERR, tensorflow::ACTION_PIPE);
167 if (!proc.Start())
168 return tensorflow::Status(tensorflow::error::UNKNOWN,
169 "unzip couldn't start");
170 string out, err;
171 int status = proc.Communicate(nullptr, &out, &err);
172 if (WEXITSTATUS(status) == 0) {
173 *out_dir = dir;
174 return tensorflow::Status::OK();
175 } else {
176 return tensorflow::Status(tensorflow::error::UNKNOWN,
177 "unzip failed. "
178 "stdout:\n" +
179 out + "\nstderr:\n" + err);
180 }
181 }
182
183 private:
184 // Make a temporary directory and return its name in `temporary`.
MakeTemporaryDirectory(string * temporary)185 tensorflow::Status MakeTemporaryDirectory(string* temporary) {
186 if (env->LocalTempFilename(temporary)) {
187 TF_CHECK_OK(env->CreateDir(*temporary));
188 temporary_directories_.push_back(*temporary);
189 return tensorflow::Status::OK();
190 }
191 return tensorflow::Status(tensorflow::error::UNKNOWN,
192 "make temporary directory failed");
193 }
194
195 std::vector<string> temporary_directories_;
196 };
197
198 // Return the singleton archive_environment.
archive_environment()199 ArchiveEnvironment* archive_environment() {
200 static ArchiveEnvironment* env = new ArchiveEnvironment;
201 return env;
202 }
203
204 // Read the manifest.txt out of the unarchived archive file. Specifically
205 // `original_file` is the original zip file for error messages. `dir` is
206 // the temporary directory where the archive file has been unarchived and
207 // `test_paths` is the list of test prefixes that were in the manifest.
208 // Note, it is an error for a manifest to contain no tests.
ReadManifest(const string & original_file,const string & dir,std::vector<string> * test_paths)209 tensorflow::Status ReadManifest(const string& original_file, const string& dir,
210 std::vector<string>* test_paths) {
211 // Read the newline delimited list of entries in the manifest.
212 std::ifstream manifest_fp(dir + "/manifest.txt");
213 string manifest((std::istreambuf_iterator<char>(manifest_fp)),
214 std::istreambuf_iterator<char>());
215 size_t pos = 0;
216 int added = 0;
217 while (true) {
218 size_t end_pos = manifest.find("\n", pos);
219 if (end_pos == string::npos) break;
220 string filename = manifest.substr(pos, end_pos - pos);
221 test_paths->push_back(dir + "/" + filename);
222 pos = end_pos + 1;
223 added += 1;
224 }
225 if (!added) {
226 string message = "Test had no examples: " + original_file;
227 return tensorflow::Status(tensorflow::error::UNKNOWN, message);
228 }
229 return tensorflow::Status::OK();
230 }
231
232 // Get a list of tests from either zip or tar file
UnarchiveAndFindTestNames(const string & zip_file,const string & tar_file)233 std::vector<string> UnarchiveAndFindTestNames(const string& zip_file,
234 const string& tar_file) {
235 if (zip_file.empty() && tar_file.empty()) {
236 TF_CHECK_OK(tensorflow::Status(tensorflow::error::UNKNOWN,
237 "Neither zip_file nor tar_file was given"));
238 }
239 string decompress_tmp_dir;
240 TF_CHECK_OK(archive_environment()->UnArchive(zip_file, tar_file,
241 &decompress_tmp_dir));
242 std::vector<string> stuff;
243 if (!zip_file.empty()) {
244 TF_CHECK_OK(ReadManifest(zip_file, decompress_tmp_dir, &stuff));
245 } else {
246 TF_CHECK_OK(ReadManifest(tar_file, decompress_tmp_dir, &stuff));
247 }
248 return stuff;
249 }
250
251 class OpsTest : public ::testing::TestWithParam<string> {};
252
TEST_P(OpsTest,RunZipTests)253 TEST_P(OpsTest, RunZipTests) {
254 string test_path = GetParam();
255 string tflite_test_case = test_path + "_tests.txt";
256 string tflite_dir = test_path.substr(0, test_path.find_last_of("/"));
257 string test_name = test_path.substr(test_path.find_last_of('/'));
258
259 std::ifstream tflite_stream(tflite_test_case);
260 ASSERT_TRUE(tflite_stream.is_open()) << tflite_test_case;
261 tflite::testing::TfLiteDriver test_driver(FLAGS_use_nnapi);
262 test_driver.SetModelBaseDir(tflite_dir);
263
264 auto broken_tests = kBrokenTests;
265 if (FLAGS_use_nnapi) {
266 broken_tests.insert(kBrokenNnapiTests.begin(), kBrokenNnapiTests.end());
267 }
268
269 string bug_number;
270 for (const auto& p : broken_tests) {
271 if (RE2::PartialMatch(test_name, p.first)) {
272 bug_number = p.second;
273 }
274 }
275
276 bool result = tflite::testing::ParseAndRunTests(&tflite_stream, &test_driver);
277 string message = test_driver.GetErrorMessage();
278 if (bug_number.empty()) {
279 if (FLAGS_use_nnapi && FLAGS_ignore_unsupported_nnapi && !result) {
280 EXPECT_EQ(message, string("Failed to invoke interpreter")) << message;
281 } else {
282 EXPECT_TRUE(result) << message;
283 }
284 } else {
285 if (FLAGS_ignore_known_bugs) {
286 EXPECT_FALSE(result) << "Test was expected to fail but is now passing; "
287 "you can mark http://b/"
288 << bug_number << " as fixed! Yay!";
289 } else {
290 EXPECT_TRUE(result) << message << ": Possibly due to http://b/"
291 << bug_number;
292 }
293 }
294 }
295
296 struct ZipPathParamName {
297 template <class ParamType>
operator ()tflite::testing::ZipPathParamName298 string operator()(const ::testing::TestParamInfo<ParamType>& info) const {
299 string param_name = info.param;
300 size_t last_slash = param_name.find_last_of("\\/");
301 if (last_slash != string::npos) {
302 param_name = param_name.substr(last_slash);
303 }
304 for (size_t index = 0; index < param_name.size(); ++index) {
305 if (!isalnum(param_name[index]) && param_name[index] != '_')
306 param_name[index] = '_';
307 }
308 return param_name;
309 }
310 };
311
312 INSTANTIATE_TEST_CASE_P(tests, OpsTest,
313 ::testing::ValuesIn(UnarchiveAndFindTestNames(
314 *FLAGS_zip_file_path, *FLAGS_tar_file_path)),
315 ZipPathParamName());
316
317 } // namespace testing
318 } // namespace tflite
319
main(int argc,char ** argv)320 int main(int argc, char** argv) {
321 ::testing::AddGlobalTestEnvironment(tflite::testing::archive_environment());
322
323 std::vector<tensorflow::Flag> flags = {
324 tensorflow::Flag(
325 "ignore_known_bugs", &tflite::testing::FLAGS_ignore_known_bugs,
326 "If a particular model is affected by a known bug, the "
327 "corresponding test should expect the outputs to not match."),
328 tensorflow::Flag(
329 "tar_file_path", tflite::testing::FLAGS_tar_file_path,
330 "Required (or zip_file_path): Location of the test tar file."),
331 tensorflow::Flag(
332 "zip_file_path", tflite::testing::FLAGS_zip_file_path,
333 "Required (or tar_file_path): Location of the test zip file."),
334 tensorflow::Flag("unzip_binary_path",
335 tflite::testing::FLAGS_unzip_binary_path,
336 "Location of a suitable unzip binary."),
337 tensorflow::Flag("tar_binary_path",
338 tflite::testing::FLAGS_tar_binary_path,
339 "Location of a suitable tar binary."),
340 tensorflow::Flag("use_nnapi", &tflite::testing::FLAGS_use_nnapi,
341 "Whether to enable the NNAPI delegate"),
342 tensorflow::Flag("ignore_unsupported_nnapi",
343 &tflite::testing::FLAGS_ignore_unsupported_nnapi,
344 "Don't fail tests just because delegation to NNAPI "
345 "is not possible")};
346 bool success = tensorflow::Flags::Parse(&argc, argv, flags);
347 if (!success || (argc == 2 && !strcmp(argv[1], "--helpfull"))) {
348 fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str());
349 return 1;
350 }
351
352 ::tflite::LogToStderr();
353 // TODO(mikie): googletest arguments do not work - maybe the tensorflow flags
354 // parser removes them?
355 ::testing::InitGoogleTest(&argc, argv);
356 return RUN_ALL_TESTS();
357 }
358