• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cinttypes>
17 #include <cstring>
18 #include <string>
19 #include <vector>
20 
21 #include "tensorflow/core/lib/core/stringpiece.h"
22 #include "tensorflow/core/lib/strings/str_util.h"
23 #include "tensorflow/core/lib/strings/stringprintf.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/util/command_line_flags.h"
26 
27 namespace tensorflow {
28 namespace {
29 
ParseStringFlag(tensorflow::StringPiece arg,tensorflow::StringPiece flag,const std::function<bool (string)> & hook,bool * value_parsing_ok)30 bool ParseStringFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
31                      const std::function<bool(string)>& hook,
32                      bool* value_parsing_ok) {
33   *value_parsing_ok = true;
34   if (absl::ConsumePrefix(&arg, "--") && absl::ConsumePrefix(&arg, flag) &&
35       absl::ConsumePrefix(&arg, "=")) {
36     *value_parsing_ok = hook(string(arg));
37     return true;
38   }
39 
40   return false;
41 }
42 
ParseInt32Flag(tensorflow::StringPiece arg,tensorflow::StringPiece flag,const std::function<bool (int32)> & hook,bool * value_parsing_ok)43 bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
44                     const std::function<bool(int32)>& hook,
45                     bool* value_parsing_ok) {
46   *value_parsing_ok = true;
47   if (absl::ConsumePrefix(&arg, "--") && absl::ConsumePrefix(&arg, flag) &&
48       absl::ConsumePrefix(&arg, "=")) {
49     char extra;
50     int32 parsed_int32;
51     if (sscanf(arg.data(), "%d%c", &parsed_int32, &extra) != 1) {
52       LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
53                  << ".";
54       *value_parsing_ok = false;
55     } else {
56       *value_parsing_ok = hook(parsed_int32);
57     }
58     return true;
59   }
60 
61   return false;
62 }
63 
ParseInt64Flag(tensorflow::StringPiece arg,tensorflow::StringPiece flag,const std::function<bool (int64)> & hook,bool * value_parsing_ok)64 bool ParseInt64Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
65                     const std::function<bool(int64)>& hook,
66                     bool* value_parsing_ok) {
67   *value_parsing_ok = true;
68   if (absl::ConsumePrefix(&arg, "--") && absl::ConsumePrefix(&arg, flag) &&
69       absl::ConsumePrefix(&arg, "=")) {
70     char extra;
71     int64_t parsed_int64;
72     if (sscanf(arg.data(), "%" SCNd64 "%c", &parsed_int64, &extra) != 1) {
73       LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
74                  << ".";
75       *value_parsing_ok = false;
76     } else {
77       *value_parsing_ok = hook(parsed_int64);
78     }
79     return true;
80   }
81 
82   return false;
83 }
84 
ParseBoolFlag(tensorflow::StringPiece arg,tensorflow::StringPiece flag,const std::function<bool (bool)> & hook,bool * value_parsing_ok)85 bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
86                    const std::function<bool(bool)>& hook,
87                    bool* value_parsing_ok) {
88   *value_parsing_ok = true;
89   if (absl::ConsumePrefix(&arg, "--") && absl::ConsumePrefix(&arg, flag)) {
90     if (arg.empty()) {
91       *value_parsing_ok = hook(true);
92       return true;
93     }
94 
95     if (arg == "=true") {
96       *value_parsing_ok = hook(true);
97       return true;
98     } else if (arg == "=false") {
99       *value_parsing_ok = hook(false);
100       return true;
101     } else {
102       LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
103                  << ".";
104       *value_parsing_ok = false;
105       return true;
106     }
107   }
108 
109   return false;
110 }
111 
ParseFloatFlag(tensorflow::StringPiece arg,tensorflow::StringPiece flag,const std::function<bool (float)> & hook,bool * value_parsing_ok)112 bool ParseFloatFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
113                     const std::function<bool(float)>& hook,
114                     bool* value_parsing_ok) {
115   *value_parsing_ok = true;
116   if (absl::ConsumePrefix(&arg, "--") && absl::ConsumePrefix(&arg, flag) &&
117       absl::ConsumePrefix(&arg, "=")) {
118     char extra;
119     float parsed_float;
120     if (sscanf(arg.data(), "%f%c", &parsed_float, &extra) != 1) {
121       LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
122                  << ".";
123       *value_parsing_ok = false;
124     } else {
125       *value_parsing_ok = hook(parsed_float);
126     }
127     return true;
128   }
129 
130   return false;
131 }
132 
133 }  // namespace
134 
Flag(const char * name,tensorflow::int32 * dst,const string & usage_text,bool * dst_updated)135 Flag::Flag(const char* name, tensorflow::int32* dst, const string& usage_text,
136            bool* dst_updated)
137     : name_(name),
138       type_(TYPE_INT32),
139       int32_hook_([dst, dst_updated](int32 value) {
140         *dst = value;
141         if (dst_updated) *dst_updated = true;
142         return true;
143       }),
144       int32_default_for_display_(*dst),
145       usage_text_(usage_text) {}
146 
Flag(const char * name,tensorflow::int64 * dst,const string & usage_text,bool * dst_updated)147 Flag::Flag(const char* name, tensorflow::int64* dst, const string& usage_text,
148            bool* dst_updated)
149     : name_(name),
150       type_(TYPE_INT64),
151       int64_hook_([dst, dst_updated](int64 value) {
152         *dst = value;
153         if (dst_updated) *dst_updated = true;
154         return true;
155       }),
156       int64_default_for_display_(*dst),
157       usage_text_(usage_text) {}
158 
Flag(const char * name,float * dst,const string & usage_text,bool * dst_updated)159 Flag::Flag(const char* name, float* dst, const string& usage_text,
160            bool* dst_updated)
161     : name_(name),
162       type_(TYPE_FLOAT),
163       float_hook_([dst, dst_updated](float value) {
164         *dst = value;
165         if (dst_updated) *dst_updated = true;
166         return true;
167       }),
168       float_default_for_display_(*dst),
169       usage_text_(usage_text) {}
170 
Flag(const char * name,bool * dst,const string & usage_text,bool * dst_updated)171 Flag::Flag(const char* name, bool* dst, const string& usage_text,
172            bool* dst_updated)
173     : name_(name),
174       type_(TYPE_BOOL),
175       bool_hook_([dst, dst_updated](bool value) {
176         *dst = value;
177         if (dst_updated) *dst_updated = true;
178         return true;
179       }),
180       bool_default_for_display_(*dst),
181       usage_text_(usage_text) {}
182 
Flag(const char * name,string * dst,const string & usage_text,bool * dst_updated)183 Flag::Flag(const char* name, string* dst, const string& usage_text,
184            bool* dst_updated)
185     : name_(name),
186       type_(TYPE_STRING),
187       string_hook_([dst, dst_updated](string value) {
188         *dst = std::move(value);
189         if (dst_updated) *dst_updated = true;
190         return true;
191       }),
192       string_default_for_display_(*dst),
193       usage_text_(usage_text) {}
194 
Flag(const char * name,std::function<bool (int32)> int32_hook,int32 default_value_for_display,const string & usage_text)195 Flag::Flag(const char* name, std::function<bool(int32)> int32_hook,
196            int32 default_value_for_display, const string& usage_text)
197     : name_(name),
198       type_(TYPE_INT32),
199       int32_hook_(std::move(int32_hook)),
200       int32_default_for_display_(default_value_for_display),
201       usage_text_(usage_text) {}
202 
Flag(const char * name,std::function<bool (int64)> int64_hook,int64 default_value_for_display,const string & usage_text)203 Flag::Flag(const char* name, std::function<bool(int64)> int64_hook,
204            int64 default_value_for_display, const string& usage_text)
205     : name_(name),
206       type_(TYPE_INT64),
207       int64_hook_(std::move(int64_hook)),
208       int64_default_for_display_(default_value_for_display),
209       usage_text_(usage_text) {}
210 
Flag(const char * name,std::function<bool (float)> float_hook,float default_value_for_display,const string & usage_text)211 Flag::Flag(const char* name, std::function<bool(float)> float_hook,
212            float default_value_for_display, const string& usage_text)
213     : name_(name),
214       type_(TYPE_FLOAT),
215       float_hook_(std::move(float_hook)),
216       float_default_for_display_(default_value_for_display),
217       usage_text_(usage_text) {}
218 
Flag(const char * name,std::function<bool (bool)> bool_hook,bool default_value_for_display,const string & usage_text)219 Flag::Flag(const char* name, std::function<bool(bool)> bool_hook,
220            bool default_value_for_display, const string& usage_text)
221     : name_(name),
222       type_(TYPE_BOOL),
223       bool_hook_(std::move(bool_hook)),
224       bool_default_for_display_(default_value_for_display),
225       usage_text_(usage_text) {}
226 
Flag(const char * name,std::function<bool (string)> string_hook,string default_value_for_display,const string & usage_text)227 Flag::Flag(const char* name, std::function<bool(string)> string_hook,
228            string default_value_for_display, const string& usage_text)
229     : name_(name),
230       type_(TYPE_STRING),
231       string_hook_(std::move(string_hook)),
232       string_default_for_display_(std::move(default_value_for_display)),
233       usage_text_(usage_text) {}
234 
Parse(string arg,bool * value_parsing_ok) const235 bool Flag::Parse(string arg, bool* value_parsing_ok) const {
236   bool result = false;
237   if (type_ == TYPE_INT32) {
238     result = ParseInt32Flag(arg, name_, int32_hook_, value_parsing_ok);
239   } else if (type_ == TYPE_INT64) {
240     result = ParseInt64Flag(arg, name_, int64_hook_, value_parsing_ok);
241   } else if (type_ == TYPE_BOOL) {
242     result = ParseBoolFlag(arg, name_, bool_hook_, value_parsing_ok);
243   } else if (type_ == TYPE_STRING) {
244     result = ParseStringFlag(arg, name_, string_hook_, value_parsing_ok);
245   } else if (type_ == TYPE_FLOAT) {
246     result = ParseFloatFlag(arg, name_, float_hook_, value_parsing_ok);
247   }
248   return result;
249 }
250 
Parse(int * argc,char ** argv,const std::vector<Flag> & flag_list)251 /*static*/ bool Flags::Parse(int* argc, char** argv,
252                              const std::vector<Flag>& flag_list) {
253   bool result = true;
254   std::vector<char*> unknown_flags;
255   for (int i = 1; i < *argc; ++i) {
256     if (string(argv[i]) == "--") {
257       while (i < *argc) {
258         unknown_flags.push_back(argv[i]);
259         ++i;
260       }
261       break;
262     }
263 
264     bool was_found = false;
265     for (const Flag& flag : flag_list) {
266       bool value_parsing_ok;
267       was_found = flag.Parse(argv[i], &value_parsing_ok);
268       if (!value_parsing_ok) {
269         result = false;
270       }
271       if (was_found) {
272         break;
273       }
274     }
275     if (!was_found) {
276       unknown_flags.push_back(argv[i]);
277     }
278   }
279   // Passthrough any extra flags.
280   int dst = 1;  // Skip argv[0]
281   for (char* f : unknown_flags) {
282     argv[dst++] = f;
283   }
284   argv[dst++] = nullptr;
285   *argc = unknown_flags.size() + 1;
286   return result && (*argc < 2 || strcmp(argv[1], "--help") != 0);
287 }
288 
Usage(const string & cmdline,const std::vector<Flag> & flag_list)289 /*static*/ string Flags::Usage(const string& cmdline,
290                                const std::vector<Flag>& flag_list) {
291   string usage_text;
292   if (!flag_list.empty()) {
293     strings::Appendf(&usage_text, "usage: %s\nFlags:\n", cmdline.c_str());
294   } else {
295     strings::Appendf(&usage_text, "usage: %s\n", cmdline.c_str());
296   }
297   for (const Flag& flag : flag_list) {
298     const char* type_name = "";
299     string flag_string;
300     if (flag.type_ == Flag::TYPE_INT32) {
301       type_name = "int32";
302       flag_string = strings::Printf("--%s=%d", flag.name_.c_str(),
303                                     flag.int32_default_for_display_);
304     } else if (flag.type_ == Flag::TYPE_INT64) {
305       type_name = "int64";
306       flag_string = strings::Printf(
307           "--%s=%lld", flag.name_.c_str(),
308           static_cast<long long>(flag.int64_default_for_display_));
309     } else if (flag.type_ == Flag::TYPE_BOOL) {
310       type_name = "bool";
311       flag_string =
312           strings::Printf("--%s=%s", flag.name_.c_str(),
313                           flag.bool_default_for_display_ ? "true" : "false");
314     } else if (flag.type_ == Flag::TYPE_STRING) {
315       type_name = "string";
316       flag_string = strings::Printf("--%s=\"%s\"", flag.name_.c_str(),
317                                     flag.string_default_for_display_.c_str());
318     } else if (flag.type_ == Flag::TYPE_FLOAT) {
319       type_name = "float";
320       flag_string = strings::Printf("--%s=%f", flag.name_.c_str(),
321                                     flag.float_default_for_display_);
322     }
323     strings::Appendf(&usage_text, "\t%-33s\t%s\t%s\n", flag_string.c_str(),
324                      type_name, flag.usage_text_.c_str());
325   }
326   return usage_text;
327 }
328 
329 }  // namespace tensorflow
330