1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5 http://www.apache.org/licenses/LICENSE-2.0
6 Unless required by applicable law or agreed to in writing, software
7 distributed under the License is distributed on an "AS IS" BASIS,
8 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 See the License for the specific language governing permissions and
10 limitations under the License.
11 ==============================================================================*/
12
13 #include "tensorflow/lite/tools/command_line_flags.h"
14
15 #include <algorithm>
16 #include <cstring>
17 #include <functional>
18 #include <iomanip>
19 #include <numeric>
20 #include <sstream>
21 #include <string>
22 #include <unordered_map>
23 #include <utility>
24 #include <vector>
25
26 #include "absl/strings/match.h"
27 #include "tensorflow/lite/tools/logging.h"
28
29 namespace tflite {
30 namespace {
31
32 template <typename T>
ToString(T val)33 std::string ToString(T val) {
34 std::ostringstream stream;
35 stream << val;
36 return stream.str();
37 }
38
39 template <>
ToString(bool val)40 std::string ToString(bool val) {
41 return val ? "true" : "false";
42 }
43
44 template <>
ToString(const std::string & val)45 std::string ToString(const std::string& val) {
46 return val;
47 }
48
ParseFlag(const std::string & arg,int argv_position,const std::string & flag,bool positional,const std::function<bool (const std::string &,int argv_position)> & parse_func,bool * value_parsing_ok)49 bool ParseFlag(const std::string& arg, int argv_position,
50 const std::string& flag, bool positional,
51 const std::function<bool(const std::string&, int argv_position)>&
52 parse_func,
53 bool* value_parsing_ok) {
54 if (positional) {
55 *value_parsing_ok = parse_func(arg, argv_position);
56 return true;
57 }
58 *value_parsing_ok = true;
59 std::string flag_prefix = "--" + flag + "=";
60 if (!absl::StartsWith(arg, flag_prefix)) {
61 return false;
62 }
63 bool has_value = arg.size() >= flag_prefix.size();
64 *value_parsing_ok = has_value;
65 if (has_value) {
66 *value_parsing_ok =
67 parse_func(arg.substr(flag_prefix.size()), argv_position);
68 }
69 return true;
70 }
71
72 template <typename T>
ParseFlag(const std::string & flag_value,int argv_position,const std::function<void (const T &,int)> & hook)73 bool ParseFlag(const std::string& flag_value, int argv_position,
74 const std::function<void(const T&, int)>& hook) {
75 std::istringstream stream(flag_value);
76 T read_value;
77 stream >> read_value;
78 if (!stream.eof() && !stream.good()) {
79 return false;
80 }
81 hook(read_value, argv_position);
82 return true;
83 }
84
85 template <>
ParseFlag(const std::string & flag_value,int argv_position,const std::function<void (const bool &,int)> & hook)86 bool ParseFlag(const std::string& flag_value, int argv_position,
87 const std::function<void(const bool&, int)>& hook) {
88 if (flag_value != "true" && flag_value != "false" && flag_value != "0" &&
89 flag_value != "1") {
90 return false;
91 }
92
93 hook(flag_value == "true" || flag_value == "1", argv_position);
94 return true;
95 }
96
97 template <typename T>
ParseFlag(const std::string & flag_value,int argv_position,const std::function<void (const std::string &,int)> & hook)98 bool ParseFlag(const std::string& flag_value, int argv_position,
99 const std::function<void(const std::string&, int)>& hook) {
100 hook(flag_value, argv_position);
101 return true;
102 }
103 } // namespace
104
105 #define CONSTRUCTOR_IMPLEMENTATION(flag_T, default_value_T, flag_enum_val) \
106 Flag::Flag(const char* name, \
107 const std::function<void(const flag_T& /*flag_val*/, \
108 int /*argv_position*/)>& hook, \
109 default_value_T default_value, const std::string& usage_text, \
110 FlagType flag_type) \
111 : name_(name), \
112 type_(flag_enum_val), \
113 value_hook_([hook](const std::string& flag_value, int argv_position) { \
114 return ParseFlag<flag_T>(flag_value, argv_position, hook); \
115 }), \
116 default_for_display_(ToString<default_value_T>(default_value)), \
117 usage_text_(usage_text), \
118 flag_type_(flag_type) {}
119
CONSTRUCTOR_IMPLEMENTATION(int32_t,int32_t,TYPE_INT32)120 CONSTRUCTOR_IMPLEMENTATION(int32_t, int32_t, TYPE_INT32)
121 CONSTRUCTOR_IMPLEMENTATION(int64_t, int64_t, TYPE_INT64)
122 CONSTRUCTOR_IMPLEMENTATION(float, float, TYPE_FLOAT)
123 CONSTRUCTOR_IMPLEMENTATION(bool, bool, TYPE_BOOL)
124 CONSTRUCTOR_IMPLEMENTATION(std::string, const std::string&, TYPE_STRING)
125
126 #undef CONSTRUCTOR_IMPLEMENTATION
127
128 bool Flag::Parse(const std::string& arg, int argv_position,
129 bool* value_parsing_ok) const {
130 return ParseFlag(
131 arg, argv_position, name_, flag_type_ == kPositional,
132 [&](const std::string& read_value, int argv_position) {
133 return value_hook_(read_value, argv_position);
134 },
135 value_parsing_ok);
136 }
137
GetTypeName() const138 std::string Flag::GetTypeName() const {
139 switch (type_) {
140 case TYPE_INT32:
141 return "int32";
142 case TYPE_INT64:
143 return "int64";
144 case TYPE_FLOAT:
145 return "float";
146 case TYPE_BOOL:
147 return "bool";
148 case TYPE_STRING:
149 return "string";
150 }
151
152 return "unknown";
153 }
154
Parse(int * argc,const char ** argv,const std::vector<Flag> & flag_list)155 /*static*/ bool Flags::Parse(int* argc, const char** argv,
156 const std::vector<Flag>& flag_list) {
157 bool result = true;
158 std::vector<bool> unknown_argvs(*argc, true);
159 // Record the list of flags that have been processed. key is the flag's name
160 // and the value is the corresponding argv index if there's one, or -1 when
161 // the argv list doesn't contain this flag.
162 std::unordered_map<std::string, int> processed_flags;
163
164 // Stores indexes of flag_list in a sorted order.
165 std::vector<int> sorted_idx(flag_list.size());
166 std::iota(std::begin(sorted_idx), std::end(sorted_idx), 0);
167 std::sort(sorted_idx.begin(), sorted_idx.end(), [&flag_list](int a, int b) {
168 return flag_list[a].GetFlagType() < flag_list[b].GetFlagType();
169 });
170 int positional_count = 0;
171
172 for (int idx = 0; idx < sorted_idx.size(); ++idx) {
173 const Flag& flag = flag_list[sorted_idx[idx]];
174
175 const auto it = processed_flags.find(flag.name_);
176 if (it != processed_flags.end()) {
177 #ifndef NDEBUG
178 // Only log this in debug builds.
179 TFLITE_LOG(WARN) << "Duplicate flags: " << flag.name_;
180 #endif
181 if (it->second != -1) {
182 bool value_parsing_ok;
183 flag.Parse(argv[it->second], it->second, &value_parsing_ok);
184 if (!value_parsing_ok) {
185 TFLITE_LOG(ERROR) << "Failed to parse flag '" << flag.name_
186 << "' against argv '" << argv[it->second] << "'";
187 result = false;
188 }
189 continue;
190 } else if (flag.flag_type_ == Flag::kRequired) {
191 TFLITE_LOG(ERROR) << "Required flag not provided: " << flag.name_;
192 // If the required flag isn't found, we immediately stop the whole flag
193 // parsing.
194 result = false;
195 break;
196 }
197 }
198
199 // Parses positional flags.
200 if (flag.flag_type_ == Flag::kPositional) {
201 if (++positional_count >= *argc) {
202 TFLITE_LOG(ERROR) << "Too few command line arguments.";
203 return false;
204 }
205 bool value_parsing_ok;
206 flag.Parse(argv[positional_count], positional_count, &value_parsing_ok);
207 if (!value_parsing_ok) {
208 TFLITE_LOG(ERROR) << "Failed to parse positional flag: " << flag.name_;
209 return false;
210 }
211 unknown_argvs[positional_count] = false;
212 processed_flags[flag.name_] = positional_count;
213 continue;
214 }
215
216 // Parse other flags.
217 bool was_found = false;
218 for (int i = positional_count + 1; i < *argc; ++i) {
219 if (!unknown_argvs[i]) continue;
220 bool value_parsing_ok;
221 was_found = flag.Parse(argv[i], i, &value_parsing_ok);
222 if (!value_parsing_ok) {
223 TFLITE_LOG(ERROR) << "Failed to parse flag '" << flag.name_
224 << "' against argv '" << argv[i] << "'";
225 result = false;
226 }
227 if (was_found) {
228 unknown_argvs[i] = false;
229 processed_flags[flag.name_] = i;
230 break;
231 }
232 }
233
234 // If the flag is found from the argv (i.e. the flag name appears in argv),
235 // continue to the next flag parsing.
236 if (was_found) continue;
237
238 // The flag isn't found, do some bookkeeping work.
239 processed_flags[flag.name_] = -1;
240 if (flag.flag_type_ == Flag::kRequired) {
241 TFLITE_LOG(ERROR) << "Required flag not provided: " << flag.name_;
242 result = false;
243 // If the required flag isn't found, we immediately stop the whole flag
244 // parsing by breaking the outer-loop (i.e. the 'sorted_idx'-iteration
245 // loop).
246 break;
247 }
248 }
249
250 int dst = 1; // Skip argv[0]
251 for (int i = 1; i < *argc; ++i) {
252 if (unknown_argvs[i]) {
253 argv[dst++] = argv[i];
254 }
255 }
256 *argc = dst;
257 return result && (*argc < 2 || std::strcmp(argv[1], "--help") != 0);
258 }
259
Usage(const std::string & cmdline,const std::vector<Flag> & flag_list)260 /*static*/ std::string Flags::Usage(const std::string& cmdline,
261 const std::vector<Flag>& flag_list) {
262 // Stores indexes of flag_list in a sorted order.
263 std::vector<int> sorted_idx(flag_list.size());
264 std::iota(std::begin(sorted_idx), std::end(sorted_idx), 0);
265 std::stable_sort(
266 sorted_idx.begin(), sorted_idx.end(), [&flag_list](int a, int b) {
267 return flag_list[a].GetFlagType() < flag_list[b].GetFlagType();
268 });
269 // Counts number of positional flags will be shown.
270 int positional_count = 0;
271 std::ostringstream usage_text;
272 usage_text << "usage: " << cmdline;
273 // Prints usage for positional flag.
274 for (int i = 0; i < sorted_idx.size(); ++i) {
275 const Flag& flag = flag_list[sorted_idx[i]];
276 if (flag.flag_type_ == Flag::kPositional) {
277 positional_count++;
278 usage_text << " <" << flag.name_ << ">";
279 } else {
280 usage_text << " <flags>";
281 break;
282 }
283 }
284 usage_text << "\n";
285
286 // Finds the max number of chars of the name column in the usage message.
287 int max_name_width = 0;
288 std::vector<std::string> name_column(flag_list.size());
289 for (int i = 0; i < sorted_idx.size(); ++i) {
290 const Flag& flag = flag_list[sorted_idx[i]];
291 if (flag.flag_type_ != Flag::kPositional) {
292 name_column[i] += "--";
293 name_column[i] += flag.name_;
294 name_column[i] += "=";
295 name_column[i] += flag.default_for_display_;
296 } else {
297 name_column[i] += flag.name_;
298 }
299 if (name_column[i].size() > max_name_width) {
300 max_name_width = name_column[i].size();
301 }
302 }
303
304 if (positional_count > 0) {
305 usage_text << "Where:\n";
306 }
307 for (int i = 0; i < sorted_idx.size(); ++i) {
308 const Flag& flag = flag_list[sorted_idx[i]];
309 if (i == positional_count) {
310 usage_text << "Flags:\n";
311 }
312 auto type_name = flag.GetTypeName();
313 usage_text << "\t";
314 usage_text << std::left << std::setw(max_name_width) << name_column[i];
315 usage_text << "\t" << type_name << "\t";
316 usage_text << (flag.flag_type_ != Flag::kOptional ? "required"
317 : "optional");
318 usage_text << "\t" << flag.usage_text_ << "\n";
319 }
320 return usage_text.str();
321 }
322
ArgsToString(int argc,const char ** argv)323 /*static*/ std::string Flags::ArgsToString(int argc, const char** argv) {
324 std::string args;
325 for (int i = 1; i < argc; ++i) {
326 args.append(argv[i]);
327 if (i != argc - 1) args.append(" ");
328 }
329 return args;
330 }
331
332 } // namespace tflite
333