• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <cstdarg>
16 #include <cstdio>
17 #include <cstdlib>
18 #include <fstream>
19 #include <map>
20 #include <sstream>
21 #include <gtest/gtest.h>
22 #include "re2/re2.h"
23 #include "tensorflow/lite/testing/parse_testdata.h"
24 #include "tensorflow/lite/testing/tflite_driver.h"
25 #include "tensorflow/lite/testing/util.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/platform/subprocess.h"
29 #include "tensorflow/core/util/command_line_flags.h"
30 
31 namespace tflite {
32 namespace testing {
33 
34 namespace {
35 bool FLAGS_ignore_known_bugs = true;
36 // As archive file names are test-specific, no default is possible.
37 //
38 // This test supports input as both zip and tar, as a stock android image does
39 // not have unzip but does have tar.
40 string* FLAGS_zip_file_path = new string;
41 string* FLAGS_tar_file_path = new string;
42 #ifndef __ANDROID__
43 string* FLAGS_unzip_binary_path = new string("/usr/bin/unzip");
44 string* FLAGS_tar_binary_path = new string("/bin/tar");
45 #else
46 string* FLAGS_unzip_binary_path = new string("/system/bin/unzip");
47 string* FLAGS_tar_binary_path = new string("/system/bin/tar");
48 #endif
49 bool FLAGS_use_nnapi = false;
50 bool FLAGS_ignore_unsupported_nnapi = false;
51 }  // namespace
52 
53 // TensorFlow system environment for file system called.
54 tensorflow::Env* env = tensorflow::Env::Default();
55 
56 // List of tests that are expected to fail when
57 //   --test_arg=--ignore_known_bugs=false
58 // Key is a substring of the test name and value is a bug number.
59 // TODO(ahentz): make sure we clean this list up frequently.
60 std::map<string, string> kBrokenTests = {
61     // L2Norm only supports tensors with 4D or fewer.
62     {R"(^\/l2norm_dim=.*,epsilon=.*,input_shape=\[.,.,.,.,.*\])", "67963684"},
63 
64     // SpaceToBatchND only supports 4D tensors.
65     {R"(^\/space_to_batch_nd.*input_shape=\[1,4,4,4,1,1\])", "70848787"},
66 
67     // L2Norm only works for dim=-1.
68     {R"(^\/l2norm_dim=-2,epsilon=.*,input_shape=\[.,.\])", "67963812"},
69     {R"(^\/l2norm_dim=0,epsilon=.*,input_shape=\[.,.\])", "67963812"},
70     {R"(^\/l2norm_dim=-2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"},
71     {R"(^\/l2norm_dim=-2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
72     {R"(^\/l2norm_dim=2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"},
73     {R"(^\/l2norm_dim=2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
74     {R"(^\/l2norm_dim=0,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"},
75     {R"(^\/l2norm_dim=0,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
76     {R"(^\/l2norm_dim=1,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"},
77     {R"(^\/l2norm_dim=1,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
78     {R"(^\/l2norm_dim=\[2,3\],epsilon=.*,input_shape=\[3,15,14,3\])",
79      "67963812"},
80     {R"(^\/l2norm_dim=\[2,3\],epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
81 
82     // ResizeBilinear looks completely incompatible with Tensorflow
83     {R"(^\/resize_bilinear.*dtype=tf.int32)", "72401107"},
84 
85     // Transpose only supports 1D-4D input tensors.
86     {R"(^\/transpose.*input_shape=\[.,.,.,.,.\])", "71545879"},
87 
88     // No Support for float.
89     {R"(^\/floor_div.*dtype=tf\.float32)", "112859002"},
90 
91     // Relu does not support int32.
92     // These test cases appends a Relu after the tested ops when
93     // activation=True. The tests are failing since Relu doesn't support int32.
94     {R"(^\/div.*activation=True.*dtype=tf\.int32)", "112968789"},
95     {R"(^\/floor_div.*activation=True.*dtype=tf\.int32)", "112968789"},
96     {R"(^\/floor_mod.*activation=True.*dtype=tf\.int32)", "112968789"},
97     {R"(^\/floor_mod.*activation=True.*dtype=tf\.int64)", "112968789"},
98 
99     {R"(^\/sub.*dtype=tf\.int64)", "119126484"},
100     {R"(^\/div.*dtype=tf\.int64)", "119126484"},
101     {R"(^\/mul.*dtype=tf\.int64)", "119126484"},
102     {R"(^\/add.*dtype=tf\.int64)", "119126484"},
103     {R"(^\/floor_div.*dtype=tf\.int64)", "119126484"},
104     {R"(^\/squared_difference.*dtype=tf\.int64)", "119126484"},
105 };
106 
107 // Additional list of tests that are expected to fail when
108 //   --test_arg=--ignore_known_bugs=false
109 // and
110 //   --test_arg=--use_nnapi=true
111 // Note that issues related to lack of NNAPI support for a particular op are
112 // handled separately; this list is specifically for broken cases where
113 // execution produces broken output.
114 // Key is a substring of the test name and value is a bug number.
115 std::map<string, string> kBrokenNnapiTests = {
116     // Certain NNAPI kernels silently fail with int32 types.
117     {R"(^\/add.*dtype=tf\.int32)", "122987564"},
118     {R"(^\/concat.*dtype=tf\.int32)", "122987564"},
119     {R"(^\/mul.*dtype=tf\.int32)", "122987564"},
120     {R"(^\/space_to_depth.*dtype=tf\.int32)", "122987564"},
121 
122     // Certain NNAPI fully_connected shape permutations fail.
123     {R"(^\/fully_connected_constant_filter=True.*shape1=\[3,3\])", "122987564"},
124     {R"(^\/fully_connected_constant_filter=True.*shape1=\[4,4\])", "122987564"},
125     {R"(^\/fully_connected.*shape1=\[3,3\].*transpose_b=True)", "122987564"},
126     {R"(^\/fully_connected.*shape1=\[4,4\].*shape2=\[4,1\])", "122987564"},
127 };
128 
129 // Allows test data to be unarchived into a temporary directory and makes
130 // sure those temporary directories are removed later.
131 class ArchiveEnvironment : public ::testing::Environment {
132  public:
~ArchiveEnvironment()133   ~ArchiveEnvironment() override {}
134 
135   // Delete all temporary directories on teardown.
TearDown()136   void TearDown() override {
137     for (const auto& dir : temporary_directories_) {
138       tensorflow::int64 undeleted_dirs, undeleted_files;
139       TF_CHECK_OK(
140           env->DeleteRecursively(dir, &undeleted_dirs, &undeleted_files));
141     }
142     temporary_directories_.clear();
143   }
144 
145   // Unarchive `archive` file into a new temporary directory  `out_dir`.
UnArchive(const string & zip,const string & tar,string * out_dir)146   tensorflow::Status UnArchive(const string& zip, const string& tar,
147                                string* out_dir) {
148     string dir;
149     TF_CHECK_OK(MakeTemporaryDirectory(&dir));
150     tensorflow::SubProcess proc;
151     if (!zip.empty()) {
152       string unzip_binary = *FLAGS_unzip_binary_path;
153       TF_CHECK_OK(env->FileExists(unzip_binary));
154       TF_CHECK_OK(env->FileExists(zip));
155       proc.SetProgram(unzip_binary, {"unzip", "-d", dir, zip});
156     } else {
157       string tar_binary = *FLAGS_tar_binary_path;
158       TF_CHECK_OK(env->FileExists(tar_binary));
159       TF_CHECK_OK(env->FileExists(tar));
160       // 'o' needs to be explicitly set on Android so that
161       // untarring works as non-root (otherwise tries to chown
162       // files, which fails)
163       proc.SetProgram(tar_binary, {"tar", "xfo", tar, "-C", dir});
164     }
165     proc.SetChannelAction(tensorflow::CHAN_STDOUT, tensorflow::ACTION_PIPE);
166     proc.SetChannelAction(tensorflow::CHAN_STDERR, tensorflow::ACTION_PIPE);
167     if (!proc.Start())
168       return tensorflow::Status(tensorflow::error::UNKNOWN,
169                                 "unzip couldn't start");
170     string out, err;
171     int status = proc.Communicate(nullptr, &out, &err);
172     if (WEXITSTATUS(status) == 0) {
173       *out_dir = dir;
174       return tensorflow::Status::OK();
175     } else {
176       return tensorflow::Status(tensorflow::error::UNKNOWN,
177                                 "unzip failed. "
178                                 "stdout:\n" +
179                                     out + "\nstderr:\n" + err);
180     }
181   }
182 
183  private:
184   // Make a temporary directory and return its name in `temporary`.
MakeTemporaryDirectory(string * temporary)185   tensorflow::Status MakeTemporaryDirectory(string* temporary) {
186     if (env->LocalTempFilename(temporary)) {
187       TF_CHECK_OK(env->CreateDir(*temporary));
188       temporary_directories_.push_back(*temporary);
189       return tensorflow::Status::OK();
190     }
191     return tensorflow::Status(tensorflow::error::UNKNOWN,
192                               "make temporary directory failed");
193   }
194 
195   std::vector<string> temporary_directories_;
196 };
197 
198 // Return the singleton archive_environment.
archive_environment()199 ArchiveEnvironment* archive_environment() {
200   static ArchiveEnvironment* env = new ArchiveEnvironment;
201   return env;
202 }
203 
204 // Read the manifest.txt out of the unarchived archive file. Specifically
205 // `original_file` is the original zip file for error messages. `dir` is
206 // the temporary directory where the archive file has been unarchived and
207 // `test_paths` is the list of test prefixes that were in the manifest.
208 // Note, it is an error for a manifest to contain no tests.
ReadManifest(const string & original_file,const string & dir,std::vector<string> * test_paths)209 tensorflow::Status ReadManifest(const string& original_file, const string& dir,
210                                 std::vector<string>* test_paths) {
211   // Read the newline delimited list of entries in the manifest.
212   std::ifstream manifest_fp(dir + "/manifest.txt");
213   string manifest((std::istreambuf_iterator<char>(manifest_fp)),
214                   std::istreambuf_iterator<char>());
215   size_t pos = 0;
216   int added = 0;
217   while (true) {
218     size_t end_pos = manifest.find("\n", pos);
219     if (end_pos == string::npos) break;
220     string filename = manifest.substr(pos, end_pos - pos);
221     test_paths->push_back(dir + "/" + filename);
222     pos = end_pos + 1;
223     added += 1;
224   }
225   if (!added) {
226     string message = "Test had no examples: " + original_file;
227     return tensorflow::Status(tensorflow::error::UNKNOWN, message);
228   }
229   return tensorflow::Status::OK();
230 }
231 
232 // Get a list of tests from either zip or tar file
UnarchiveAndFindTestNames(const string & zip_file,const string & tar_file)233 std::vector<string> UnarchiveAndFindTestNames(const string& zip_file,
234                                               const string& tar_file) {
235   if (zip_file.empty() && tar_file.empty()) {
236     TF_CHECK_OK(tensorflow::Status(tensorflow::error::UNKNOWN,
237                                    "Neither zip_file nor tar_file was given"));
238   }
239   string decompress_tmp_dir;
240   TF_CHECK_OK(archive_environment()->UnArchive(zip_file, tar_file,
241                                                &decompress_tmp_dir));
242   std::vector<string> stuff;
243   if (!zip_file.empty()) {
244     TF_CHECK_OK(ReadManifest(zip_file, decompress_tmp_dir, &stuff));
245   } else {
246     TF_CHECK_OK(ReadManifest(tar_file, decompress_tmp_dir, &stuff));
247   }
248   return stuff;
249 }
250 
251 class OpsTest : public ::testing::TestWithParam<string> {};
252 
TEST_P(OpsTest,RunZipTests)253 TEST_P(OpsTest, RunZipTests) {
254   string test_path = GetParam();
255   string tflite_test_case = test_path + "_tests.txt";
256   string tflite_dir = test_path.substr(0, test_path.find_last_of("/"));
257   string test_name = test_path.substr(test_path.find_last_of('/'));
258 
259   std::ifstream tflite_stream(tflite_test_case);
260   ASSERT_TRUE(tflite_stream.is_open()) << tflite_test_case;
261   tflite::testing::TfLiteDriver test_driver(FLAGS_use_nnapi);
262   test_driver.SetModelBaseDir(tflite_dir);
263 
264   auto broken_tests = kBrokenTests;
265   if (FLAGS_use_nnapi) {
266     broken_tests.insert(kBrokenNnapiTests.begin(), kBrokenNnapiTests.end());
267   }
268 
269   string bug_number;
270   for (const auto& p : broken_tests) {
271     if (RE2::PartialMatch(test_name, p.first)) {
272       bug_number = p.second;
273     }
274   }
275 
276   bool result = tflite::testing::ParseAndRunTests(&tflite_stream, &test_driver);
277   string message = test_driver.GetErrorMessage();
278   if (bug_number.empty()) {
279     if (FLAGS_use_nnapi && FLAGS_ignore_unsupported_nnapi && !result) {
280       EXPECT_EQ(message, string("Failed to invoke interpreter")) << message;
281     } else {
282       EXPECT_TRUE(result) << message;
283     }
284   } else {
285     if (FLAGS_ignore_known_bugs) {
286       EXPECT_FALSE(result) << "Test was expected to fail but is now passing; "
287                               "you can mark http://b/"
288                            << bug_number << " as fixed! Yay!";
289     } else {
290       EXPECT_TRUE(result) << message << ": Possibly due to http://b/"
291                           << bug_number;
292     }
293   }
294 }
295 
296 struct ZipPathParamName {
297   template <class ParamType>
operator ()tflite::testing::ZipPathParamName298   string operator()(const ::testing::TestParamInfo<ParamType>& info) const {
299     string param_name = info.param;
300     size_t last_slash = param_name.find_last_of("\\/");
301     if (last_slash != string::npos) {
302       param_name = param_name.substr(last_slash);
303     }
304     for (size_t index = 0; index < param_name.size(); ++index) {
305       if (!isalnum(param_name[index]) && param_name[index] != '_')
306         param_name[index] = '_';
307     }
308     return param_name;
309   }
310 };
311 
312 INSTANTIATE_TEST_CASE_P(tests, OpsTest,
313                         ::testing::ValuesIn(UnarchiveAndFindTestNames(
314                             *FLAGS_zip_file_path, *FLAGS_tar_file_path)),
315                         ZipPathParamName());
316 
317 }  // namespace testing
318 }  // namespace tflite
319 
main(int argc,char ** argv)320 int main(int argc, char** argv) {
321   ::testing::AddGlobalTestEnvironment(tflite::testing::archive_environment());
322 
323   std::vector<tensorflow::Flag> flags = {
324       tensorflow::Flag(
325           "ignore_known_bugs", &tflite::testing::FLAGS_ignore_known_bugs,
326           "If a particular model is affected by a known bug, the "
327           "corresponding test should expect the outputs to not match."),
328       tensorflow::Flag(
329           "tar_file_path", tflite::testing::FLAGS_tar_file_path,
330           "Required (or zip_file_path): Location of the test tar file."),
331       tensorflow::Flag(
332           "zip_file_path", tflite::testing::FLAGS_zip_file_path,
333           "Required (or tar_file_path): Location of the test zip file."),
334       tensorflow::Flag("unzip_binary_path",
335                        tflite::testing::FLAGS_unzip_binary_path,
336                        "Location of a suitable unzip binary."),
337       tensorflow::Flag("tar_binary_path",
338                        tflite::testing::FLAGS_tar_binary_path,
339                        "Location of a suitable tar binary."),
340       tensorflow::Flag("use_nnapi", &tflite::testing::FLAGS_use_nnapi,
341                        "Whether to enable the NNAPI delegate"),
342       tensorflow::Flag("ignore_unsupported_nnapi",
343                        &tflite::testing::FLAGS_ignore_unsupported_nnapi,
344                        "Don't fail tests just because delegation to NNAPI "
345                        "is not possible")};
346   bool success = tensorflow::Flags::Parse(&argc, argv, flags);
347   if (!success || (argc == 2 && !strcmp(argv[1], "--helpfull"))) {
348     fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str());
349     return 1;
350   }
351 
352   ::tflite::LogToStderr();
353   // TODO(mikie): googletest arguments do not work - maybe the tensorflow flags
354   // parser removes them?
355   ::testing::InitGoogleTest(&argc, argv);
356   return RUN_ALL_TESTS();
357 }
358