1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cinttypes>
17 #include <cstring>
18 #include <string>
19 #include <vector>
20
21 #include "tensorflow/core/lib/core/stringpiece.h"
22 #include "tensorflow/core/lib/strings/str_util.h"
23 #include "tensorflow/core/lib/strings/stringprintf.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/util/command_line_flags.h"
26
27 namespace tensorflow {
28 namespace {
29
ParseStringFlag(tensorflow::StringPiece arg,tensorflow::StringPiece flag,const std::function<bool (string)> & hook,bool * value_parsing_ok)30 bool ParseStringFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
31 const std::function<bool(string)>& hook,
32 bool* value_parsing_ok) {
33 *value_parsing_ok = true;
34 if (absl::ConsumePrefix(&arg, "--") && absl::ConsumePrefix(&arg, flag) &&
35 absl::ConsumePrefix(&arg, "=")) {
36 *value_parsing_ok = hook(string(arg));
37 return true;
38 }
39
40 return false;
41 }
42
ParseInt32Flag(tensorflow::StringPiece arg,tensorflow::StringPiece flag,const std::function<bool (int32)> & hook,bool * value_parsing_ok)43 bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
44 const std::function<bool(int32)>& hook,
45 bool* value_parsing_ok) {
46 *value_parsing_ok = true;
47 if (absl::ConsumePrefix(&arg, "--") && absl::ConsumePrefix(&arg, flag) &&
48 absl::ConsumePrefix(&arg, "=")) {
49 char extra;
50 int32 parsed_int32;
51 if (sscanf(arg.data(), "%d%c", &parsed_int32, &extra) != 1) {
52 LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
53 << ".";
54 *value_parsing_ok = false;
55 } else {
56 *value_parsing_ok = hook(parsed_int32);
57 }
58 return true;
59 }
60
61 return false;
62 }
63
ParseInt64Flag(tensorflow::StringPiece arg,tensorflow::StringPiece flag,const std::function<bool (int64)> & hook,bool * value_parsing_ok)64 bool ParseInt64Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
65 const std::function<bool(int64)>& hook,
66 bool* value_parsing_ok) {
67 *value_parsing_ok = true;
68 if (absl::ConsumePrefix(&arg, "--") && absl::ConsumePrefix(&arg, flag) &&
69 absl::ConsumePrefix(&arg, "=")) {
70 char extra;
71 int64_t parsed_int64;
72 if (sscanf(arg.data(), "%" SCNd64 "%c", &parsed_int64, &extra) != 1) {
73 LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
74 << ".";
75 *value_parsing_ok = false;
76 } else {
77 *value_parsing_ok = hook(parsed_int64);
78 }
79 return true;
80 }
81
82 return false;
83 }
84
ParseBoolFlag(tensorflow::StringPiece arg,tensorflow::StringPiece flag,const std::function<bool (bool)> & hook,bool * value_parsing_ok)85 bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
86 const std::function<bool(bool)>& hook,
87 bool* value_parsing_ok) {
88 *value_parsing_ok = true;
89 if (absl::ConsumePrefix(&arg, "--") && absl::ConsumePrefix(&arg, flag)) {
90 if (arg.empty()) {
91 *value_parsing_ok = hook(true);
92 return true;
93 }
94
95 if (arg == "=true") {
96 *value_parsing_ok = hook(true);
97 return true;
98 } else if (arg == "=false") {
99 *value_parsing_ok = hook(false);
100 return true;
101 } else {
102 LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
103 << ".";
104 *value_parsing_ok = false;
105 return true;
106 }
107 }
108
109 return false;
110 }
111
ParseFloatFlag(tensorflow::StringPiece arg,tensorflow::StringPiece flag,const std::function<bool (float)> & hook,bool * value_parsing_ok)112 bool ParseFloatFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
113 const std::function<bool(float)>& hook,
114 bool* value_parsing_ok) {
115 *value_parsing_ok = true;
116 if (absl::ConsumePrefix(&arg, "--") && absl::ConsumePrefix(&arg, flag) &&
117 absl::ConsumePrefix(&arg, "=")) {
118 char extra;
119 float parsed_float;
120 if (sscanf(arg.data(), "%f%c", &parsed_float, &extra) != 1) {
121 LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
122 << ".";
123 *value_parsing_ok = false;
124 } else {
125 *value_parsing_ok = hook(parsed_float);
126 }
127 return true;
128 }
129
130 return false;
131 }
132
133 } // namespace
134
Flag(const char * name,tensorflow::int32 * dst,const string & usage_text,bool * dst_updated)135 Flag::Flag(const char* name, tensorflow::int32* dst, const string& usage_text,
136 bool* dst_updated)
137 : name_(name),
138 type_(TYPE_INT32),
139 int32_hook_([dst, dst_updated](int32 value) {
140 *dst = value;
141 if (dst_updated) *dst_updated = true;
142 return true;
143 }),
144 int32_default_for_display_(*dst),
145 usage_text_(usage_text) {}
146
Flag(const char * name,tensorflow::int64 * dst,const string & usage_text,bool * dst_updated)147 Flag::Flag(const char* name, tensorflow::int64* dst, const string& usage_text,
148 bool* dst_updated)
149 : name_(name),
150 type_(TYPE_INT64),
151 int64_hook_([dst, dst_updated](int64 value) {
152 *dst = value;
153 if (dst_updated) *dst_updated = true;
154 return true;
155 }),
156 int64_default_for_display_(*dst),
157 usage_text_(usage_text) {}
158
Flag(const char * name,float * dst,const string & usage_text,bool * dst_updated)159 Flag::Flag(const char* name, float* dst, const string& usage_text,
160 bool* dst_updated)
161 : name_(name),
162 type_(TYPE_FLOAT),
163 float_hook_([dst, dst_updated](float value) {
164 *dst = value;
165 if (dst_updated) *dst_updated = true;
166 return true;
167 }),
168 float_default_for_display_(*dst),
169 usage_text_(usage_text) {}
170
Flag(const char * name,bool * dst,const string & usage_text,bool * dst_updated)171 Flag::Flag(const char* name, bool* dst, const string& usage_text,
172 bool* dst_updated)
173 : name_(name),
174 type_(TYPE_BOOL),
175 bool_hook_([dst, dst_updated](bool value) {
176 *dst = value;
177 if (dst_updated) *dst_updated = true;
178 return true;
179 }),
180 bool_default_for_display_(*dst),
181 usage_text_(usage_text) {}
182
Flag(const char * name,string * dst,const string & usage_text,bool * dst_updated)183 Flag::Flag(const char* name, string* dst, const string& usage_text,
184 bool* dst_updated)
185 : name_(name),
186 type_(TYPE_STRING),
187 string_hook_([dst, dst_updated](string value) {
188 *dst = std::move(value);
189 if (dst_updated) *dst_updated = true;
190 return true;
191 }),
192 string_default_for_display_(*dst),
193 usage_text_(usage_text) {}
194
Flag(const char * name,std::function<bool (int32)> int32_hook,int32 default_value_for_display,const string & usage_text)195 Flag::Flag(const char* name, std::function<bool(int32)> int32_hook,
196 int32 default_value_for_display, const string& usage_text)
197 : name_(name),
198 type_(TYPE_INT32),
199 int32_hook_(std::move(int32_hook)),
200 int32_default_for_display_(default_value_for_display),
201 usage_text_(usage_text) {}
202
Flag(const char * name,std::function<bool (int64)> int64_hook,int64 default_value_for_display,const string & usage_text)203 Flag::Flag(const char* name, std::function<bool(int64)> int64_hook,
204 int64 default_value_for_display, const string& usage_text)
205 : name_(name),
206 type_(TYPE_INT64),
207 int64_hook_(std::move(int64_hook)),
208 int64_default_for_display_(default_value_for_display),
209 usage_text_(usage_text) {}
210
Flag(const char * name,std::function<bool (float)> float_hook,float default_value_for_display,const string & usage_text)211 Flag::Flag(const char* name, std::function<bool(float)> float_hook,
212 float default_value_for_display, const string& usage_text)
213 : name_(name),
214 type_(TYPE_FLOAT),
215 float_hook_(std::move(float_hook)),
216 float_default_for_display_(default_value_for_display),
217 usage_text_(usage_text) {}
218
Flag(const char * name,std::function<bool (bool)> bool_hook,bool default_value_for_display,const string & usage_text)219 Flag::Flag(const char* name, std::function<bool(bool)> bool_hook,
220 bool default_value_for_display, const string& usage_text)
221 : name_(name),
222 type_(TYPE_BOOL),
223 bool_hook_(std::move(bool_hook)),
224 bool_default_for_display_(default_value_for_display),
225 usage_text_(usage_text) {}
226
Flag(const char * name,std::function<bool (string)> string_hook,string default_value_for_display,const string & usage_text)227 Flag::Flag(const char* name, std::function<bool(string)> string_hook,
228 string default_value_for_display, const string& usage_text)
229 : name_(name),
230 type_(TYPE_STRING),
231 string_hook_(std::move(string_hook)),
232 string_default_for_display_(std::move(default_value_for_display)),
233 usage_text_(usage_text) {}
234
Parse(string arg,bool * value_parsing_ok) const235 bool Flag::Parse(string arg, bool* value_parsing_ok) const {
236 bool result = false;
237 if (type_ == TYPE_INT32) {
238 result = ParseInt32Flag(arg, name_, int32_hook_, value_parsing_ok);
239 } else if (type_ == TYPE_INT64) {
240 result = ParseInt64Flag(arg, name_, int64_hook_, value_parsing_ok);
241 } else if (type_ == TYPE_BOOL) {
242 result = ParseBoolFlag(arg, name_, bool_hook_, value_parsing_ok);
243 } else if (type_ == TYPE_STRING) {
244 result = ParseStringFlag(arg, name_, string_hook_, value_parsing_ok);
245 } else if (type_ == TYPE_FLOAT) {
246 result = ParseFloatFlag(arg, name_, float_hook_, value_parsing_ok);
247 }
248 return result;
249 }
250
Parse(int * argc,char ** argv,const std::vector<Flag> & flag_list)251 /*static*/ bool Flags::Parse(int* argc, char** argv,
252 const std::vector<Flag>& flag_list) {
253 bool result = true;
254 std::vector<char*> unknown_flags;
255 for (int i = 1; i < *argc; ++i) {
256 if (string(argv[i]) == "--") {
257 while (i < *argc) {
258 unknown_flags.push_back(argv[i]);
259 ++i;
260 }
261 break;
262 }
263
264 bool was_found = false;
265 for (const Flag& flag : flag_list) {
266 bool value_parsing_ok;
267 was_found = flag.Parse(argv[i], &value_parsing_ok);
268 if (!value_parsing_ok) {
269 result = false;
270 }
271 if (was_found) {
272 break;
273 }
274 }
275 if (!was_found) {
276 unknown_flags.push_back(argv[i]);
277 }
278 }
279 // Passthrough any extra flags.
280 int dst = 1; // Skip argv[0]
281 for (char* f : unknown_flags) {
282 argv[dst++] = f;
283 }
284 argv[dst++] = nullptr;
285 *argc = unknown_flags.size() + 1;
286 return result && (*argc < 2 || strcmp(argv[1], "--help") != 0);
287 }
288
Usage(const string & cmdline,const std::vector<Flag> & flag_list)289 /*static*/ string Flags::Usage(const string& cmdline,
290 const std::vector<Flag>& flag_list) {
291 string usage_text;
292 if (!flag_list.empty()) {
293 strings::Appendf(&usage_text, "usage: %s\nFlags:\n", cmdline.c_str());
294 } else {
295 strings::Appendf(&usage_text, "usage: %s\n", cmdline.c_str());
296 }
297 for (const Flag& flag : flag_list) {
298 const char* type_name = "";
299 string flag_string;
300 if (flag.type_ == Flag::TYPE_INT32) {
301 type_name = "int32";
302 flag_string = strings::Printf("--%s=%d", flag.name_.c_str(),
303 flag.int32_default_for_display_);
304 } else if (flag.type_ == Flag::TYPE_INT64) {
305 type_name = "int64";
306 flag_string = strings::Printf(
307 "--%s=%lld", flag.name_.c_str(),
308 static_cast<long long>(flag.int64_default_for_display_));
309 } else if (flag.type_ == Flag::TYPE_BOOL) {
310 type_name = "bool";
311 flag_string =
312 strings::Printf("--%s=%s", flag.name_.c_str(),
313 flag.bool_default_for_display_ ? "true" : "false");
314 } else if (flag.type_ == Flag::TYPE_STRING) {
315 type_name = "string";
316 flag_string = strings::Printf("--%s=\"%s\"", flag.name_.c_str(),
317 flag.string_default_for_display_.c_str());
318 } else if (flag.type_ == Flag::TYPE_FLOAT) {
319 type_name = "float";
320 flag_string = strings::Printf("--%s=%f", flag.name_.c_str(),
321 flag.float_default_for_display_);
322 }
323 strings::Appendf(&usage_text, "\t%-33s\t%s\t%s\n", flag_string.c_str(),
324 type_name, flag.usage_text_.c_str());
325 }
326 return usage_text;
327 }
328
329 } // namespace tensorflow
330