1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5 http://www.apache.org/licenses/LICENSE-2.0
6 Unless required by applicable law or agreed to in writing, software
7 distributed under the License is distributed on an "AS IS" BASIS,
8 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 See the License for the specific language governing permissions and
10 limitations under the License.
11 ==============================================================================*/
12
13 #include "tensorflow/lite/tools/command_line_flags.h"
14
15 #include <algorithm>
16 #include <cstring>
17 #include <iomanip>
18 #include <numeric>
19 #include <sstream>
20 #include <string>
21 #include <unordered_map>
22 #include <utility>
23 #include <vector>
24
25 #include "tensorflow/lite/tools/logging.h"
26
27 namespace tflite {
28 namespace {
29
30 template <typename T>
ToString(T val)31 std::string ToString(T val) {
32 std::ostringstream stream;
33 stream << val;
34 return stream.str();
35 }
36
ParseFlag(const std::string & arg,const std::string & flag,bool positional,const std::function<bool (const std::string &)> & parse_func,bool * value_parsing_ok)37 bool ParseFlag(const std::string& arg, const std::string& flag, bool positional,
38 const std::function<bool(const std::string&)>& parse_func,
39 bool* value_parsing_ok) {
40 if (positional) {
41 *value_parsing_ok = parse_func(arg);
42 return true;
43 }
44 *value_parsing_ok = true;
45 std::string flag_prefix = "--" + flag + "=";
46 if (arg.find(flag_prefix) != 0) {
47 return false;
48 }
49 bool has_value = arg.size() >= flag_prefix.size();
50 *value_parsing_ok = has_value;
51 if (has_value) {
52 *value_parsing_ok = parse_func(arg.substr(flag_prefix.size()));
53 }
54 return true;
55 }
56
57 template <typename T>
ParseFlag(const std::string & flag_value,const std::function<void (const T &)> & hook)58 bool ParseFlag(const std::string& flag_value,
59 const std::function<void(const T&)>& hook) {
60 std::istringstream stream(flag_value);
61 T read_value;
62 stream >> read_value;
63 if (!stream.eof() && !stream.good()) {
64 return false;
65 }
66 hook(read_value);
67 return true;
68 }
69
ParseBoolFlag(const std::string & flag_value,const std::function<void (const bool &)> & hook)70 bool ParseBoolFlag(const std::string& flag_value,
71 const std::function<void(const bool&)>& hook) {
72 if (flag_value != "true" && flag_value != "false" && flag_value != "0" &&
73 flag_value != "1") {
74 return false;
75 }
76
77 hook(flag_value == "true" || flag_value == "1");
78 return true;
79 }
80 } // namespace
81
Flag(const char * name,const std::function<void (const int32_t &)> & hook,int32_t default_value,const std::string & usage_text,FlagType flag_type)82 Flag::Flag(const char* name, const std::function<void(const int32_t&)>& hook,
83 int32_t default_value, const std::string& usage_text,
84 FlagType flag_type)
85 : name_(name),
86 type_(TYPE_INT32),
87 value_hook_([hook](const std::string& flag_value) {
88 return ParseFlag<int32_t>(flag_value, hook);
89 }),
90 default_for_display_(ToString(default_value)),
91 usage_text_(usage_text),
92 flag_type_(flag_type) {}
93
Flag(const char * name,const std::function<void (const int64_t &)> & hook,int64_t default_value,const std::string & usage_text,FlagType flag_type)94 Flag::Flag(const char* name, const std::function<void(const int64_t&)>& hook,
95 int64_t default_value, const std::string& usage_text,
96 FlagType flag_type)
97 : name_(name),
98 type_(TYPE_INT64),
99 value_hook_([hook](const std::string& flag_value) {
100 return ParseFlag<int64_t>(flag_value, hook);
101 }),
102 default_for_display_(ToString(default_value)),
103 usage_text_(usage_text),
104 flag_type_(flag_type) {}
105
Flag(const char * name,const std::function<void (const float &)> & hook,float default_value,const std::string & usage_text,FlagType flag_type)106 Flag::Flag(const char* name, const std::function<void(const float&)>& hook,
107 float default_value, const std::string& usage_text,
108 FlagType flag_type)
109 : name_(name),
110 type_(TYPE_FLOAT),
111 value_hook_([hook](const std::string& flag_value) {
112 return ParseFlag<float>(flag_value, hook);
113 }),
114 default_for_display_(ToString(default_value)),
115 usage_text_(usage_text),
116 flag_type_(flag_type) {}
117
Flag(const char * name,const std::function<void (const bool &)> & hook,bool default_value,const std::string & usage_text,FlagType flag_type)118 Flag::Flag(const char* name, const std::function<void(const bool&)>& hook,
119 bool default_value, const std::string& usage_text,
120 FlagType flag_type)
121 : name_(name),
122 type_(TYPE_BOOL),
123 value_hook_([hook](const std::string& flag_value) {
124 return ParseBoolFlag(flag_value, hook);
125 }),
126 default_for_display_(default_value ? "true" : "false"),
127 usage_text_(usage_text),
128 flag_type_(flag_type) {}
129
Flag(const char * name,const std::function<void (const std::string &)> & hook,const std::string & default_value,const std::string & usage_text,FlagType flag_type)130 Flag::Flag(const char* name,
131 const std::function<void(const std::string&)>& hook,
132 const std::string& default_value, const std::string& usage_text,
133 FlagType flag_type)
134 : name_(name),
135 type_(TYPE_STRING),
136 value_hook_([hook](const std::string& flag_value) {
137 hook(flag_value);
138 return true;
139 }),
140 default_for_display_(default_value),
141 usage_text_(usage_text),
142 flag_type_(flag_type) {}
143
Parse(const std::string & arg,bool * value_parsing_ok) const144 bool Flag::Parse(const std::string& arg, bool* value_parsing_ok) const {
145 return ParseFlag(arg, name_, flag_type_ == kPositional, value_hook_,
146 value_parsing_ok);
147 }
148
GetTypeName() const149 std::string Flag::GetTypeName() const {
150 switch (type_) {
151 case TYPE_INT32:
152 return "int32";
153 case TYPE_INT64:
154 return "int64";
155 case TYPE_FLOAT:
156 return "float";
157 case TYPE_BOOL:
158 return "bool";
159 case TYPE_STRING:
160 return "string";
161 }
162
163 return "unknown";
164 }
165
Parse(int * argc,const char ** argv,const std::vector<Flag> & flag_list)166 /*static*/ bool Flags::Parse(int* argc, const char** argv,
167 const std::vector<Flag>& flag_list) {
168 bool result = true;
169 std::vector<bool> unknown_argvs(*argc, true);
170 // Record the list of flags that have been processed. key is the flag's name
171 // and the value is the corresponding argv index if there's one, or -1 when
172 // the argv list doesn't contain this flag.
173 std::unordered_map<std::string, int> processed_flags;
174
175 // Stores indexes of flag_list in a sorted order.
176 std::vector<int> sorted_idx(flag_list.size());
177 std::iota(std::begin(sorted_idx), std::end(sorted_idx), 0);
178 std::sort(sorted_idx.begin(), sorted_idx.end(), [&flag_list](int a, int b) {
179 return flag_list[a].GetFlagType() < flag_list[b].GetFlagType();
180 });
181 int positional_count = 0;
182
183 for (int idx = 0; idx < sorted_idx.size(); ++idx) {
184 const Flag& flag = flag_list[sorted_idx[idx]];
185
186 const auto it = processed_flags.find(flag.name_);
187 if (it != processed_flags.end()) {
188 #ifndef NDEBUG
189 // Only log this in debug builds.
190 TFLITE_LOG(WARN) << "Duplicate flags: " << flag.name_;
191 #endif
192 if (it->second != -1) {
193 bool value_parsing_ok;
194 flag.Parse(argv[it->second], &value_parsing_ok);
195 if (!value_parsing_ok) {
196 TFLITE_LOG(ERROR) << "Failed to parse flag '" << flag.name_
197 << "' against argv '" << argv[it->second] << "'";
198 result = false;
199 }
200 continue;
201 } else if (flag.flag_type_ == Flag::kRequired) {
202 TFLITE_LOG(ERROR) << "Required flag not provided: " << flag.name_;
203 // If the required flag isn't found, we immediately stop the whole flag
204 // parsing.
205 result = false;
206 break;
207 }
208 }
209
210 // Parses positional flags.
211 if (flag.flag_type_ == Flag::kPositional) {
212 if (++positional_count >= *argc) {
213 TFLITE_LOG(ERROR) << "Too few command line arguments.";
214 return false;
215 }
216 bool value_parsing_ok;
217 flag.Parse(argv[positional_count], &value_parsing_ok);
218 if (!value_parsing_ok) {
219 TFLITE_LOG(ERROR) << "Failed to parse positional flag: " << flag.name_;
220 return false;
221 }
222 unknown_argvs[positional_count] = false;
223 processed_flags[flag.name_] = positional_count;
224 continue;
225 }
226
227 // Parse other flags.
228 bool was_found = false;
229 for (int i = positional_count + 1; i < *argc; ++i) {
230 if (!unknown_argvs[i]) continue;
231 bool value_parsing_ok;
232 was_found = flag.Parse(argv[i], &value_parsing_ok);
233 if (!value_parsing_ok) {
234 TFLITE_LOG(ERROR) << "Failed to parse flag '" << flag.name_
235 << "' against argv '" << argv[i] << "'";
236 result = false;
237 }
238 if (was_found) {
239 unknown_argvs[i] = false;
240 processed_flags[flag.name_] = i;
241 break;
242 }
243 }
244
245 // If the flag is found from the argv (i.e. the flag name appears in argv),
246 // continue to the next flag parsing.
247 if (was_found) continue;
248
249 // The flag isn't found, do some bookkeeping work.
250 processed_flags[flag.name_] = -1;
251 if (flag.flag_type_ == Flag::kRequired) {
252 TFLITE_LOG(ERROR) << "Required flag not provided: " << flag.name_;
253 result = false;
254 // If the required flag isn't found, we immediately stop the whole flag
255 // parsing by breaking the outer-loop (i.e. the 'sorted_idx'-iteration
256 // loop).
257 break;
258 }
259 }
260
261 int dst = 1; // Skip argv[0]
262 for (int i = 1; i < *argc; ++i) {
263 if (unknown_argvs[i]) {
264 argv[dst++] = argv[i];
265 }
266 }
267 *argc = dst;
268 return result && (*argc < 2 || std::strcmp(argv[1], "--help") != 0);
269 }
270
Usage(const std::string & cmdline,const std::vector<Flag> & flag_list)271 /*static*/ std::string Flags::Usage(const std::string& cmdline,
272 const std::vector<Flag>& flag_list) {
273 // Stores indexes of flag_list in a sorted order.
274 std::vector<int> sorted_idx(flag_list.size());
275 std::iota(std::begin(sorted_idx), std::end(sorted_idx), 0);
276 std::sort(sorted_idx.begin(), sorted_idx.end(), [&flag_list](int a, int b) {
277 return flag_list[a].GetFlagType() < flag_list[b].GetFlagType();
278 });
279 // Counts number of positional flags will be shown.
280 int positional_count = 0;
281 std::ostringstream usage_text;
282 usage_text << "usage: " << cmdline;
283 // Prints usage for positional flag.
284 for (int i = 0; i < sorted_idx.size(); ++i) {
285 const Flag& flag = flag_list[sorted_idx[i]];
286 if (flag.flag_type_ == Flag::kPositional) {
287 positional_count++;
288 usage_text << " <" << flag.name_ << ">";
289 } else {
290 usage_text << " <flags>";
291 break;
292 }
293 }
294 usage_text << "\n";
295
296 // Finds the max number of chars of the name column in the usage message.
297 int max_name_width = 0;
298 std::vector<std::string> name_column(flag_list.size());
299 for (int i = 0; i < sorted_idx.size(); ++i) {
300 const Flag& flag = flag_list[sorted_idx[i]];
301 if (flag.flag_type_ != Flag::kPositional) {
302 name_column[i] += "--";
303 name_column[i] += flag.name_;
304 name_column[i] += "=";
305 name_column[i] += flag.default_for_display_;
306 } else {
307 name_column[i] += flag.name_;
308 }
309 if (name_column[i].size() > max_name_width) {
310 max_name_width = name_column[i].size();
311 }
312 }
313
314 if (positional_count > 0) {
315 usage_text << "Where:\n";
316 }
317 for (int i = 0; i < sorted_idx.size(); ++i) {
318 const Flag& flag = flag_list[sorted_idx[i]];
319 if (i == positional_count) {
320 usage_text << "Flags:\n";
321 }
322 auto type_name = flag.GetTypeName();
323 usage_text << "\t";
324 usage_text << std::left << std::setw(max_name_width) << name_column[i];
325 usage_text << "\t" << type_name << "\t";
326 usage_text << (flag.flag_type_ != Flag::kOptional ? "required"
327 : "optional");
328 usage_text << "\t" << flag.usage_text_ << "\n";
329 }
330 return usage_text.str();
331 }
332
ArgsToString(int argc,const char ** argv)333 /*static*/ std::string Flags::ArgsToString(int argc, const char** argv) {
334 std::string args;
335 for (int i = 1; i < argc; ++i) {
336 args.append(argv[i]);
337 if (i != argc - 1) args.append(" ");
338 }
339 return args;
340 }
341
342 } // namespace tflite
343