• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5     http://www.apache.org/licenses/LICENSE-2.0
6 Unless required by applicable law or agreed to in writing, software
7 distributed under the License is distributed on an "AS IS" BASIS,
8 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 See the License for the specific language governing permissions and
10 limitations under the License.
11 ==============================================================================*/
12 
13 #include "tensorflow/lite/tools/command_line_flags.h"
14 
15 #include <algorithm>
16 #include <cstring>
17 #include <functional>
18 #include <iomanip>
19 #include <numeric>
20 #include <sstream>
21 #include <string>
22 #include <unordered_map>
23 #include <utility>
24 #include <vector>
25 
26 #include "absl/strings/match.h"
27 #include "tensorflow/lite/tools/logging.h"
28 
29 namespace tflite {
30 namespace {
31 
32 template <typename T>
ToString(T val)33 std::string ToString(T val) {
34   std::ostringstream stream;
35   stream << val;
36   return stream.str();
37 }
38 
39 template <>
ToString(bool val)40 std::string ToString(bool val) {
41   return val ? "true" : "false";
42 }
43 
44 template <>
ToString(const std::string & val)45 std::string ToString(const std::string& val) {
46   return val;
47 }
48 
ParseFlag(const std::string & arg,int argv_position,const std::string & flag,bool positional,const std::function<bool (const std::string &,int argv_position)> & parse_func,bool * value_parsing_ok)49 bool ParseFlag(const std::string& arg, int argv_position,
50                const std::string& flag, bool positional,
51                const std::function<bool(const std::string&, int argv_position)>&
52                    parse_func,
53                bool* value_parsing_ok) {
54   if (positional) {
55     *value_parsing_ok = parse_func(arg, argv_position);
56     return true;
57   }
58   *value_parsing_ok = true;
59   std::string flag_prefix = "--" + flag + "=";
60   if (!absl::StartsWith(arg, flag_prefix)) {
61     return false;
62   }
63   bool has_value = arg.size() >= flag_prefix.size();
64   *value_parsing_ok = has_value;
65   if (has_value) {
66     *value_parsing_ok =
67         parse_func(arg.substr(flag_prefix.size()), argv_position);
68   }
69   return true;
70 }
71 
72 template <typename T>
ParseFlag(const std::string & flag_value,int argv_position,const std::function<void (const T &,int)> & hook)73 bool ParseFlag(const std::string& flag_value, int argv_position,
74                const std::function<void(const T&, int)>& hook) {
75   std::istringstream stream(flag_value);
76   T read_value;
77   stream >> read_value;
78   if (!stream.eof() && !stream.good()) {
79     return false;
80   }
81   hook(read_value, argv_position);
82   return true;
83 }
84 
85 template <>
ParseFlag(const std::string & flag_value,int argv_position,const std::function<void (const bool &,int)> & hook)86 bool ParseFlag(const std::string& flag_value, int argv_position,
87                const std::function<void(const bool&, int)>& hook) {
88   if (flag_value != "true" && flag_value != "false" && flag_value != "0" &&
89       flag_value != "1") {
90     return false;
91   }
92 
93   hook(flag_value == "true" || flag_value == "1", argv_position);
94   return true;
95 }
96 
97 template <typename T>
ParseFlag(const std::string & flag_value,int argv_position,const std::function<void (const std::string &,int)> & hook)98 bool ParseFlag(const std::string& flag_value, int argv_position,
99                const std::function<void(const std::string&, int)>& hook) {
100   hook(flag_value, argv_position);
101   return true;
102 }
103 }  // namespace
104 
105 #define CONSTRUCTOR_IMPLEMENTATION(flag_T, default_value_T, flag_enum_val)     \
106   Flag::Flag(const char* name,                                                 \
107              const std::function<void(const flag_T& /*flag_val*/,              \
108                                       int /*argv_position*/)>& hook,           \
109              default_value_T default_value, const std::string& usage_text,     \
110              FlagType flag_type)                                               \
111       : name_(name),                                                           \
112         type_(flag_enum_val),                                                  \
113         value_hook_([hook](const std::string& flag_value, int argv_position) { \
114           return ParseFlag<flag_T>(flag_value, argv_position, hook);           \
115         }),                                                                    \
116         default_for_display_(ToString<default_value_T>(default_value)),        \
117         usage_text_(usage_text),                                               \
118         flag_type_(flag_type) {}
119 
CONSTRUCTOR_IMPLEMENTATION(int32_t,int32_t,TYPE_INT32)120 CONSTRUCTOR_IMPLEMENTATION(int32_t, int32_t, TYPE_INT32)
121 CONSTRUCTOR_IMPLEMENTATION(int64_t, int64_t, TYPE_INT64)
122 CONSTRUCTOR_IMPLEMENTATION(float, float, TYPE_FLOAT)
123 CONSTRUCTOR_IMPLEMENTATION(bool, bool, TYPE_BOOL)
124 CONSTRUCTOR_IMPLEMENTATION(std::string, const std::string&, TYPE_STRING)
125 
126 #undef CONSTRUCTOR_IMPLEMENTATION
127 
128 bool Flag::Parse(const std::string& arg, int argv_position,
129                  bool* value_parsing_ok) const {
130   return ParseFlag(
131       arg, argv_position, name_, flag_type_ == kPositional,
132       [&](const std::string& read_value, int argv_position) {
133         return value_hook_(read_value, argv_position);
134       },
135       value_parsing_ok);
136 }
137 
GetTypeName() const138 std::string Flag::GetTypeName() const {
139   switch (type_) {
140     case TYPE_INT32:
141       return "int32";
142     case TYPE_INT64:
143       return "int64";
144     case TYPE_FLOAT:
145       return "float";
146     case TYPE_BOOL:
147       return "bool";
148     case TYPE_STRING:
149       return "string";
150   }
151 
152   return "unknown";
153 }
154 
Parse(int * argc,const char ** argv,const std::vector<Flag> & flag_list)155 /*static*/ bool Flags::Parse(int* argc, const char** argv,
156                              const std::vector<Flag>& flag_list) {
157   bool result = true;
158   std::vector<bool> unknown_argvs(*argc, true);
159   // Record the list of flags that have been processed. key is the flag's name
160   // and the value is the corresponding argv index if there's one, or -1 when
161   // the argv list doesn't contain this flag.
162   std::unordered_map<std::string, int> processed_flags;
163 
164   // Stores indexes of flag_list in a sorted order.
165   std::vector<int> sorted_idx(flag_list.size());
166   std::iota(std::begin(sorted_idx), std::end(sorted_idx), 0);
167   std::sort(sorted_idx.begin(), sorted_idx.end(), [&flag_list](int a, int b) {
168     return flag_list[a].GetFlagType() < flag_list[b].GetFlagType();
169   });
170   int positional_count = 0;
171 
172   for (int idx = 0; idx < sorted_idx.size(); ++idx) {
173     const Flag& flag = flag_list[sorted_idx[idx]];
174 
175     const auto it = processed_flags.find(flag.name_);
176     if (it != processed_flags.end()) {
177 #ifndef NDEBUG
178       // Only log this in debug builds.
179       TFLITE_LOG(WARN) << "Duplicate flags: " << flag.name_;
180 #endif
181       if (it->second != -1) {
182         bool value_parsing_ok;
183         flag.Parse(argv[it->second], it->second, &value_parsing_ok);
184         if (!value_parsing_ok) {
185           TFLITE_LOG(ERROR) << "Failed to parse flag '" << flag.name_
186                             << "' against argv '" << argv[it->second] << "'";
187           result = false;
188         }
189         continue;
190       } else if (flag.flag_type_ == Flag::kRequired) {
191         TFLITE_LOG(ERROR) << "Required flag not provided: " << flag.name_;
192         // If the required flag isn't found, we immediately stop the whole flag
193         // parsing.
194         result = false;
195         break;
196       }
197     }
198 
199     // Parses positional flags.
200     if (flag.flag_type_ == Flag::kPositional) {
201       if (++positional_count >= *argc) {
202         TFLITE_LOG(ERROR) << "Too few command line arguments.";
203         return false;
204       }
205       bool value_parsing_ok;
206       flag.Parse(argv[positional_count], positional_count, &value_parsing_ok);
207       if (!value_parsing_ok) {
208         TFLITE_LOG(ERROR) << "Failed to parse positional flag: " << flag.name_;
209         return false;
210       }
211       unknown_argvs[positional_count] = false;
212       processed_flags[flag.name_] = positional_count;
213       continue;
214     }
215 
216     // Parse other flags.
217     bool was_found = false;
218     for (int i = positional_count + 1; i < *argc; ++i) {
219       if (!unknown_argvs[i]) continue;
220       bool value_parsing_ok;
221       was_found = flag.Parse(argv[i], i, &value_parsing_ok);
222       if (!value_parsing_ok) {
223         TFLITE_LOG(ERROR) << "Failed to parse flag '" << flag.name_
224                           << "' against argv '" << argv[i] << "'";
225         result = false;
226       }
227       if (was_found) {
228         unknown_argvs[i] = false;
229         processed_flags[flag.name_] = i;
230         break;
231       }
232     }
233 
234     // If the flag is found from the argv (i.e. the flag name appears in argv),
235     // continue to the next flag parsing.
236     if (was_found) continue;
237 
238     // The flag isn't found, do some bookkeeping work.
239     processed_flags[flag.name_] = -1;
240     if (flag.flag_type_ == Flag::kRequired) {
241       TFLITE_LOG(ERROR) << "Required flag not provided: " << flag.name_;
242       result = false;
243       // If the required flag isn't found, we immediately stop the whole flag
244       // parsing by breaking the outer-loop (i.e. the 'sorted_idx'-iteration
245       // loop).
246       break;
247     }
248   }
249 
250   int dst = 1;  // Skip argv[0]
251   for (int i = 1; i < *argc; ++i) {
252     if (unknown_argvs[i]) {
253       argv[dst++] = argv[i];
254     }
255   }
256   *argc = dst;
257   return result && (*argc < 2 || std::strcmp(argv[1], "--help") != 0);
258 }
259 
Usage(const std::string & cmdline,const std::vector<Flag> & flag_list)260 /*static*/ std::string Flags::Usage(const std::string& cmdline,
261                                     const std::vector<Flag>& flag_list) {
262   // Stores indexes of flag_list in a sorted order.
263   std::vector<int> sorted_idx(flag_list.size());
264   std::iota(std::begin(sorted_idx), std::end(sorted_idx), 0);
265   std::stable_sort(
266       sorted_idx.begin(), sorted_idx.end(), [&flag_list](int a, int b) {
267         return flag_list[a].GetFlagType() < flag_list[b].GetFlagType();
268       });
269   // Counts number of positional flags will be shown.
270   int positional_count = 0;
271   std::ostringstream usage_text;
272   usage_text << "usage: " << cmdline;
273   // Prints usage for positional flag.
274   for (int i = 0; i < sorted_idx.size(); ++i) {
275     const Flag& flag = flag_list[sorted_idx[i]];
276     if (flag.flag_type_ == Flag::kPositional) {
277       positional_count++;
278       usage_text << " <" << flag.name_ << ">";
279     } else {
280       usage_text << " <flags>";
281       break;
282     }
283   }
284   usage_text << "\n";
285 
286   // Finds the max number of chars of the name column in the usage message.
287   int max_name_width = 0;
288   std::vector<std::string> name_column(flag_list.size());
289   for (int i = 0; i < sorted_idx.size(); ++i) {
290     const Flag& flag = flag_list[sorted_idx[i]];
291     if (flag.flag_type_ != Flag::kPositional) {
292       name_column[i] += "--";
293       name_column[i] += flag.name_;
294       name_column[i] += "=";
295       name_column[i] += flag.default_for_display_;
296     } else {
297       name_column[i] += flag.name_;
298     }
299     if (name_column[i].size() > max_name_width) {
300       max_name_width = name_column[i].size();
301     }
302   }
303 
304   if (positional_count > 0) {
305     usage_text << "Where:\n";
306   }
307   for (int i = 0; i < sorted_idx.size(); ++i) {
308     const Flag& flag = flag_list[sorted_idx[i]];
309     if (i == positional_count) {
310       usage_text << "Flags:\n";
311     }
312     auto type_name = flag.GetTypeName();
313     usage_text << "\t";
314     usage_text << std::left << std::setw(max_name_width) << name_column[i];
315     usage_text << "\t" << type_name << "\t";
316     usage_text << (flag.flag_type_ != Flag::kOptional ? "required"
317                                                       : "optional");
318     usage_text << "\t" << flag.usage_text_ << "\n";
319   }
320   return usage_text.str();
321 }
322 
ArgsToString(int argc,const char ** argv)323 /*static*/ std::string Flags::ArgsToString(int argc, const char** argv) {
324   std::string args;
325   for (int i = 1; i < argc; ++i) {
326     args.append(argv[i]);
327     if (i != argc - 1) args.append(" ");
328   }
329   return args;
330 }
331 
332 }  // namespace tflite
333