• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/tools/command_line_flags.h"
17 
18 #include <string>
19 
20 #include <gmock/gmock.h>
21 #include <gtest/gtest.h>
22 #include "tensorflow/lite/tools/tool_params.h"
23 
24 namespace tflite {
25 namespace {
26 
TEST(CommandLineFlagsTest,BasicUsage)27 TEST(CommandLineFlagsTest, BasicUsage) {
28   int some_int32 = 10;
29   int some_int1 = 8;  // Not provided via arguments, the value should remain.
30   int some_int2 = 9;  // Required flag.
31   int64_t some_int64 = 21474836470;  // max int32 is 2147483647
32   bool some_switch = false;
33   std::string some_name = "something_a";
34   float some_float = -23.23f;
35   float float_1 = -23.23f;  // positional flag.
36   bool some_bool = false;
37   bool some_numeric_bool = true;
38   const char* argv_strings[] = {"program_name",
39                                 "12.2",
40                                 "--some_int32=20",
41                                 "--some_int2=5",
42                                 "--some_int64=214748364700",
43                                 "--some_switch=true",
44                                 "--some_name=somethingelse",
45                                 "--some_float=42.0",
46                                 "--some_bool=true",
47                                 "--some_numeric_bool=0"};
48   int argc = 10;
49   bool parsed_ok = Flags::Parse(
50       &argc, reinterpret_cast<const char**>(argv_strings),
51       {
52           Flag::CreateFlag("some_int32", &some_int32, "some int32"),
53           Flag::CreateFlag("some_int64", &some_int64, "some int64"),
54           Flag::CreateFlag("some_switch", &some_switch, "some switch"),
55           Flag::CreateFlag("some_name", &some_name, "some name"),
56           Flag::CreateFlag("some_float", &some_float, "some float"),
57           Flag::CreateFlag("some_bool", &some_bool, "some bool"),
58           Flag::CreateFlag("some_numeric_bool", &some_numeric_bool,
59                            "some numeric bool"),
60           Flag::CreateFlag("some_int1", &some_int1, "some int"),
61           Flag::CreateFlag("some_int2", &some_int2, "some int",
62                            Flag::kRequired),
63           Flag::CreateFlag("float_1", &float_1, "some float",
64                            Flag::kPositional),
65       });
66 
67   EXPECT_TRUE(parsed_ok);
68   EXPECT_EQ(20, some_int32);
69   EXPECT_EQ(8, some_int1);
70   EXPECT_EQ(5, some_int2);
71   EXPECT_EQ(214748364700, some_int64);
72   EXPECT_TRUE(some_switch);
73   EXPECT_EQ("somethingelse", some_name);
74   EXPECT_NEAR(42.0f, some_float, 1e-5f);
75   EXPECT_NEAR(12.2f, float_1, 1e-5f);
76   EXPECT_TRUE(some_bool);
77   EXPECT_FALSE(some_numeric_bool);
78   EXPECT_EQ(argc, 1);
79 }
80 
TEST(CommandLineFlagsTest,EmptyStringFlag)81 TEST(CommandLineFlagsTest, EmptyStringFlag) {
82   int argc = 2;
83   std::string some_string = "invalid";
84   const char* argv_strings[] = {"program_name", "--some_string="};
85   bool parsed_ok = Flags::Parse(
86       &argc, reinterpret_cast<const char**>(argv_strings),
87       {Flag::CreateFlag("some_string", &some_string, "some string")});
88 
89   EXPECT_TRUE(parsed_ok);
90   EXPECT_EQ(some_string, "");
91   EXPECT_EQ(argc, 1);
92 }
93 
TEST(CommandLineFlagsTest,BadIntValue)94 TEST(CommandLineFlagsTest, BadIntValue) {
95   int some_int = 10;
96   int argc = 2;
97   const char* argv_strings[] = {"program_name", "--some_int=notanumber"};
98   bool parsed_ok =
99       Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
100                    {Flag::CreateFlag("some_int", &some_int, "some int")});
101 
102   EXPECT_FALSE(parsed_ok);
103   EXPECT_EQ(10, some_int);
104   EXPECT_EQ(argc, 1);
105 }
106 
TEST(CommandLineFlagsTest,BadBoolValue)107 TEST(CommandLineFlagsTest, BadBoolValue) {
108   bool some_switch = false;
109   int argc = 2;
110   const char* argv_strings[] = {"program_name", "--some_switch=notabool"};
111   bool parsed_ok = Flags::Parse(
112       &argc, reinterpret_cast<const char**>(argv_strings),
113       {Flag::CreateFlag("some_switch", &some_switch, "some switch")});
114 
115   EXPECT_FALSE(parsed_ok);
116   EXPECT_FALSE(some_switch);
117   EXPECT_EQ(argc, 1);
118 }
119 
TEST(CommandLineFlagsTest,BadFloatValue)120 TEST(CommandLineFlagsTest, BadFloatValue) {
121   float some_float = -23.23f;
122   int argc = 2;
123   const char* argv_strings[] = {"program_name", "--some_float=notanumber"};
124   bool parsed_ok =
125       Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
126                    {Flag::CreateFlag("some_float", &some_float, "some float")});
127 
128   EXPECT_FALSE(parsed_ok);
129   EXPECT_NEAR(-23.23f, some_float, 1e-5f);
130   EXPECT_EQ(argc, 1);
131 }
132 
TEST(CommandLineFlagsTest,RequiredFlagNotFound)133 TEST(CommandLineFlagsTest, RequiredFlagNotFound) {
134   float some_float = -23.23f;
135   int argc = 2;
136   const char* argv_strings[] = {"program_name", "--flag=12"};
137   bool parsed_ok = Flags::Parse(
138       &argc, reinterpret_cast<const char**>(argv_strings),
139       {Flag::CreateFlag("some_flag", &some_float, "", Flag::kRequired)});
140 
141   EXPECT_FALSE(parsed_ok);
142   EXPECT_NEAR(-23.23f, some_float, 1e-5f);
143   EXPECT_EQ(argc, 2);
144 }
145 
TEST(CommandLineFlagsTest,NoArguments)146 TEST(CommandLineFlagsTest, NoArguments) {
147   float some_float = -23.23f;
148   int argc = 1;
149   const char* argv_strings[] = {"program_name"};
150   bool parsed_ok = Flags::Parse(
151       &argc, reinterpret_cast<const char**>(argv_strings),
152       {Flag::CreateFlag("some_flag", &some_float, "", Flag::kRequired)});
153 
154   EXPECT_FALSE(parsed_ok);
155   EXPECT_NEAR(-23.23f, some_float, 1e-5f);
156   EXPECT_EQ(argc, 1);
157 }
158 
TEST(CommandLineFlagsTest,NotEnoughArguments)159 TEST(CommandLineFlagsTest, NotEnoughArguments) {
160   float some_float = -23.23f;
161   int argc = 1;
162   const char* argv_strings[] = {"program_name"};
163   bool parsed_ok = Flags::Parse(
164       &argc, reinterpret_cast<const char**>(argv_strings),
165       {Flag::CreateFlag("some_flag", &some_float, "", Flag::kPositional)});
166 
167   EXPECT_FALSE(parsed_ok);
168   EXPECT_NEAR(-23.23f, some_float, 1e-5f);
169   EXPECT_EQ(argc, 1);
170 }
171 
TEST(CommandLineFlagsTest,PositionalFlagFailed)172 TEST(CommandLineFlagsTest, PositionalFlagFailed) {
173   float some_float = -23.23f;
174   int argc = 2;
175   const char* argv_strings[] = {"program_name", "string"};
176   bool parsed_ok = Flags::Parse(
177       &argc, reinterpret_cast<const char**>(argv_strings),
178       {Flag::CreateFlag("some_flag", &some_float, "", Flag::kPositional)});
179 
180   EXPECT_FALSE(parsed_ok);
181   EXPECT_NEAR(-23.23f, some_float, 1e-5f);
182   EXPECT_EQ(argc, 2);
183 }
184 
185 // Return whether str==pat, but allowing any whitespace in pat
186 // to match zero or more whitespace characters in str.
MatchWithAnyWhitespace(const std::string & str,const std::string & pat)187 static bool MatchWithAnyWhitespace(const std::string& str,
188                                    const std::string& pat) {
189   bool matching = true;
190   int pat_i = 0;
191   for (int str_i = 0; str_i != str.size() && matching; str_i++) {
192     if (isspace(str[str_i])) {
193       matching = (pat_i != pat.size() && isspace(pat[pat_i]));
194     } else {
195       while (pat_i != pat.size() && isspace(pat[pat_i])) {
196         pat_i++;
197       }
198       matching = (pat_i != pat.size() && str[str_i] == pat[pat_i++]);
199     }
200   }
201   while (pat_i != pat.size() && isspace(pat[pat_i])) {
202     pat_i++;
203   }
204   return (matching && pat_i == pat.size());
205 }
206 
TEST(CommandLineFlagsTest,UsageString)207 TEST(CommandLineFlagsTest, UsageString) {
208   int some_int = 10;
209   int64_t some_int64 = 21474836470;  // max int32 is 2147483647
210   bool some_switch = false;
211   std::string some_name = "something";
212   int some_int2 = 4;
213   // Don't test float in this case, because precision is hard to predict and
214   // match against, and we don't want a flakey test.
215   const std::string tool_name = "some_tool_name";
216   std::string usage = Flags::Usage(
217       tool_name,
218       {Flag::CreateFlag("some_int", &some_int, "some int"),
219        Flag::CreateFlag("some_int64", &some_int64, "some int64"),
220        Flag::CreateFlag("some_switch", &some_switch, "some switch"),
221        Flag::CreateFlag("some_name", &some_name, "some name", Flag::kRequired),
222        Flag::CreateFlag("some_int2", &some_int2, "some int",
223                         Flag::kPositional)});
224   // Match the usage message, being sloppy about whitespace.
225   const char* expected_usage =
226       " usage: some_tool_name <some_int2> <flags>\n"
227       "Where:\n"
228       "some_int2\tint32\trequired\tsome int\n"
229       "Flags:\n"
230       "--some_name=something\tstring\trequired\tsome name\n"
231       "--some_int=10\tint32\toptional\tsome int\n"
232       "--some_int64=21474836470\tint64\toptional\tsome int64\n"
233       "--some_switch=false\tbool\toptional\tsome switch\n";
234   ASSERT_EQ(MatchWithAnyWhitespace(usage, expected_usage), true) << usage;
235 
236   // Again but with no flags.
237   usage = Flags::Usage(tool_name, {});
238   ASSERT_EQ(MatchWithAnyWhitespace(usage, " usage: some_tool_name\n"), true)
239       << usage;
240 }
241 
242 // When there are duplicate args, the flag value and the parsing result will be
243 // based on the 1st arg.
TEST(CommandLineFlagsTest,DuplicateArgsParsableValues)244 TEST(CommandLineFlagsTest, DuplicateArgsParsableValues) {
245   int some_int = -23;
246   int argc = 3;
247   const char* argv_strings[] = {"program_name", "--some_int=1", "--some_int=2"};
248   bool parsed_ok =
249       Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
250                    {Flag::CreateFlag("some_int", &some_int, "some int")});
251 
252   EXPECT_TRUE(parsed_ok);
253   EXPECT_EQ(1, some_int);
254   EXPECT_EQ(argc, 2);
255   EXPECT_EQ("--some_int=2", argv_strings[1]);
256 }
257 
TEST(CommandLineFlagsTest,DuplicateArgsBadValueAppearFirst)258 TEST(CommandLineFlagsTest, DuplicateArgsBadValueAppearFirst) {
259   int some_int = -23;
260   int argc = 3;
261   const char* argv_strings[] = {"program_name", "--some_int=value",
262                                 "--some_int=1"};
263   bool parsed_ok =
264       Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
265                    {Flag::CreateFlag("some_int", &some_int, "some int")});
266 
267   EXPECT_FALSE(parsed_ok);
268   EXPECT_EQ(-23, some_int);
269   EXPECT_EQ(argc, 2);
270   EXPECT_EQ("--some_int=1", argv_strings[1]);
271 }
272 
TEST(CommandLineFlagsTest,DuplicateArgsBadValueAppearSecondly)273 TEST(CommandLineFlagsTest, DuplicateArgsBadValueAppearSecondly) {
274   int some_int = -23;
275   int argc = 3;
276   // Although the 2nd arg has non-parsable int value, the flag 'some_int' value
277   // could still be set and the parsing result is ok.
278   const char* argv_strings[] = {"program_name", "--some_int=1",
279                                 "--some_int=value"};
280   bool parsed_ok =
281       Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
282                    {Flag::CreateFlag("some_int", &some_int, "some int")});
283 
284   EXPECT_TRUE(parsed_ok);
285   EXPECT_EQ(1, some_int);
286   EXPECT_EQ(argc, 2);
287   EXPECT_EQ("--some_int=value", argv_strings[1]);
288 }
289 
290 // When there are duplicate flags, all of them will be checked against the arg
291 // list.
TEST(CommandLineFlagsTest,DuplicateFlags)292 TEST(CommandLineFlagsTest, DuplicateFlags) {
293   int some_int1 = -23;
294   int some_int2 = -23;
295   int argc = 2;
296   const char* argv_strings[] = {"program_name", "--some_int=1"};
297   bool parsed_ok =
298       Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
299                    {Flag::CreateFlag("some_int", &some_int1, "some int1"),
300                     Flag::CreateFlag("some_int", &some_int2, "some int2")});
301 
302   EXPECT_TRUE(parsed_ok);
303   EXPECT_EQ(1, some_int1);
304   EXPECT_EQ(1, some_int2);
305   EXPECT_EQ(argc, 1);
306 }
307 
TEST(CommandLineFlagsTest,DuplicateFlagsNotFound)308 TEST(CommandLineFlagsTest, DuplicateFlagsNotFound) {
309   int some_int1 = -23;
310   int some_int2 = -23;
311   int argc = 2;
312   const char* argv_strings[] = {"program_name", "--some_float=1.0"};
313   bool parsed_ok = Flags::Parse(
314       &argc, reinterpret_cast<const char**>(argv_strings),
315       {Flag::CreateFlag("some_int", &some_int1, "some int1", Flag::kOptional),
316        Flag::CreateFlag("some_int", &some_int2, "some int2", Flag::kRequired)});
317 
318   EXPECT_FALSE(parsed_ok);
319   EXPECT_EQ(-23, some_int1);
320   EXPECT_EQ(-23, some_int2);
321   EXPECT_EQ(argc, 2);
322 }
323 
TEST(CommandLineFlagsTest,DuplicateFlagNamesButDifferentTypes)324 TEST(CommandLineFlagsTest, DuplicateFlagNamesButDifferentTypes) {
325   int some_int = -23;
326   bool some_bool = true;
327   int argc = 2;
328   const char* argv_strings[] = {"program_name", "--some_val=20"};
329   // In this case, the 2nd 'some_val' flag of bool type will cause a no-ok
330   // parsing result.
331   bool parsed_ok =
332       Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
333                    {Flag::CreateFlag("some_val", &some_int, "some val-int"),
334                     Flag::CreateFlag("some_val", &some_bool, "some val-bool")});
335 
336   EXPECT_FALSE(parsed_ok);
337   EXPECT_EQ(20, some_int);
338   EXPECT_TRUE(some_bool);
339   EXPECT_EQ(argc, 1);
340 }
341 
TEST(CommandLineFlagsTest,DuplicateFlagsAndArgs)342 TEST(CommandLineFlagsTest, DuplicateFlagsAndArgs) {
343   int some_int1 = -23;
344   int some_int2 = -23;
345   int argc = 3;
346   const char* argv_strings[] = {"program_name", "--some_int=1", "--some_int=2"};
347   bool parsed_ok = Flags::Parse(
348       &argc, reinterpret_cast<const char**>(argv_strings),
349       {Flag::CreateFlag("some_int", &some_int1, "flag1: bind with some_int1"),
350        Flag::CreateFlag("some_int", &some_int2, "flag2: bind with some_int2")});
351 
352   // Note, when there're duplicate args, the flag value and the parsing result
353   // will be based on the 1st arg (i.e. --some_int=1). And both duplicate flags
354   // (i.e. flag1 and flag2) are checked, thus resulting their associated values
355   // (some_int1 and some_int2) being set to 1.
356   EXPECT_TRUE(parsed_ok);
357   EXPECT_EQ(1, some_int1);
358   EXPECT_EQ(1, some_int2);
359   EXPECT_EQ(argc, 2);
360 }
361 
TEST(CommandLineFlagsTest,ArgsToString)362 TEST(CommandLineFlagsTest, ArgsToString) {
363   int argc = 3;
364   const char* argv_strings[] = {"program_name", "--some_int=1", "--some_int=2"};
365   std::string args =
366       Flags::ArgsToString(argc, reinterpret_cast<const char**>(argv_strings));
367   EXPECT_EQ("--some_int=1 --some_int=2", args);
368 }
369 
TEST(CommandLineFlagsTest,ArgvPositions)370 TEST(CommandLineFlagsTest, ArgvPositions) {
371   tools::ToolParams params;
372   params.AddParam("some_int", tools::ToolParam::Create<int>(13));
373   params.AddParam("some_float", tools::ToolParam::Create<float>(17.0f));
374   params.AddParam("some_bool", tools::ToolParam::Create<bool>(true));
375 
376   const char* argv_strings[] = {"program_name", "--some_float=42.0",
377                                 "--some_bool=false", "--some_int=5"};
378   int argc = 4;
379   tools::ToolParams* const params_ptr = &params;
380   bool parsed_ok = Flags::Parse(
381       &argc, reinterpret_cast<const char**>(argv_strings),
382       {
383           Flag(
384               "some_int",
385               // NOLINT because of needing templating both trivial and complex
386               // types for a Flag.
387               [params_ptr](const int& val, int argv_position) {  // NOLINT
388                 params_ptr->Set<int>("some_int", val, argv_position);
389               },
390               13, "some int", Flag::kOptional),
391           Flag(
392               "some_float",
393               [params_ptr](const float& val, int argv_position) {  // NOLINT
394                 params_ptr->Set<float>("some_float", val, argv_position);
395               },
396               17.0f, "some float", Flag::kOptional),
397           Flag(
398               "some_bool",
399               [params_ptr](const bool& val, int argv_position) {  // NOLINT
400                 params_ptr->Set<bool>("some_bool", val, argv_position);
401               },
402               true, "some bool", Flag::kOptional),
403       });
404 
405   EXPECT_TRUE(parsed_ok);
406   EXPECT_EQ(5, params.Get<int>("some_int"));
407   EXPECT_NEAR(42.0f, params.Get<float>("some_float"), 1e-5f);
408   EXPECT_FALSE(params.Get<bool>("some_bool"));
409 
410   // The position of a parameter depends on the ordering of the associated flag
411   // specfied in the argv (i.e. 'argv_strings' above), not as the ordering of
412   // the flag in the flag list that's passed to Flags::Parse above.
413   EXPECT_EQ(3, params.GetPosition<int>("some_int"));
414   EXPECT_EQ(1, params.GetPosition<float>("some_float"));
415   EXPECT_EQ(2, params.GetPosition<bool>("some_bool"));
416 
417   EXPECT_EQ(argc, 1);
418 }
419 
420 }  // namespace
421 }  // namespace tflite
422