1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/tools/command_line_flags.h"
17
18 #include <string>
19
20 #include <gmock/gmock.h>
21 #include <gtest/gtest.h>
22 #include "tensorflow/lite/tools/tool_params.h"
23
24 namespace tflite {
25 namespace {
26
TEST(CommandLineFlagsTest,BasicUsage)27 TEST(CommandLineFlagsTest, BasicUsage) {
28 int some_int32 = 10;
29 int some_int1 = 8; // Not provided via arguments, the value should remain.
30 int some_int2 = 9; // Required flag.
31 int64_t some_int64 = 21474836470; // max int32 is 2147483647
32 bool some_switch = false;
33 std::string some_name = "something_a";
34 float some_float = -23.23f;
35 float float_1 = -23.23f; // positional flag.
36 bool some_bool = false;
37 bool some_numeric_bool = true;
38 const char* argv_strings[] = {"program_name",
39 "12.2",
40 "--some_int32=20",
41 "--some_int2=5",
42 "--some_int64=214748364700",
43 "--some_switch=true",
44 "--some_name=somethingelse",
45 "--some_float=42.0",
46 "--some_bool=true",
47 "--some_numeric_bool=0"};
48 int argc = 10;
49 bool parsed_ok = Flags::Parse(
50 &argc, reinterpret_cast<const char**>(argv_strings),
51 {
52 Flag::CreateFlag("some_int32", &some_int32, "some int32"),
53 Flag::CreateFlag("some_int64", &some_int64, "some int64"),
54 Flag::CreateFlag("some_switch", &some_switch, "some switch"),
55 Flag::CreateFlag("some_name", &some_name, "some name"),
56 Flag::CreateFlag("some_float", &some_float, "some float"),
57 Flag::CreateFlag("some_bool", &some_bool, "some bool"),
58 Flag::CreateFlag("some_numeric_bool", &some_numeric_bool,
59 "some numeric bool"),
60 Flag::CreateFlag("some_int1", &some_int1, "some int"),
61 Flag::CreateFlag("some_int2", &some_int2, "some int",
62 Flag::kRequired),
63 Flag::CreateFlag("float_1", &float_1, "some float",
64 Flag::kPositional),
65 });
66
67 EXPECT_TRUE(parsed_ok);
68 EXPECT_EQ(20, some_int32);
69 EXPECT_EQ(8, some_int1);
70 EXPECT_EQ(5, some_int2);
71 EXPECT_EQ(214748364700, some_int64);
72 EXPECT_TRUE(some_switch);
73 EXPECT_EQ("somethingelse", some_name);
74 EXPECT_NEAR(42.0f, some_float, 1e-5f);
75 EXPECT_NEAR(12.2f, float_1, 1e-5f);
76 EXPECT_TRUE(some_bool);
77 EXPECT_FALSE(some_numeric_bool);
78 EXPECT_EQ(argc, 1);
79 }
80
TEST(CommandLineFlagsTest,EmptyStringFlag)81 TEST(CommandLineFlagsTest, EmptyStringFlag) {
82 int argc = 2;
83 std::string some_string = "invalid";
84 const char* argv_strings[] = {"program_name", "--some_string="};
85 bool parsed_ok = Flags::Parse(
86 &argc, reinterpret_cast<const char**>(argv_strings),
87 {Flag::CreateFlag("some_string", &some_string, "some string")});
88
89 EXPECT_TRUE(parsed_ok);
90 EXPECT_EQ(some_string, "");
91 EXPECT_EQ(argc, 1);
92 }
93
TEST(CommandLineFlagsTest,BadIntValue)94 TEST(CommandLineFlagsTest, BadIntValue) {
95 int some_int = 10;
96 int argc = 2;
97 const char* argv_strings[] = {"program_name", "--some_int=notanumber"};
98 bool parsed_ok =
99 Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
100 {Flag::CreateFlag("some_int", &some_int, "some int")});
101
102 EXPECT_FALSE(parsed_ok);
103 EXPECT_EQ(10, some_int);
104 EXPECT_EQ(argc, 1);
105 }
106
TEST(CommandLineFlagsTest,BadBoolValue)107 TEST(CommandLineFlagsTest, BadBoolValue) {
108 bool some_switch = false;
109 int argc = 2;
110 const char* argv_strings[] = {"program_name", "--some_switch=notabool"};
111 bool parsed_ok = Flags::Parse(
112 &argc, reinterpret_cast<const char**>(argv_strings),
113 {Flag::CreateFlag("some_switch", &some_switch, "some switch")});
114
115 EXPECT_FALSE(parsed_ok);
116 EXPECT_FALSE(some_switch);
117 EXPECT_EQ(argc, 1);
118 }
119
TEST(CommandLineFlagsTest,BadFloatValue)120 TEST(CommandLineFlagsTest, BadFloatValue) {
121 float some_float = -23.23f;
122 int argc = 2;
123 const char* argv_strings[] = {"program_name", "--some_float=notanumber"};
124 bool parsed_ok =
125 Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
126 {Flag::CreateFlag("some_float", &some_float, "some float")});
127
128 EXPECT_FALSE(parsed_ok);
129 EXPECT_NEAR(-23.23f, some_float, 1e-5f);
130 EXPECT_EQ(argc, 1);
131 }
132
TEST(CommandLineFlagsTest,RequiredFlagNotFound)133 TEST(CommandLineFlagsTest, RequiredFlagNotFound) {
134 float some_float = -23.23f;
135 int argc = 2;
136 const char* argv_strings[] = {"program_name", "--flag=12"};
137 bool parsed_ok = Flags::Parse(
138 &argc, reinterpret_cast<const char**>(argv_strings),
139 {Flag::CreateFlag("some_flag", &some_float, "", Flag::kRequired)});
140
141 EXPECT_FALSE(parsed_ok);
142 EXPECT_NEAR(-23.23f, some_float, 1e-5f);
143 EXPECT_EQ(argc, 2);
144 }
145
TEST(CommandLineFlagsTest,NoArguments)146 TEST(CommandLineFlagsTest, NoArguments) {
147 float some_float = -23.23f;
148 int argc = 1;
149 const char* argv_strings[] = {"program_name"};
150 bool parsed_ok = Flags::Parse(
151 &argc, reinterpret_cast<const char**>(argv_strings),
152 {Flag::CreateFlag("some_flag", &some_float, "", Flag::kRequired)});
153
154 EXPECT_FALSE(parsed_ok);
155 EXPECT_NEAR(-23.23f, some_float, 1e-5f);
156 EXPECT_EQ(argc, 1);
157 }
158
TEST(CommandLineFlagsTest,NotEnoughArguments)159 TEST(CommandLineFlagsTest, NotEnoughArguments) {
160 float some_float = -23.23f;
161 int argc = 1;
162 const char* argv_strings[] = {"program_name"};
163 bool parsed_ok = Flags::Parse(
164 &argc, reinterpret_cast<const char**>(argv_strings),
165 {Flag::CreateFlag("some_flag", &some_float, "", Flag::kPositional)});
166
167 EXPECT_FALSE(parsed_ok);
168 EXPECT_NEAR(-23.23f, some_float, 1e-5f);
169 EXPECT_EQ(argc, 1);
170 }
171
TEST(CommandLineFlagsTest,PositionalFlagFailed)172 TEST(CommandLineFlagsTest, PositionalFlagFailed) {
173 float some_float = -23.23f;
174 int argc = 2;
175 const char* argv_strings[] = {"program_name", "string"};
176 bool parsed_ok = Flags::Parse(
177 &argc, reinterpret_cast<const char**>(argv_strings),
178 {Flag::CreateFlag("some_flag", &some_float, "", Flag::kPositional)});
179
180 EXPECT_FALSE(parsed_ok);
181 EXPECT_NEAR(-23.23f, some_float, 1e-5f);
182 EXPECT_EQ(argc, 2);
183 }
184
185 // Return whether str==pat, but allowing any whitespace in pat
186 // to match zero or more whitespace characters in str.
MatchWithAnyWhitespace(const std::string & str,const std::string & pat)187 static bool MatchWithAnyWhitespace(const std::string& str,
188 const std::string& pat) {
189 bool matching = true;
190 int pat_i = 0;
191 for (int str_i = 0; str_i != str.size() && matching; str_i++) {
192 if (isspace(str[str_i])) {
193 matching = (pat_i != pat.size() && isspace(pat[pat_i]));
194 } else {
195 while (pat_i != pat.size() && isspace(pat[pat_i])) {
196 pat_i++;
197 }
198 matching = (pat_i != pat.size() && str[str_i] == pat[pat_i++]);
199 }
200 }
201 while (pat_i != pat.size() && isspace(pat[pat_i])) {
202 pat_i++;
203 }
204 return (matching && pat_i == pat.size());
205 }
206
TEST(CommandLineFlagsTest,UsageString)207 TEST(CommandLineFlagsTest, UsageString) {
208 int some_int = 10;
209 int64_t some_int64 = 21474836470; // max int32 is 2147483647
210 bool some_switch = false;
211 std::string some_name = "something";
212 int some_int2 = 4;
213 // Don't test float in this case, because precision is hard to predict and
214 // match against, and we don't want a flakey test.
215 const std::string tool_name = "some_tool_name";
216 std::string usage = Flags::Usage(
217 tool_name,
218 {Flag::CreateFlag("some_int", &some_int, "some int"),
219 Flag::CreateFlag("some_int64", &some_int64, "some int64"),
220 Flag::CreateFlag("some_switch", &some_switch, "some switch"),
221 Flag::CreateFlag("some_name", &some_name, "some name", Flag::kRequired),
222 Flag::CreateFlag("some_int2", &some_int2, "some int",
223 Flag::kPositional)});
224 // Match the usage message, being sloppy about whitespace.
225 const char* expected_usage =
226 " usage: some_tool_name <some_int2> <flags>\n"
227 "Where:\n"
228 "some_int2\tint32\trequired\tsome int\n"
229 "Flags:\n"
230 "--some_name=something\tstring\trequired\tsome name\n"
231 "--some_int=10\tint32\toptional\tsome int\n"
232 "--some_int64=21474836470\tint64\toptional\tsome int64\n"
233 "--some_switch=false\tbool\toptional\tsome switch\n";
234 ASSERT_EQ(MatchWithAnyWhitespace(usage, expected_usage), true) << usage;
235
236 // Again but with no flags.
237 usage = Flags::Usage(tool_name, {});
238 ASSERT_EQ(MatchWithAnyWhitespace(usage, " usage: some_tool_name\n"), true)
239 << usage;
240 }
241
242 // When there are duplicate args, the flag value and the parsing result will be
243 // based on the 1st arg.
TEST(CommandLineFlagsTest,DuplicateArgsParsableValues)244 TEST(CommandLineFlagsTest, DuplicateArgsParsableValues) {
245 int some_int = -23;
246 int argc = 3;
247 const char* argv_strings[] = {"program_name", "--some_int=1", "--some_int=2"};
248 bool parsed_ok =
249 Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
250 {Flag::CreateFlag("some_int", &some_int, "some int")});
251
252 EXPECT_TRUE(parsed_ok);
253 EXPECT_EQ(1, some_int);
254 EXPECT_EQ(argc, 2);
255 EXPECT_EQ("--some_int=2", argv_strings[1]);
256 }
257
TEST(CommandLineFlagsTest,DuplicateArgsBadValueAppearFirst)258 TEST(CommandLineFlagsTest, DuplicateArgsBadValueAppearFirst) {
259 int some_int = -23;
260 int argc = 3;
261 const char* argv_strings[] = {"program_name", "--some_int=value",
262 "--some_int=1"};
263 bool parsed_ok =
264 Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
265 {Flag::CreateFlag("some_int", &some_int, "some int")});
266
267 EXPECT_FALSE(parsed_ok);
268 EXPECT_EQ(-23, some_int);
269 EXPECT_EQ(argc, 2);
270 EXPECT_EQ("--some_int=1", argv_strings[1]);
271 }
272
TEST(CommandLineFlagsTest,DuplicateArgsBadValueAppearSecondly)273 TEST(CommandLineFlagsTest, DuplicateArgsBadValueAppearSecondly) {
274 int some_int = -23;
275 int argc = 3;
276 // Although the 2nd arg has non-parsable int value, the flag 'some_int' value
277 // could still be set and the parsing result is ok.
278 const char* argv_strings[] = {"program_name", "--some_int=1",
279 "--some_int=value"};
280 bool parsed_ok =
281 Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
282 {Flag::CreateFlag("some_int", &some_int, "some int")});
283
284 EXPECT_TRUE(parsed_ok);
285 EXPECT_EQ(1, some_int);
286 EXPECT_EQ(argc, 2);
287 EXPECT_EQ("--some_int=value", argv_strings[1]);
288 }
289
290 // When there are duplicate flags, all of them will be checked against the arg
291 // list.
TEST(CommandLineFlagsTest,DuplicateFlags)292 TEST(CommandLineFlagsTest, DuplicateFlags) {
293 int some_int1 = -23;
294 int some_int2 = -23;
295 int argc = 2;
296 const char* argv_strings[] = {"program_name", "--some_int=1"};
297 bool parsed_ok =
298 Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
299 {Flag::CreateFlag("some_int", &some_int1, "some int1"),
300 Flag::CreateFlag("some_int", &some_int2, "some int2")});
301
302 EXPECT_TRUE(parsed_ok);
303 EXPECT_EQ(1, some_int1);
304 EXPECT_EQ(1, some_int2);
305 EXPECT_EQ(argc, 1);
306 }
307
TEST(CommandLineFlagsTest,DuplicateFlagsNotFound)308 TEST(CommandLineFlagsTest, DuplicateFlagsNotFound) {
309 int some_int1 = -23;
310 int some_int2 = -23;
311 int argc = 2;
312 const char* argv_strings[] = {"program_name", "--some_float=1.0"};
313 bool parsed_ok = Flags::Parse(
314 &argc, reinterpret_cast<const char**>(argv_strings),
315 {Flag::CreateFlag("some_int", &some_int1, "some int1", Flag::kOptional),
316 Flag::CreateFlag("some_int", &some_int2, "some int2", Flag::kRequired)});
317
318 EXPECT_FALSE(parsed_ok);
319 EXPECT_EQ(-23, some_int1);
320 EXPECT_EQ(-23, some_int2);
321 EXPECT_EQ(argc, 2);
322 }
323
TEST(CommandLineFlagsTest,DuplicateFlagNamesButDifferentTypes)324 TEST(CommandLineFlagsTest, DuplicateFlagNamesButDifferentTypes) {
325 int some_int = -23;
326 bool some_bool = true;
327 int argc = 2;
328 const char* argv_strings[] = {"program_name", "--some_val=20"};
329 // In this case, the 2nd 'some_val' flag of bool type will cause a no-ok
330 // parsing result.
331 bool parsed_ok =
332 Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
333 {Flag::CreateFlag("some_val", &some_int, "some val-int"),
334 Flag::CreateFlag("some_val", &some_bool, "some val-bool")});
335
336 EXPECT_FALSE(parsed_ok);
337 EXPECT_EQ(20, some_int);
338 EXPECT_TRUE(some_bool);
339 EXPECT_EQ(argc, 1);
340 }
341
TEST(CommandLineFlagsTest,DuplicateFlagsAndArgs)342 TEST(CommandLineFlagsTest, DuplicateFlagsAndArgs) {
343 int some_int1 = -23;
344 int some_int2 = -23;
345 int argc = 3;
346 const char* argv_strings[] = {"program_name", "--some_int=1", "--some_int=2"};
347 bool parsed_ok = Flags::Parse(
348 &argc, reinterpret_cast<const char**>(argv_strings),
349 {Flag::CreateFlag("some_int", &some_int1, "flag1: bind with some_int1"),
350 Flag::CreateFlag("some_int", &some_int2, "flag2: bind with some_int2")});
351
352 // Note, when there're duplicate args, the flag value and the parsing result
353 // will be based on the 1st arg (i.e. --some_int=1). And both duplicate flags
354 // (i.e. flag1 and flag2) are checked, thus resulting their associated values
355 // (some_int1 and some_int2) being set to 1.
356 EXPECT_TRUE(parsed_ok);
357 EXPECT_EQ(1, some_int1);
358 EXPECT_EQ(1, some_int2);
359 EXPECT_EQ(argc, 2);
360 }
361
TEST(CommandLineFlagsTest,ArgsToString)362 TEST(CommandLineFlagsTest, ArgsToString) {
363 int argc = 3;
364 const char* argv_strings[] = {"program_name", "--some_int=1", "--some_int=2"};
365 std::string args =
366 Flags::ArgsToString(argc, reinterpret_cast<const char**>(argv_strings));
367 EXPECT_EQ("--some_int=1 --some_int=2", args);
368 }
369
TEST(CommandLineFlagsTest,ArgvPositions)370 TEST(CommandLineFlagsTest, ArgvPositions) {
371 tools::ToolParams params;
372 params.AddParam("some_int", tools::ToolParam::Create<int>(13));
373 params.AddParam("some_float", tools::ToolParam::Create<float>(17.0f));
374 params.AddParam("some_bool", tools::ToolParam::Create<bool>(true));
375
376 const char* argv_strings[] = {"program_name", "--some_float=42.0",
377 "--some_bool=false", "--some_int=5"};
378 int argc = 4;
379 tools::ToolParams* const params_ptr = ¶ms;
380 bool parsed_ok = Flags::Parse(
381 &argc, reinterpret_cast<const char**>(argv_strings),
382 {
383 Flag(
384 "some_int",
385 // NOLINT because of needing templating both trivial and complex
386 // types for a Flag.
387 [params_ptr](const int& val, int argv_position) { // NOLINT
388 params_ptr->Set<int>("some_int", val, argv_position);
389 },
390 13, "some int", Flag::kOptional),
391 Flag(
392 "some_float",
393 [params_ptr](const float& val, int argv_position) { // NOLINT
394 params_ptr->Set<float>("some_float", val, argv_position);
395 },
396 17.0f, "some float", Flag::kOptional),
397 Flag(
398 "some_bool",
399 [params_ptr](const bool& val, int argv_position) { // NOLINT
400 params_ptr->Set<bool>("some_bool", val, argv_position);
401 },
402 true, "some bool", Flag::kOptional),
403 });
404
405 EXPECT_TRUE(parsed_ok);
406 EXPECT_EQ(5, params.Get<int>("some_int"));
407 EXPECT_NEAR(42.0f, params.Get<float>("some_float"), 1e-5f);
408 EXPECT_FALSE(params.Get<bool>("some_bool"));
409
410 // The position of a parameter depends on the ordering of the associated flag
411 // specfied in the argv (i.e. 'argv_strings' above), not as the ordering of
412 // the flag in the flag list that's passed to Flags::Parse above.
413 EXPECT_EQ(3, params.GetPosition<int>("some_int"));
414 EXPECT_EQ(1, params.GetPosition<float>("some_float"));
415 EXPECT_EQ(2, params.GetPosition<bool>("some_bool"));
416
417 EXPECT_EQ(argc, 1);
418 }
419
420 } // namespace
421 } // namespace tflite
422