1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/model_cmdline_flags.h"
16
17 #include <string>
18 #include <vector>
19
20 #include "absl/strings/numbers.h"
21 #include "absl/strings/str_join.h"
22 #include "absl/strings/str_split.h"
23 #include "absl/strings/string_view.h"
24 #include "absl/strings/strip.h"
25 #include "tensorflow/lite/toco/args.h"
26 #include "tensorflow/lite/toco/toco_graphviz_dump_options.h"
27 #include "tensorflow/lite/toco/toco_port.h"
28 #include "tensorflow/core/platform/logging.h"
29 #include "tensorflow/core/util/command_line_flags.h"
30
31 // "batch" flag only exists internally
32 #ifdef PLATFORM_GOOGLE
33 #include "base/commandlineflags.h"
34 #endif
35
36 namespace toco {
37
ParseModelFlagsFromCommandLineFlags(int * argc,char * argv[],string * msg,ParsedModelFlags * parsed_model_flags_ptr)38 bool ParseModelFlagsFromCommandLineFlags(
39 int* argc, char* argv[], string* msg,
40 ParsedModelFlags* parsed_model_flags_ptr) {
41 ParsedModelFlags& parsed_flags = *parsed_model_flags_ptr;
42 using tensorflow::Flag;
43 std::vector<tensorflow::Flag> flags = {
44 Flag("input_array", parsed_flags.input_array.bind(),
45 parsed_flags.input_array.default_value(),
46 "Deprecated: use --input_arrays instead. Name of the input array. "
47 "If not specified, will try to read "
48 "that information from the input file."),
49 Flag("input_arrays", parsed_flags.input_arrays.bind(),
50 parsed_flags.input_arrays.default_value(),
51 "Names of the input arrays, comma-separated. If not specified, "
52 "will try to read that information from the input file."),
53 Flag("output_array", parsed_flags.output_array.bind(),
54 parsed_flags.output_array.default_value(),
55 "Deprecated: use --output_arrays instead. Name of the output array, "
56 "when specifying a unique output array. "
57 "If not specified, will try to read that information from the "
58 "input file."),
59 Flag("output_arrays", parsed_flags.output_arrays.bind(),
60 parsed_flags.output_arrays.default_value(),
61 "Names of the output arrays, comma-separated. "
62 "If not specified, will try to read "
63 "that information from the input file."),
64 Flag("input_shape", parsed_flags.input_shape.bind(),
65 parsed_flags.input_shape.default_value(),
66 "Deprecated: use --input_shapes instead. Input array shape. For "
67 "many models the shape takes the form "
68 "batch size, input array height, input array width, input array "
69 "depth."),
70 Flag("input_shapes", parsed_flags.input_shapes.bind(),
71 parsed_flags.input_shapes.default_value(),
72 "Shapes corresponding to --input_arrays, colon-separated. For "
73 "many models each shape takes the form batch size, input array "
74 "height, input array width, input array depth."),
75 Flag("batch_size", parsed_flags.batch_size.bind(),
76 parsed_flags.batch_size.default_value(),
77 "Deprecated. Batch size for the model. Replaces the first dimension "
78 "of an input size array if undefined. Use only with SavedModels "
79 "when --input_shapes flag is not specified. Always use "
80 "--input_shapes flag with frozen graphs."),
81 Flag("input_data_type", parsed_flags.input_data_type.bind(),
82 parsed_flags.input_data_type.default_value(),
83 "Deprecated: use --input_data_types instead. Input array type, if "
84 "not already provided in the graph. "
85 "Typically needs to be specified when passing arbitrary arrays "
86 "to --input_arrays."),
87 Flag("input_data_types", parsed_flags.input_data_types.bind(),
88 parsed_flags.input_data_types.default_value(),
89 "Input arrays types, comma-separated, if not already provided in "
90 "the graph. "
91 "Typically needs to be specified when passing arbitrary arrays "
92 "to --input_arrays."),
93 Flag("mean_value", parsed_flags.mean_value.bind(),
94 parsed_flags.mean_value.default_value(),
95 "Deprecated: use --mean_values instead. mean_value parameter for "
96 "image models, used to compute input "
97 "activations from input pixel data."),
98 Flag("mean_values", parsed_flags.mean_values.bind(),
99 parsed_flags.mean_values.default_value(),
100 "mean_values parameter for image models, comma-separated list of "
101 "doubles, used to compute input activations from input pixel "
102 "data. Each entry in the list should match an entry in "
103 "--input_arrays."),
104 Flag("std_value", parsed_flags.std_value.bind(),
105 parsed_flags.std_value.default_value(),
106 "Deprecated: use --std_values instead. std_value parameter for "
107 "image models, used to compute input "
108 "activations from input pixel data."),
109 Flag("std_values", parsed_flags.std_values.bind(),
110 parsed_flags.std_values.default_value(),
111 "std_value parameter for image models, comma-separated list of "
112 "doubles, used to compute input activations from input pixel "
113 "data. Each entry in the list should match an entry in "
114 "--input_arrays."),
115 Flag("variable_batch", parsed_flags.variable_batch.bind(),
116 parsed_flags.variable_batch.default_value(),
117 "If true, the model accepts an arbitrary batch size. Mutually "
118 "exclusive "
119 "with the 'batch' field: at most one of these two fields can be "
120 "set."),
121 Flag("rnn_states", parsed_flags.rnn_states.bind(),
122 parsed_flags.rnn_states.default_value(), ""),
123 Flag("model_checks", parsed_flags.model_checks.bind(),
124 parsed_flags.model_checks.default_value(),
125 "A list of model checks to be applied to verify the form of the "
126 "model. Applied after the graph transformations after import."),
127 Flag("dump_graphviz", parsed_flags.dump_graphviz.bind(),
128 parsed_flags.dump_graphviz.default_value(),
129 "Dump graphviz during LogDump call. If string is non-empty then "
130 "it defines path to dump, otherwise will skip dumping."),
131 Flag("dump_graphviz_video", parsed_flags.dump_graphviz_video.bind(),
132 parsed_flags.dump_graphviz_video.default_value(),
133 "If true, will dump graphviz at each "
134 "graph transformation, which may be used to generate a video."),
135 Flag("conversion_summary_dir", parsed_flags.conversion_summary_dir.bind(),
136 parsed_flags.conversion_summary_dir.default_value(),
137 "Local file directory to store the conversion logs."),
138 Flag("allow_nonexistent_arrays",
139 parsed_flags.allow_nonexistent_arrays.bind(),
140 parsed_flags.allow_nonexistent_arrays.default_value(),
141 "If true, will allow passing inexistent arrays in --input_arrays "
142 "and --output_arrays. This makes little sense, is only useful to "
143 "more easily get graph visualizations."),
144 Flag("allow_nonascii_arrays", parsed_flags.allow_nonascii_arrays.bind(),
145 parsed_flags.allow_nonascii_arrays.default_value(),
146 "If true, will allow passing non-ascii-printable characters in "
147 "--input_arrays and --output_arrays. By default (if false), only "
148 "ascii printable characters are allowed, i.e. character codes "
149 "ranging from 32 to 127. This is disallowed by default so as to "
150 "catch common copy-and-paste issues where invisible unicode "
151 "characters are unwittingly added to these strings."),
152 Flag(
153 "arrays_extra_info_file", parsed_flags.arrays_extra_info_file.bind(),
154 parsed_flags.arrays_extra_info_file.default_value(),
155 "Path to an optional file containing a serialized ArraysExtraInfo "
156 "proto allowing to pass extra information about arrays not specified "
157 "in the input model file, such as extra MinMax information."),
158 Flag("model_flags_file", parsed_flags.model_flags_file.bind(),
159 parsed_flags.model_flags_file.default_value(),
160 "Path to an optional file containing a serialized ModelFlags proto. "
161 "Options specified on the command line will override the values in "
162 "the proto."),
163 Flag("change_concat_input_ranges",
164 parsed_flags.change_concat_input_ranges.bind(),
165 parsed_flags.change_concat_input_ranges.default_value(),
166 "Boolean to change the behavior of min/max ranges for inputs and"
167 " output of the concat operators."),
168 };
169 bool asked_for_help =
170 *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
171 if (asked_for_help) {
172 *msg += tensorflow::Flags::Usage(argv[0], flags);
173 return false;
174 } else {
175 if (!tensorflow::Flags::Parse(argc, argv, flags)) return false;
176 }
177 auto& dump_options = *GraphVizDumpOptions::singleton();
178 dump_options.dump_graphviz_video = parsed_flags.dump_graphviz_video.value();
179 dump_options.dump_graphviz = parsed_flags.dump_graphviz.value();
180
181 return true;
182 }
183
ReadModelFlagsFromCommandLineFlags(const ParsedModelFlags & parsed_model_flags,ModelFlags * model_flags)184 void ReadModelFlagsFromCommandLineFlags(
185 const ParsedModelFlags& parsed_model_flags, ModelFlags* model_flags) {
186 toco::port::CheckInitGoogleIsDone("InitGoogle is not done yet");
187
188 // Load proto containing the initial model flags.
189 // Additional flags specified on the command line will overwrite the values.
190 if (parsed_model_flags.model_flags_file.specified()) {
191 string model_flags_file_contents;
192 QCHECK(port::file::GetContents(parsed_model_flags.model_flags_file.value(),
193 &model_flags_file_contents,
194 port::file::Defaults())
195 .ok())
196 << "Specified --model_flags_file="
197 << parsed_model_flags.model_flags_file.value()
198 << " was not found or could not be read";
199 QCHECK(ParseFromStringEitherTextOrBinary(model_flags_file_contents,
200 model_flags))
201 << "Specified --model_flags_file="
202 << parsed_model_flags.model_flags_file.value()
203 << " could not be parsed";
204 }
205
206 #ifdef PLATFORM_GOOGLE
207 CHECK(!((base::SpecifiedOnCommandLine("batch") &&
208 parsed_model_flags.variable_batch.specified())))
209 << "The --batch and --variable_batch flags are mutually exclusive.";
210 #endif
211 CHECK(!(parsed_model_flags.output_array.specified() &&
212 parsed_model_flags.output_arrays.specified()))
213 << "The --output_array and --vs flags are mutually exclusive.";
214
215 if (parsed_model_flags.output_array.specified()) {
216 model_flags->add_output_arrays(parsed_model_flags.output_array.value());
217 }
218
219 if (parsed_model_flags.output_arrays.specified()) {
220 std::vector<string> output_arrays =
221 absl::StrSplit(parsed_model_flags.output_arrays.value(), ',');
222 for (const string& output_array : output_arrays) {
223 model_flags->add_output_arrays(output_array);
224 }
225 }
226
227 const bool uses_single_input_flags =
228 parsed_model_flags.input_array.specified() ||
229 parsed_model_flags.mean_value.specified() ||
230 parsed_model_flags.std_value.specified() ||
231 parsed_model_flags.input_shape.specified();
232
233 const bool uses_multi_input_flags =
234 parsed_model_flags.input_arrays.specified() ||
235 parsed_model_flags.mean_values.specified() ||
236 parsed_model_flags.std_values.specified() ||
237 parsed_model_flags.input_shapes.specified();
238
239 QCHECK(!(uses_single_input_flags && uses_multi_input_flags))
240 << "Use either the singular-form input flags (--input_array, "
241 "--input_shape, --mean_value, --std_value) or the plural form input "
242 "flags (--input_arrays, --input_shapes, --mean_values, --std_values), "
243 "but not both forms within the same command line.";
244
245 if (parsed_model_flags.input_array.specified()) {
246 QCHECK(uses_single_input_flags);
247 model_flags->add_input_arrays()->set_name(
248 parsed_model_flags.input_array.value());
249 }
250 if (parsed_model_flags.input_arrays.specified()) {
251 QCHECK(uses_multi_input_flags);
252 for (const auto& input_array :
253 absl::StrSplit(parsed_model_flags.input_arrays.value(), ',')) {
254 model_flags->add_input_arrays()->set_name(string(input_array));
255 }
256 }
257 if (parsed_model_flags.mean_value.specified()) {
258 QCHECK(uses_single_input_flags);
259 model_flags->mutable_input_arrays(0)->set_mean_value(
260 parsed_model_flags.mean_value.value());
261 }
262 if (parsed_model_flags.mean_values.specified()) {
263 QCHECK(uses_multi_input_flags);
264 std::vector<string> mean_values =
265 absl::StrSplit(parsed_model_flags.mean_values.value(), ',');
266 QCHECK(mean_values.size() == model_flags->input_arrays_size());
267 for (size_t i = 0; i < mean_values.size(); ++i) {
268 char* last = nullptr;
269 model_flags->mutable_input_arrays(i)->set_mean_value(
270 strtod(mean_values[i].data(), &last));
271 CHECK(last != mean_values[i].data());
272 }
273 }
274 if (parsed_model_flags.std_value.specified()) {
275 QCHECK(uses_single_input_flags);
276 model_flags->mutable_input_arrays(0)->set_std_value(
277 parsed_model_flags.std_value.value());
278 }
279 if (parsed_model_flags.std_values.specified()) {
280 QCHECK(uses_multi_input_flags);
281 std::vector<string> std_values =
282 absl::StrSplit(parsed_model_flags.std_values.value(), ',');
283 QCHECK(std_values.size() == model_flags->input_arrays_size());
284 for (size_t i = 0; i < std_values.size(); ++i) {
285 char* last = nullptr;
286 model_flags->mutable_input_arrays(i)->set_std_value(
287 strtod(std_values[i].data(), &last));
288 CHECK(last != std_values[i].data());
289 }
290 }
291 if (parsed_model_flags.input_data_type.specified()) {
292 QCHECK(uses_single_input_flags);
293 IODataType type;
294 QCHECK(IODataType_Parse(parsed_model_flags.input_data_type.value(), &type));
295 model_flags->mutable_input_arrays(0)->set_data_type(type);
296 }
297 if (parsed_model_flags.input_data_types.specified()) {
298 QCHECK(uses_multi_input_flags);
299 std::vector<string> input_data_types =
300 absl::StrSplit(parsed_model_flags.input_data_types.value(), ',');
301 QCHECK(input_data_types.size() == model_flags->input_arrays_size());
302 for (size_t i = 0; i < input_data_types.size(); ++i) {
303 IODataType type;
304 QCHECK(IODataType_Parse(input_data_types[i], &type));
305 model_flags->mutable_input_arrays(i)->set_data_type(type);
306 }
307 }
308 if (parsed_model_flags.input_shape.specified()) {
309 QCHECK(uses_single_input_flags);
310 if (model_flags->input_arrays().empty()) {
311 model_flags->add_input_arrays();
312 }
313 auto* shape = model_flags->mutable_input_arrays(0)->mutable_shape();
314 shape->clear_dims();
315 const IntList& list = parsed_model_flags.input_shape.value();
316 for (auto& dim : list.elements) {
317 shape->add_dims(dim);
318 }
319 }
320 if (parsed_model_flags.input_shapes.specified()) {
321 QCHECK(uses_multi_input_flags);
322 std::vector<string> input_shapes =
323 absl::StrSplit(parsed_model_flags.input_shapes.value(), ':');
324 QCHECK(input_shapes.size() == model_flags->input_arrays_size());
325 for (size_t i = 0; i < input_shapes.size(); ++i) {
326 auto* shape = model_flags->mutable_input_arrays(i)->mutable_shape();
327 shape->clear_dims();
328 // Treat an empty input shape as a scalar.
329 if (input_shapes[i].empty()) {
330 continue;
331 }
332 for (const auto& dim_str : absl::StrSplit(input_shapes[i], ',')) {
333 int size;
334 CHECK(absl::SimpleAtoi(dim_str, &size))
335 << "Failed to parse input_shape: " << input_shapes[i];
336 shape->add_dims(size);
337 }
338 }
339 }
340
341 #define READ_MODEL_FLAG(name) \
342 do { \
343 if (parsed_model_flags.name.specified()) { \
344 model_flags->set_##name(parsed_model_flags.name.value()); \
345 } \
346 } while (false)
347
348 READ_MODEL_FLAG(variable_batch);
349
350 #undef READ_MODEL_FLAG
351
352 for (const auto& element : parsed_model_flags.rnn_states.value().elements) {
353 auto* rnn_state_proto = model_flags->add_rnn_states();
354 for (const auto& kv_pair : element) {
355 const string& key = kv_pair.first;
356 const string& value = kv_pair.second;
357 if (key == "state_array") {
358 rnn_state_proto->set_state_array(value);
359 } else if (key == "back_edge_source_array") {
360 rnn_state_proto->set_back_edge_source_array(value);
361 } else if (key == "size") {
362 int32 size = 0;
363 CHECK(absl::SimpleAtoi(value, &size));
364 CHECK_GT(size, 0);
365 rnn_state_proto->set_size(size);
366 } else {
367 LOG(FATAL) << "Unknown key '" << key << "' in --rnn_states";
368 }
369 }
370 CHECK(rnn_state_proto->has_state_array() &&
371 rnn_state_proto->has_back_edge_source_array() &&
372 rnn_state_proto->has_size())
373 << "--rnn_states must include state_array, back_edge_source_array and "
374 "size.";
375 }
376
377 for (const auto& element : parsed_model_flags.model_checks.value().elements) {
378 auto* model_check_proto = model_flags->add_model_checks();
379 for (const auto& kv_pair : element) {
380 const string& key = kv_pair.first;
381 const string& value = kv_pair.second;
382 if (key == "count_type") {
383 model_check_proto->set_count_type(value);
384 } else if (key == "count_min") {
385 int32 count = 0;
386 CHECK(absl::SimpleAtoi(value, &count));
387 CHECK_GE(count, -1);
388 model_check_proto->set_count_min(count);
389 } else if (key == "count_max") {
390 int32 count = 0;
391 CHECK(absl::SimpleAtoi(value, &count));
392 CHECK_GE(count, -1);
393 model_check_proto->set_count_max(count);
394 } else {
395 LOG(FATAL) << "Unknown key '" << key << "' in --model_checks";
396 }
397 }
398 }
399
400 if (!model_flags->has_allow_nonascii_arrays()) {
401 model_flags->set_allow_nonascii_arrays(
402 parsed_model_flags.allow_nonascii_arrays.value());
403 }
404 if (!model_flags->has_allow_nonexistent_arrays()) {
405 model_flags->set_allow_nonexistent_arrays(
406 parsed_model_flags.allow_nonexistent_arrays.value());
407 }
408 if (!model_flags->has_change_concat_input_ranges()) {
409 model_flags->set_change_concat_input_ranges(
410 parsed_model_flags.change_concat_input_ranges.value());
411 }
412
413 if (parsed_model_flags.arrays_extra_info_file.specified()) {
414 string arrays_extra_info_file_contents;
415 CHECK(port::file::GetContents(
416 parsed_model_flags.arrays_extra_info_file.value(),
417 &arrays_extra_info_file_contents, port::file::Defaults())
418 .ok());
419 ParseFromStringEitherTextOrBinary(arrays_extra_info_file_contents,
420 model_flags->mutable_arrays_extra_info());
421 }
422 }
423
UncheckedGlobalParsedModelFlags(bool must_already_exist)424 ParsedModelFlags* UncheckedGlobalParsedModelFlags(bool must_already_exist) {
425 static auto* flags = [must_already_exist]() {
426 if (must_already_exist) {
427 fprintf(stderr, __FILE__
428 ":"
429 "GlobalParsedModelFlags() used without initialization\n");
430 fflush(stderr);
431 abort();
432 }
433 return new toco::ParsedModelFlags;
434 }();
435 return flags;
436 }
437
GlobalParsedModelFlags()438 ParsedModelFlags* GlobalParsedModelFlags() {
439 return UncheckedGlobalParsedModelFlags(true);
440 }
441
ParseModelFlagsOrDie(int * argc,char * argv[])442 void ParseModelFlagsOrDie(int* argc, char* argv[]) {
443 // TODO(aselle): in the future allow Google version to use
444 // flags, and only use this mechanism for open source
445 auto* flags = UncheckedGlobalParsedModelFlags(false);
446 string msg;
447 bool model_success =
448 toco::ParseModelFlagsFromCommandLineFlags(argc, argv, &msg, flags);
449 if (!model_success || !msg.empty()) {
450 // Log in non-standard way since this happens pre InitGoogle.
451 fprintf(stderr, "%s", msg.c_str());
452 fflush(stderr);
453 abort();
454 }
455 }
456
457 } // namespace toco
458