• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <cstdarg>
16 #include <cstdio>
17 #include <cstdlib>
18 #include <fstream>
19 #include <map>
20 #include <sstream>
21 #include <string>
22 #include <tuple>
23 #include <utility>
24 
25 #include <gtest/gtest.h>
26 #include "re2/re2.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/platform/subprocess.h"
29 #include "tensorflow/core/util/command_line_flags.h"
30 #include "tensorflow/lite/testing/parse_testdata.h"
31 #include "tensorflow/lite/testing/tflite_driver.h"
32 #include "tensorflow/lite/testing/util.h"
33 
34 namespace tflite {
35 namespace testing {
36 
37 namespace {
38 bool FLAGS_ignore_known_bugs = true;
39 // As archive file names are test-specific, no default is possible.
40 //
41 // This test supports input as both zip and tar, as a stock android image does
42 // not have unzip but does have tar.
43 string* FLAGS_zip_file_path = new string;
44 string* FLAGS_tar_file_path = new string;
45 #ifndef __ANDROID__
46 string* FLAGS_unzip_binary_path = new string("/usr/bin/unzip");
47 string* FLAGS_tar_binary_path = new string("/bin/tar");
48 #else
49 string* FLAGS_unzip_binary_path = new string("/system/bin/unzip");
50 string* FLAGS_tar_binary_path = new string("/system/bin/tar");
51 #endif
52 bool FLAGS_use_nnapi = false;
53 bool FLAGS_ignore_unsupported_nnapi = false;
54 }  // namespace
55 
56 // TensorFlow system environment for file system called.
57 tensorflow::Env* env = tensorflow::Env::Default();
58 
59 // Already known broken tests.
60 // Key is a substring of the test name and value is pair of a bug number and
61 // whether this test should be always ignored regardless of `ignore_known_bugs`.
62 // if `always_ignore` is false, tests are expected to fail when
63 //   --test_arg=--ignore_known_bugs=false
64 using BrokenTestMap =
65     std::map</* test_name */ string,
66              std::pair</* bug_number */ string, /* always_ignore */ bool>>;
67 // TODO(ahentz): make sure we clean this list up frequently.
GetKnownBrokenTests()68 const BrokenTestMap& GetKnownBrokenTests() {
69   static const BrokenTestMap* const kBrokenTests = new BrokenTestMap({
70       // TODO(b/194364155): TF and TFLite have different behaviors when output
71       // nan values in LocalResponseNorm ops.
72       {R"(^\/local_response_norm.*alpha=-3.*beta=2)", {"194364155", true}},
73       {R"(^\/local_response_norm.*alpha=(None|2).*beta=2.*bias=-0\.1.*depth_radius=(0|1).*input_shape=\[3,15,14,3\])",
74        {"194364155", true}},
75   });
76   return *kBrokenTests;
77 }
78 
79 // Additional list of tests that are expected to fail when
80 //   --test_arg=--ignore_known_bugs=false
81 // and
82 //   --test_arg=--use_nnapi=true
83 // Note that issues related to lack of NNAPI support for a particular op are
84 // handled separately; this list is specifically for broken cases where
85 // execution produces broken output.
86 // Key is a substring of the test name and value is a bug number.
GetKnownBrokenNnapiTests()87 const std::map<string, string>& GetKnownBrokenNnapiTests() {
88   static const std::map<string, string>* const kBrokenNnapiTests =
89       new std::map<string, string>({
90           // Certain NNAPI kernels silently fail with int32 types.
91           {R"(^\/add.*dtype=tf\.int32)", "122987564"},
92           {R"(^\/concat.*dtype=tf\.int32)", "122987564"},
93           {R"(^\/mul.*dtype=tf\.int32)", "122987564"},
94           {R"(^\/space_to_depth.*dtype=tf\.int32)", "122987564"},
95 
96           // Certain NNAPI fully_connected shape permutations fail.
97           {R"(^\/fully_connected_constant_filter=True.*shape1=\[3,3\])",
98            "122987564"},
99           {R"(^\/fully_connected_constant_filter=True.*shape1=\[4,4\])",
100            "122987564"},
101           {R"(^\/fully_connected.*shape1=\[3,3\].*transpose_b=True)",
102            "122987564"},
103           {R"(^\/fully_connected.*shape1=\[4,4\].*shape2=\[4,1\])",
104            "122987564"},
105       });
106   return *kBrokenNnapiTests;
107 }
108 
109 // List of quantize tests that are probably to fail.
110 // Quantized tflite models has high diff error with tensorflow models.
111 // Key is a substring of the test name and value is a bug number.
112 // TODO(b/134594898): Remove these bugs and corresponding codes or move them to
113 // kBrokenTests after b/134594898 is fixed.
GetKnownQuantizeBrokenTests()114 const std::map<string, string>& GetKnownQuantizeBrokenTests() {
115   static const std::map<string, string>* const kQuantizeBrokenTests =
116       new std::map<string, string>({
117           {R"(^\/conv.*fully_quantize=True)", "134594898"},
118           {R"(^\/depthwiseconv.*fully_quantize=True)", "134594898"},
119           {R"(^\/sum.*fully_quantize=True)", "134594898"},
120           {R"(^\/l2norm.*fully_quantize=True)", "134594898"},
121           {R"(^\/prelu.*fully_quantize=True)", "156112683"},
122       });
123   return *kQuantizeBrokenTests;
124 }
125 
126 // Allows test data to be unarchived into a temporary directory and makes
127 // sure those temporary directories are removed later.
128 class ArchiveEnvironment : public ::testing::Environment {
129  public:
~ArchiveEnvironment()130   ~ArchiveEnvironment() override {}
131 
132   // Delete all temporary directories on teardown.
TearDown()133   void TearDown() override {
134     for (const auto& dir : temporary_directories_) {
135       int64_t undeleted_dirs, undeleted_files;
136       TF_CHECK_OK(
137           env->DeleteRecursively(dir, &undeleted_dirs, &undeleted_files));
138     }
139     temporary_directories_.clear();
140   }
141 
142   // Unarchive `archive` file into a new temporary directory  `out_dir`.
UnArchive(const string & zip,const string & tar,string * out_dir)143   tensorflow::Status UnArchive(const string& zip, const string& tar,
144                                string* out_dir) {
145     string dir;
146     TF_CHECK_OK(MakeTemporaryDirectory(&dir));
147     tensorflow::SubProcess proc;
148     if (!zip.empty()) {
149       string unzip_binary = *FLAGS_unzip_binary_path;
150       TF_CHECK_OK(env->FileExists(unzip_binary));
151       TF_CHECK_OK(env->FileExists(zip));
152       proc.SetProgram(unzip_binary, {"unzip", "-d", dir, zip});
153     } else {
154       string tar_binary = *FLAGS_tar_binary_path;
155       TF_CHECK_OK(env->FileExists(tar_binary));
156       TF_CHECK_OK(env->FileExists(tar));
157       // 'o' needs to be explicitly set on Android so that
158       // untarring works as non-root (otherwise tries to chown
159       // files, which fails)
160       proc.SetProgram(tar_binary, {"tar", "xfo", tar, "-C", dir});
161     }
162     proc.SetChannelAction(tensorflow::CHAN_STDOUT, tensorflow::ACTION_PIPE);
163     proc.SetChannelAction(tensorflow::CHAN_STDERR, tensorflow::ACTION_PIPE);
164     if (!proc.Start())
165       return tensorflow::Status(tensorflow::error::UNKNOWN,
166                                 "unzip couldn't start");
167     string out, err;
168     int status = proc.Communicate(nullptr, &out, &err);
169     if (WEXITSTATUS(status) == 0) {
170       *out_dir = dir;
171       return ::tensorflow::OkStatus();
172     } else {
173       return tensorflow::Status(tensorflow::error::UNKNOWN,
174                                 "unzip failed. "
175                                 "stdout:\n" +
176                                     out + "\nstderr:\n" + err);
177     }
178   }
179 
180  private:
181   // Make a temporary directory and return its name in `temporary`.
MakeTemporaryDirectory(string * temporary)182   tensorflow::Status MakeTemporaryDirectory(string* temporary) {
183     if (env->LocalTempFilename(temporary)) {
184       TF_CHECK_OK(env->CreateDir(*temporary));
185       temporary_directories_.push_back(*temporary);
186       return ::tensorflow::OkStatus();
187     }
188     return tensorflow::Status(tensorflow::error::UNKNOWN,
189                               "make temporary directory failed");
190   }
191 
192   std::vector<string> temporary_directories_;
193 };
194 
195 // Return the singleton archive_environment.
archive_environment()196 ArchiveEnvironment* archive_environment() {
197   static ArchiveEnvironment* env = new ArchiveEnvironment;
198   return env;
199 }
200 
201 // Read the manifest.txt out of the unarchived archive file. Specifically
202 // `original_file` is the original zip file for error messages. `dir` is
203 // the temporary directory where the archive file has been unarchived and
204 // `test_paths` is the list of test prefixes that were in the manifest.
205 // Note, it is an error for a manifest to contain no tests.
ReadManifest(const string & original_file,const string & dir,std::vector<string> * test_paths)206 tensorflow::Status ReadManifest(const string& original_file, const string& dir,
207                                 std::vector<string>* test_paths) {
208   // Read the newline delimited list of entries in the manifest.
209   std::ifstream manifest_fp(dir + "/manifest.txt");
210   string manifest((std::istreambuf_iterator<char>(manifest_fp)),
211                   std::istreambuf_iterator<char>());
212   size_t pos = 0;
213   int added = 0;
214   while (true) {
215     size_t end_pos = manifest.find('\n', pos);
216     if (end_pos == string::npos) break;
217     string filename = manifest.substr(pos, end_pos - pos);
218     test_paths->push_back(dir + "/" + filename);
219     pos = end_pos + 1;
220     added += 1;
221   }
222   if (!added) {
223     string message = "Test had no examples: " + original_file;
224     return tensorflow::Status(tensorflow::error::UNKNOWN, message);
225   }
226   return ::tensorflow::OkStatus();
227 }
228 
229 // Get a list of tests from either zip or tar file
UnarchiveAndFindTestNames(const string & zip_file,const string & tar_file)230 std::vector<string> UnarchiveAndFindTestNames(const string& zip_file,
231                                               const string& tar_file) {
232   if (zip_file.empty() && tar_file.empty()) {
233     TF_CHECK_OK(tensorflow::Status(tensorflow::error::UNKNOWN,
234                                    "Neither zip_file nor tar_file was given"));
235   }
236   string decompress_tmp_dir;
237   TF_CHECK_OK(archive_environment()->UnArchive(zip_file, tar_file,
238                                                &decompress_tmp_dir));
239   std::vector<string> stuff;
240   if (!zip_file.empty()) {
241     TF_CHECK_OK(ReadManifest(zip_file, decompress_tmp_dir, &stuff));
242   } else {
243     TF_CHECK_OK(ReadManifest(tar_file, decompress_tmp_dir, &stuff));
244   }
245   return stuff;
246 }
247 
248 class OpsTest : public ::testing::TestWithParam<string> {};
249 
TEST_P(OpsTest,RunZipTests)250 TEST_P(OpsTest, RunZipTests) {
251   string test_path_and_label = GetParam();
252   string test_path = test_path_and_label;
253   string label = test_path_and_label;
254   size_t end_pos = test_path_and_label.find(' ');
255   if (end_pos != string::npos) {
256     test_path = test_path_and_label.substr(0, end_pos);
257     label = test_path_and_label.substr(end_pos + 1);
258   }
259   string tflite_test_case = test_path + "_tests.txt";
260   string tflite_dir = test_path.substr(0, test_path.find_last_of('/'));
261   string test_name = label.substr(label.find_last_of('/'));
262 
263   std::ifstream tflite_stream(tflite_test_case);
264   ASSERT_TRUE(tflite_stream.is_open()) << tflite_test_case;
265   tflite::testing::TfLiteDriver test_driver(
266       FLAGS_use_nnapi ? TfLiteDriver::DelegateType::kNnapi
267                       : TfLiteDriver::DelegateType::kNone);
268   bool fully_quantize = false;
269   if (label.find("fully_quantize=True") != std::string::npos) {
270     // TODO(b/134594898): Tighten this constraint.
271     test_driver.SetThreshold(0.2, 0.1);
272     fully_quantize = true;
273   }
274 
275   test_driver.SetModelBaseDir(tflite_dir);
276 
277   auto broken_tests = GetKnownBrokenTests();
278   if (FLAGS_use_nnapi) {
279     for (const auto& t : GetKnownBrokenNnapiTests()) {
280       broken_tests[t.first] =
281           std::make_pair(t.second, /* always ignore */ false);
282     }
283   }
284   auto quantize_broken_tests = GetKnownQuantizeBrokenTests();
285 
286   bool result = tflite::testing::ParseAndRunTests(&tflite_stream, &test_driver);
287   string message = test_driver.GetErrorMessage();
288 
289   if (!fully_quantize) {
290     string bug_number;
291     bool always_ignore;
292     for (const auto& p : broken_tests) {
293       if (RE2::PartialMatch(test_name, p.first)) {
294         std::tie(bug_number, always_ignore) = p.second;
295         break;
296       }
297     }
298     if (bug_number.empty()) {
299       if (FLAGS_use_nnapi && FLAGS_ignore_unsupported_nnapi && !result) {
300         EXPECT_EQ(message, string("Failed to invoke interpreter")) << message;
301       } else {
302         EXPECT_TRUE(result) << message;
303       }
304     } else {
305       if (FLAGS_ignore_known_bugs) {
306         EXPECT_FALSE(result) << "Test was expected to fail but is now passing; "
307                                 "you can mark http://b/"
308                              << bug_number << " as fixed! Yay!";
309       } else {
310         EXPECT_TRUE(result || always_ignore)
311             << message << ": Possibly due to http://b/" << bug_number;
312       }
313     }
314   } else {
315     if (!result) {
316       string bug_number;
317       // See if the tests are potential quantize failures.
318       for (const auto& p : quantize_broken_tests) {
319         if (RE2::PartialMatch(test_name, p.first)) {
320           bug_number = p.second;
321           break;
322         }
323       }
324       EXPECT_FALSE(bug_number.empty());
325     }
326   }
327 }
328 
329 struct ZipPathParamName {
330   template <class ParamType>
operator ()tflite::testing::ZipPathParamName331   string operator()(const ::testing::TestParamInfo<ParamType>& info) const {
332     string param_name = info.param;
333     size_t last_slash = param_name.find_last_of("\\/");
334     if (last_slash != string::npos) {
335       param_name = param_name.substr(last_slash);
336     }
337     for (size_t index = 0; index < param_name.size(); ++index) {
338       if (!isalnum(param_name[index]) && param_name[index] != '_')
339         param_name[index] = '_';
340     }
341     return param_name;
342   }
343 };
344 
345 INSTANTIATE_TEST_CASE_P(tests, OpsTest,
346                         ::testing::ValuesIn(UnarchiveAndFindTestNames(
347                             *FLAGS_zip_file_path, *FLAGS_tar_file_path)),
348                         ZipPathParamName());
349 
350 }  // namespace testing
351 }  // namespace tflite
352 
main(int argc,char ** argv)353 int main(int argc, char** argv) {
354   ::testing::AddGlobalTestEnvironment(tflite::testing::archive_environment());
355 
356   std::vector<tensorflow::Flag> flags = {
357       tensorflow::Flag(
358           "ignore_known_bugs", &tflite::testing::FLAGS_ignore_known_bugs,
359           "If a particular model is affected by a known bug, the "
360           "corresponding test should expect the outputs to not match."),
361       tensorflow::Flag(
362           "tar_file_path", tflite::testing::FLAGS_tar_file_path,
363           "Required (or zip_file_path): Location of the test tar file."),
364       tensorflow::Flag(
365           "zip_file_path", tflite::testing::FLAGS_zip_file_path,
366           "Required (or tar_file_path): Location of the test zip file."),
367       tensorflow::Flag("unzip_binary_path",
368                        tflite::testing::FLAGS_unzip_binary_path,
369                        "Location of a suitable unzip binary."),
370       tensorflow::Flag("tar_binary_path",
371                        tflite::testing::FLAGS_tar_binary_path,
372                        "Location of a suitable tar binary."),
373       tensorflow::Flag("use_nnapi", &tflite::testing::FLAGS_use_nnapi,
374                        "Whether to enable the NNAPI delegate"),
375       tensorflow::Flag("ignore_unsupported_nnapi",
376                        &tflite::testing::FLAGS_ignore_unsupported_nnapi,
377                        "Don't fail tests just because delegation to NNAPI "
378                        "is not possible")};
379   bool success = tensorflow::Flags::Parse(&argc, argv, flags);
380   if (!success || (argc == 2 && !strcmp(argv[1], "--helpfull"))) {
381     fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str());
382     return 1;
383   }
384 
385   if (!tflite::testing::TfLiteDriver::InitTestDelegateProviders(
386           &argc, const_cast<const char**>(argv))) {
387     return EXIT_FAILURE;
388   }
389 
390   ::tflite::LogToStderr();
391   // TODO(mikie): googletest arguments do not work - maybe the tensorflow flags
392   // parser removes them?
393   ::testing::InitGoogleTest(&argc, argv);
394   return RUN_ALL_TESTS();
395 }
396