1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cinttypes>
17 #include <cstring>
18 #include <string>
19 #include <vector>
20
21 #include "tensorflow/core/lib/core/stringpiece.h"
22 #include "tensorflow/core/lib/strings/str_util.h"
23 #include "tensorflow/core/lib/strings/stringprintf.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/util/command_line_flags.h"
26
27 namespace tensorflow {
28 namespace {
29
ParseStringFlag(tensorflow::StringPiece arg,tensorflow::StringPiece flag,const std::function<bool (string)> & hook,bool * value_parsing_ok)30 bool ParseStringFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
31 const std::function<bool(string)>& hook,
32 bool* value_parsing_ok) {
33 *value_parsing_ok = true;
34 if (str_util::ConsumePrefix(&arg, "--") &&
35 str_util::ConsumePrefix(&arg, flag) &&
36 str_util::ConsumePrefix(&arg, "=")) {
37 *value_parsing_ok = hook(string(arg));
38 return true;
39 }
40
41 return false;
42 }
43
ParseInt32Flag(tensorflow::StringPiece arg,tensorflow::StringPiece flag,const std::function<bool (int32)> & hook,bool * value_parsing_ok)44 bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
45 const std::function<bool(int32)>& hook,
46 bool* value_parsing_ok) {
47 *value_parsing_ok = true;
48 if (str_util::ConsumePrefix(&arg, "--") &&
49 str_util::ConsumePrefix(&arg, flag) &&
50 str_util::ConsumePrefix(&arg, "=")) {
51 char extra;
52 int32 parsed_int32;
53 if (sscanf(arg.data(), "%d%c", &parsed_int32, &extra) != 1) {
54 LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
55 << ".";
56 *value_parsing_ok = false;
57 } else {
58 *value_parsing_ok = hook(parsed_int32);
59 }
60 return true;
61 }
62
63 return false;
64 }
65
ParseInt64Flag(tensorflow::StringPiece arg,tensorflow::StringPiece flag,const std::function<bool (int64)> & hook,bool * value_parsing_ok)66 bool ParseInt64Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
67 const std::function<bool(int64)>& hook,
68 bool* value_parsing_ok) {
69 *value_parsing_ok = true;
70 if (str_util::ConsumePrefix(&arg, "--") &&
71 str_util::ConsumePrefix(&arg, flag) &&
72 str_util::ConsumePrefix(&arg, "=")) {
73 char extra;
74 int64_t parsed_int64;
75 if (sscanf(arg.data(), "%" SCNd64 "%c", &parsed_int64, &extra) != 1) {
76 LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
77 << ".";
78 *value_parsing_ok = false;
79 } else {
80 *value_parsing_ok = hook(parsed_int64);
81 }
82 return true;
83 }
84
85 return false;
86 }
87
ParseBoolFlag(tensorflow::StringPiece arg,tensorflow::StringPiece flag,const std::function<bool (bool)> & hook,bool * value_parsing_ok)88 bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
89 const std::function<bool(bool)>& hook,
90 bool* value_parsing_ok) {
91 *value_parsing_ok = true;
92 if (str_util::ConsumePrefix(&arg, "--") &&
93 str_util::ConsumePrefix(&arg, flag)) {
94 if (arg.empty()) {
95 *value_parsing_ok = hook(true);
96 return true;
97 }
98
99 if (arg == "=true") {
100 *value_parsing_ok = hook(true);
101 return true;
102 } else if (arg == "=false") {
103 *value_parsing_ok = hook(false);
104 return true;
105 } else {
106 LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
107 << ".";
108 *value_parsing_ok = false;
109 return true;
110 }
111 }
112
113 return false;
114 }
115
ParseFloatFlag(tensorflow::StringPiece arg,tensorflow::StringPiece flag,const std::function<bool (float)> & hook,bool * value_parsing_ok)116 bool ParseFloatFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
117 const std::function<bool(float)>& hook,
118 bool* value_parsing_ok) {
119 *value_parsing_ok = true;
120 if (str_util::ConsumePrefix(&arg, "--") &&
121 str_util::ConsumePrefix(&arg, flag) &&
122 str_util::ConsumePrefix(&arg, "=")) {
123 char extra;
124 float parsed_float;
125 if (sscanf(arg.data(), "%f%c", &parsed_float, &extra) != 1) {
126 LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
127 << ".";
128 *value_parsing_ok = false;
129 } else {
130 *value_parsing_ok = hook(parsed_float);
131 }
132 return true;
133 }
134
135 return false;
136 }
137
138 } // namespace
139
Flag(const char * name,tensorflow::int32 * dst,const string & usage_text)140 Flag::Flag(const char* name, tensorflow::int32* dst, const string& usage_text)
141 : name_(name),
142 type_(TYPE_INT32),
143 int32_hook_([dst](int32 value) {
144 *dst = value;
145 return true;
146 }),
147 int32_default_for_display_(*dst),
148 usage_text_(usage_text) {}
149
Flag(const char * name,tensorflow::int64 * dst,const string & usage_text)150 Flag::Flag(const char* name, tensorflow::int64* dst, const string& usage_text)
151 : name_(name),
152 type_(TYPE_INT64),
153 int64_hook_([dst](int64 value) {
154 *dst = value;
155 return true;
156 }),
157 int64_default_for_display_(*dst),
158 usage_text_(usage_text) {}
159
Flag(const char * name,float * dst,const string & usage_text)160 Flag::Flag(const char* name, float* dst, const string& usage_text)
161 : name_(name),
162 type_(TYPE_FLOAT),
163 float_hook_([dst](float value) {
164 *dst = value;
165 return true;
166 }),
167 float_default_for_display_(*dst),
168 usage_text_(usage_text) {}
169
Flag(const char * name,bool * dst,const string & usage_text)170 Flag::Flag(const char* name, bool* dst, const string& usage_text)
171 : name_(name),
172 type_(TYPE_BOOL),
173 bool_hook_([dst](bool value) {
174 *dst = value;
175 return true;
176 }),
177 bool_default_for_display_(*dst),
178 usage_text_(usage_text) {}
179
Flag(const char * name,string * dst,const string & usage_text)180 Flag::Flag(const char* name, string* dst, const string& usage_text)
181 : name_(name),
182 type_(TYPE_STRING),
183 string_hook_([dst](string value) {
184 *dst = std::move(value);
185 return true;
186 }),
187 string_default_for_display_(*dst),
188 usage_text_(usage_text) {}
189
Flag(const char * name,std::function<bool (int32)> int32_hook,int32 default_value_for_display,const string & usage_text)190 Flag::Flag(const char* name, std::function<bool(int32)> int32_hook,
191 int32 default_value_for_display, const string& usage_text)
192 : name_(name),
193 type_(TYPE_INT32),
194 int32_hook_(std::move(int32_hook)),
195 int32_default_for_display_(default_value_for_display),
196 usage_text_(usage_text) {}
197
Flag(const char * name,std::function<bool (int64)> int64_hook,int64 default_value_for_display,const string & usage_text)198 Flag::Flag(const char* name, std::function<bool(int64)> int64_hook,
199 int64 default_value_for_display, const string& usage_text)
200 : name_(name),
201 type_(TYPE_INT64),
202 int64_hook_(std::move(int64_hook)),
203 int64_default_for_display_(default_value_for_display),
204 usage_text_(usage_text) {}
205
Flag(const char * name,std::function<bool (float)> float_hook,float default_value_for_display,const string & usage_text)206 Flag::Flag(const char* name, std::function<bool(float)> float_hook,
207 float default_value_for_display, const string& usage_text)
208 : name_(name),
209 type_(TYPE_FLOAT),
210 float_hook_(std::move(float_hook)),
211 float_default_for_display_(default_value_for_display),
212 usage_text_(usage_text) {}
213
Flag(const char * name,std::function<bool (bool)> bool_hook,bool default_value_for_display,const string & usage_text)214 Flag::Flag(const char* name, std::function<bool(bool)> bool_hook,
215 bool default_value_for_display, const string& usage_text)
216 : name_(name),
217 type_(TYPE_BOOL),
218 bool_hook_(std::move(bool_hook)),
219 bool_default_for_display_(default_value_for_display),
220 usage_text_(usage_text) {}
221
Flag(const char * name,std::function<bool (string)> string_hook,string default_value_for_display,const string & usage_text)222 Flag::Flag(const char* name, std::function<bool(string)> string_hook,
223 string default_value_for_display, const string& usage_text)
224 : name_(name),
225 type_(TYPE_STRING),
226 string_hook_(std::move(string_hook)),
227 string_default_for_display_(std::move(default_value_for_display)),
228 usage_text_(usage_text) {}
229
Parse(string arg,bool * value_parsing_ok) const230 bool Flag::Parse(string arg, bool* value_parsing_ok) const {
231 bool result = false;
232 if (type_ == TYPE_INT32) {
233 result = ParseInt32Flag(arg, name_, int32_hook_, value_parsing_ok);
234 } else if (type_ == TYPE_INT64) {
235 result = ParseInt64Flag(arg, name_, int64_hook_, value_parsing_ok);
236 } else if (type_ == TYPE_BOOL) {
237 result = ParseBoolFlag(arg, name_, bool_hook_, value_parsing_ok);
238 } else if (type_ == TYPE_STRING) {
239 result = ParseStringFlag(arg, name_, string_hook_, value_parsing_ok);
240 } else if (type_ == TYPE_FLOAT) {
241 result = ParseFloatFlag(arg, name_, float_hook_, value_parsing_ok);
242 }
243 return result;
244 }
245
Parse(int * argc,char ** argv,const std::vector<Flag> & flag_list)246 /*static*/ bool Flags::Parse(int* argc, char** argv,
247 const std::vector<Flag>& flag_list) {
248 bool result = true;
249 std::vector<char*> unknown_flags;
250 for (int i = 1; i < *argc; ++i) {
251 if (string(argv[i]) == "--") {
252 while (i < *argc) {
253 unknown_flags.push_back(argv[i]);
254 ++i;
255 }
256 break;
257 }
258
259 bool was_found = false;
260 for (const Flag& flag : flag_list) {
261 bool value_parsing_ok;
262 was_found = flag.Parse(argv[i], &value_parsing_ok);
263 if (!value_parsing_ok) {
264 result = false;
265 }
266 if (was_found) {
267 break;
268 }
269 }
270 if (!was_found) {
271 unknown_flags.push_back(argv[i]);
272 }
273 }
274 // Passthrough any extra flags.
275 int dst = 1; // Skip argv[0]
276 for (char* f : unknown_flags) {
277 argv[dst++] = f;
278 }
279 argv[dst++] = nullptr;
280 *argc = unknown_flags.size() + 1;
281 return result && (*argc < 2 || strcmp(argv[1], "--help") != 0);
282 }
283
Usage(const string & cmdline,const std::vector<Flag> & flag_list)284 /*static*/ string Flags::Usage(const string& cmdline,
285 const std::vector<Flag>& flag_list) {
286 string usage_text;
287 if (!flag_list.empty()) {
288 strings::Appendf(&usage_text, "usage: %s\nFlags:\n", cmdline.c_str());
289 } else {
290 strings::Appendf(&usage_text, "usage: %s\n", cmdline.c_str());
291 }
292 for (const Flag& flag : flag_list) {
293 const char* type_name = "";
294 string flag_string;
295 if (flag.type_ == Flag::TYPE_INT32) {
296 type_name = "int32";
297 flag_string = strings::Printf("--%s=%d", flag.name_.c_str(),
298 flag.int32_default_for_display_);
299 } else if (flag.type_ == Flag::TYPE_INT64) {
300 type_name = "int64";
301 flag_string = strings::Printf(
302 "--%s=%lld", flag.name_.c_str(),
303 static_cast<long long>(flag.int64_default_for_display_));
304 } else if (flag.type_ == Flag::TYPE_BOOL) {
305 type_name = "bool";
306 flag_string =
307 strings::Printf("--%s=%s", flag.name_.c_str(),
308 flag.bool_default_for_display_ ? "true" : "false");
309 } else if (flag.type_ == Flag::TYPE_STRING) {
310 type_name = "string";
311 flag_string = strings::Printf("--%s=\"%s\"", flag.name_.c_str(),
312 flag.string_default_for_display_.c_str());
313 } else if (flag.type_ == Flag::TYPE_FLOAT) {
314 type_name = "float";
315 flag_string = strings::Printf("--%s=%f", flag.name_.c_str(),
316 flag.float_default_for_display_);
317 }
318 strings::Appendf(&usage_text, "\t%-33s\t%s\t%s\n", flag_string.c_str(),
319 type_name, flag.usage_text_.c_str());
320 }
321 return usage_text;
322 }
323
324 } // namespace tensorflow
325