1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/model_cmdline_flags.h"
16
17 #include <string>
18 #include <vector>
19
20 #include "absl/strings/numbers.h"
21 #include "absl/strings/str_join.h"
22 #include "absl/strings/str_split.h"
23 #include "absl/strings/string_view.h"
24 #include "absl/strings/strip.h"
25 #include "tensorflow/lite/toco/args.h"
26 #include "tensorflow/lite/toco/toco_graphviz_dump_options.h"
27 #include "tensorflow/lite/toco/toco_port.h"
28 #include "tensorflow/core/platform/logging.h"
29 #include "tensorflow/core/util/command_line_flags.h"
30
31 // "batch" flag only exists internally
32 #ifdef PLATFORM_GOOGLE
33 #include "base/commandlineflags.h"
34 #endif
35
36 namespace toco {
37
ParseModelFlagsFromCommandLineFlags(int * argc,char * argv[],std::string * msg,ParsedModelFlags * parsed_model_flags_ptr)38 bool ParseModelFlagsFromCommandLineFlags(
39 int* argc, char* argv[], std::string* msg,
40 ParsedModelFlags* parsed_model_flags_ptr) {
41 ParsedModelFlags& parsed_flags = *parsed_model_flags_ptr;
42 using tensorflow::Flag;
43 std::vector<tensorflow::Flag> flags = {
44 Flag("input_array", parsed_flags.input_array.bind(),
45 parsed_flags.input_array.default_value(),
46 "Deprecated: use --input_arrays instead. Name of the input array. "
47 "If not specified, will try to read "
48 "that information from the input file."),
49 Flag("input_arrays", parsed_flags.input_arrays.bind(),
50 parsed_flags.input_arrays.default_value(),
51 "Names of the input arrays, comma-separated. If not specified, "
52 "will try to read that information from the input file."),
53 Flag("output_array", parsed_flags.output_array.bind(),
54 parsed_flags.output_array.default_value(),
55 "Deprecated: use --output_arrays instead. Name of the output array, "
56 "when specifying a unique output array. "
57 "If not specified, will try to read that information from the "
58 "input file."),
59 Flag("output_arrays", parsed_flags.output_arrays.bind(),
60 parsed_flags.output_arrays.default_value(),
61 "Names of the output arrays, comma-separated. "
62 "If not specified, will try to read "
63 "that information from the input file."),
64 Flag("input_shape", parsed_flags.input_shape.bind(),
65 parsed_flags.input_shape.default_value(),
66 "Deprecated: use --input_shapes instead. Input array shape. For "
67 "many models the shape takes the form "
68 "batch size, input array height, input array width, input array "
69 "depth."),
70 Flag("input_shapes", parsed_flags.input_shapes.bind(),
71 parsed_flags.input_shapes.default_value(),
72 "Shapes corresponding to --input_arrays, colon-separated. For "
73 "many models each shape takes the form batch size, input array "
74 "height, input array width, input array depth."),
75 Flag("batch_size", parsed_flags.batch_size.bind(),
76 parsed_flags.batch_size.default_value(),
77 "Deprecated. Batch size for the model. Replaces the first dimension "
78 "of an input size array if undefined. Use only with SavedModels "
79 "when --input_shapes flag is not specified. Always use "
80 "--input_shapes flag with frozen graphs."),
81 Flag("input_data_type", parsed_flags.input_data_type.bind(),
82 parsed_flags.input_data_type.default_value(),
83 "Deprecated: use --input_data_types instead. Input array type, if "
84 "not already provided in the graph. "
85 "Typically needs to be specified when passing arbitrary arrays "
86 "to --input_arrays."),
87 Flag("input_data_types", parsed_flags.input_data_types.bind(),
88 parsed_flags.input_data_types.default_value(),
89 "Input arrays types, comma-separated, if not already provided in "
90 "the graph. "
91 "Typically needs to be specified when passing arbitrary arrays "
92 "to --input_arrays."),
93 Flag("mean_value", parsed_flags.mean_value.bind(),
94 parsed_flags.mean_value.default_value(),
95 "Deprecated: use --mean_values instead. mean_value parameter for "
96 "image models, used to compute input "
97 "activations from input pixel data."),
98 Flag("mean_values", parsed_flags.mean_values.bind(),
99 parsed_flags.mean_values.default_value(),
100 "mean_values parameter for image models, comma-separated list of "
101 "doubles, used to compute input activations from input pixel "
102 "data. Each entry in the list should match an entry in "
103 "--input_arrays."),
104 Flag("std_value", parsed_flags.std_value.bind(),
105 parsed_flags.std_value.default_value(),
106 "Deprecated: use --std_values instead. std_value parameter for "
107 "image models, used to compute input "
108 "activations from input pixel data."),
109 Flag("std_values", parsed_flags.std_values.bind(),
110 parsed_flags.std_values.default_value(),
111 "std_value parameter for image models, comma-separated list of "
112 "doubles, used to compute input activations from input pixel "
113 "data. Each entry in the list should match an entry in "
114 "--input_arrays."),
115 Flag("variable_batch", parsed_flags.variable_batch.bind(),
116 parsed_flags.variable_batch.default_value(),
117 "If true, the model accepts an arbitrary batch size. Mutually "
118 "exclusive "
119 "with the 'batch' field: at most one of these two fields can be "
120 "set."),
121 Flag("rnn_states", parsed_flags.rnn_states.bind(),
122 parsed_flags.rnn_states.default_value(), ""),
123 Flag("model_checks", parsed_flags.model_checks.bind(),
124 parsed_flags.model_checks.default_value(),
125 "A list of model checks to be applied to verify the form of the "
126 "model. Applied after the graph transformations after import."),
127 Flag("dump_graphviz", parsed_flags.dump_graphviz.bind(),
128 parsed_flags.dump_graphviz.default_value(),
129 "Dump graphviz during LogDump call. If string is non-empty then "
130 "it defines path to dump, otherwise will skip dumping."),
131 Flag("dump_graphviz_video", parsed_flags.dump_graphviz_video.bind(),
132 parsed_flags.dump_graphviz_video.default_value(),
133 "If true, will dump graphviz at each "
134 "graph transformation, which may be used to generate a video."),
135 Flag("conversion_summary_dir", parsed_flags.conversion_summary_dir.bind(),
136 parsed_flags.conversion_summary_dir.default_value(),
137 "Local file directory to store the conversion logs."),
138 Flag("allow_nonexistent_arrays",
139 parsed_flags.allow_nonexistent_arrays.bind(),
140 parsed_flags.allow_nonexistent_arrays.default_value(),
141 "If true, will allow passing inexistent arrays in --input_arrays "
142 "and --output_arrays. This makes little sense, is only useful to "
143 "more easily get graph visualizations."),
144 Flag("allow_nonascii_arrays", parsed_flags.allow_nonascii_arrays.bind(),
145 parsed_flags.allow_nonascii_arrays.default_value(),
146 "If true, will allow passing non-ascii-printable characters in "
147 "--input_arrays and --output_arrays. By default (if false), only "
148 "ascii printable characters are allowed, i.e. character codes "
149 "ranging from 32 to 127. This is disallowed by default so as to "
150 "catch common copy-and-paste issues where invisible unicode "
151 "characters are unwittingly added to these strings."),
152 Flag(
153 "arrays_extra_info_file", parsed_flags.arrays_extra_info_file.bind(),
154 parsed_flags.arrays_extra_info_file.default_value(),
155 "Path to an optional file containing a serialized ArraysExtraInfo "
156 "proto allowing to pass extra information about arrays not specified "
157 "in the input model file, such as extra MinMax information."),
158 Flag("model_flags_file", parsed_flags.model_flags_file.bind(),
159 parsed_flags.model_flags_file.default_value(),
160 "Path to an optional file containing a serialized ModelFlags proto. "
161 "Options specified on the command line will override the values in "
162 "the proto."),
163 Flag("change_concat_input_ranges",
164 parsed_flags.change_concat_input_ranges.bind(),
165 parsed_flags.change_concat_input_ranges.default_value(),
166 "Boolean to change the behavior of min/max ranges for inputs and"
167 " output of the concat operators."),
168 };
169 bool asked_for_help =
170 *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
171 if (asked_for_help) {
172 *msg += tensorflow::Flags::Usage(argv[0], flags);
173 return false;
174 } else {
175 if (!tensorflow::Flags::Parse(argc, argv, flags)) return false;
176 }
177 auto& dump_options = *GraphVizDumpOptions::singleton();
178 dump_options.dump_graphviz_video = parsed_flags.dump_graphviz_video.value();
179 dump_options.dump_graphviz = parsed_flags.dump_graphviz.value();
180
181 return true;
182 }
183
ReadModelFlagsFromCommandLineFlags(const ParsedModelFlags & parsed_model_flags,ModelFlags * model_flags)184 void ReadModelFlagsFromCommandLineFlags(
185 const ParsedModelFlags& parsed_model_flags, ModelFlags* model_flags) {
186 toco::port::CheckInitGoogleIsDone("InitGoogle is not done yet");
187
188 // Load proto containing the initial model flags.
189 // Additional flags specified on the command line will overwrite the values.
190 if (parsed_model_flags.model_flags_file.specified()) {
191 std::string model_flags_file_contents;
192 QCHECK(port::file::GetContents(parsed_model_flags.model_flags_file.value(),
193 &model_flags_file_contents,
194 port::file::Defaults())
195 .ok())
196 << "Specified --model_flags_file="
197 << parsed_model_flags.model_flags_file.value()
198 << " was not found or could not be read";
199 QCHECK(ParseFromStringEitherTextOrBinary(model_flags_file_contents,
200 model_flags))
201 << "Specified --model_flags_file="
202 << parsed_model_flags.model_flags_file.value()
203 << " could not be parsed";
204 }
205
206 #ifdef PLATFORM_GOOGLE
207 CHECK(!((base::WasPresentOnCommandLine("batch") &&
208 parsed_model_flags.variable_batch.specified())))
209 << "The --batch and --variable_batch flags are mutually exclusive.";
210 #endif
211 CHECK(!(parsed_model_flags.output_array.specified() &&
212 parsed_model_flags.output_arrays.specified()))
213 << "The --output_array and --vs flags are mutually exclusive.";
214
215 if (parsed_model_flags.output_array.specified()) {
216 model_flags->add_output_arrays(parsed_model_flags.output_array.value());
217 }
218
219 if (parsed_model_flags.output_arrays.specified()) {
220 std::vector<std::string> output_arrays =
221 absl::StrSplit(parsed_model_flags.output_arrays.value(), ',');
222 for (const std::string& output_array : output_arrays) {
223 model_flags->add_output_arrays(output_array);
224 }
225 }
226
227 const bool uses_single_input_flags =
228 parsed_model_flags.input_array.specified() ||
229 parsed_model_flags.mean_value.specified() ||
230 parsed_model_flags.std_value.specified() ||
231 parsed_model_flags.input_shape.specified();
232
233 const bool uses_multi_input_flags =
234 parsed_model_flags.input_arrays.specified() ||
235 parsed_model_flags.mean_values.specified() ||
236 parsed_model_flags.std_values.specified() ||
237 parsed_model_flags.input_shapes.specified();
238
239 QCHECK(!(uses_single_input_flags && uses_multi_input_flags))
240 << "Use either the singular-form input flags (--input_array, "
241 "--input_shape, --mean_value, --std_value) or the plural form input "
242 "flags (--input_arrays, --input_shapes, --mean_values, --std_values), "
243 "but not both forms within the same command line.";
244
245 if (parsed_model_flags.input_array.specified()) {
246 QCHECK(uses_single_input_flags);
247 model_flags->add_input_arrays()->set_name(
248 parsed_model_flags.input_array.value());
249 }
250 if (parsed_model_flags.input_arrays.specified()) {
251 QCHECK(uses_multi_input_flags);
252 for (const auto& input_array :
253 absl::StrSplit(parsed_model_flags.input_arrays.value(), ',')) {
254 model_flags->add_input_arrays()->set_name(std::string(input_array));
255 }
256 }
257 if (parsed_model_flags.mean_value.specified()) {
258 QCHECK(uses_single_input_flags);
259 model_flags->mutable_input_arrays(0)->set_mean_value(
260 parsed_model_flags.mean_value.value());
261 }
262 if (parsed_model_flags.mean_values.specified()) {
263 QCHECK(uses_multi_input_flags);
264 std::vector<std::string> mean_values =
265 absl::StrSplit(parsed_model_flags.mean_values.value(), ',');
266 QCHECK(static_cast<int>(mean_values.size()) ==
267 model_flags->input_arrays_size());
268 for (size_t i = 0; i < mean_values.size(); ++i) {
269 char* last = nullptr;
270 model_flags->mutable_input_arrays(i)->set_mean_value(
271 strtod(mean_values[i].data(), &last));
272 CHECK(last != mean_values[i].data());
273 }
274 }
275 if (parsed_model_flags.std_value.specified()) {
276 QCHECK(uses_single_input_flags);
277 model_flags->mutable_input_arrays(0)->set_std_value(
278 parsed_model_flags.std_value.value());
279 }
280 if (parsed_model_flags.std_values.specified()) {
281 QCHECK(uses_multi_input_flags);
282 std::vector<std::string> std_values =
283 absl::StrSplit(parsed_model_flags.std_values.value(), ',');
284 QCHECK(static_cast<int>(std_values.size()) ==
285 model_flags->input_arrays_size());
286 for (size_t i = 0; i < std_values.size(); ++i) {
287 char* last = nullptr;
288 model_flags->mutable_input_arrays(i)->set_std_value(
289 strtod(std_values[i].data(), &last));
290 CHECK(last != std_values[i].data());
291 }
292 }
293 if (parsed_model_flags.input_data_type.specified()) {
294 QCHECK(uses_single_input_flags);
295 IODataType type;
296 QCHECK(IODataType_Parse(parsed_model_flags.input_data_type.value(), &type));
297 model_flags->mutable_input_arrays(0)->set_data_type(type);
298 }
299 if (parsed_model_flags.input_data_types.specified()) {
300 QCHECK(uses_multi_input_flags);
301 std::vector<std::string> input_data_types =
302 absl::StrSplit(parsed_model_flags.input_data_types.value(), ',');
303 QCHECK(static_cast<int>(input_data_types.size()) ==
304 model_flags->input_arrays_size());
305 for (size_t i = 0; i < input_data_types.size(); ++i) {
306 IODataType type;
307 QCHECK(IODataType_Parse(input_data_types[i], &type));
308 model_flags->mutable_input_arrays(i)->set_data_type(type);
309 }
310 }
311 if (parsed_model_flags.input_shape.specified()) {
312 QCHECK(uses_single_input_flags);
313 if (model_flags->input_arrays().empty()) {
314 model_flags->add_input_arrays();
315 }
316 auto* shape = model_flags->mutable_input_arrays(0)->mutable_shape();
317 shape->clear_dims();
318 const IntList& list = parsed_model_flags.input_shape.value();
319 for (auto& dim : list.elements) {
320 shape->add_dims(dim);
321 }
322 }
323 if (parsed_model_flags.input_shapes.specified()) {
324 QCHECK(uses_multi_input_flags);
325 std::vector<std::string> input_shapes =
326 absl::StrSplit(parsed_model_flags.input_shapes.value(), ':');
327 QCHECK(static_cast<int>(input_shapes.size()) ==
328 model_flags->input_arrays_size());
329 for (size_t i = 0; i < input_shapes.size(); ++i) {
330 auto* shape = model_flags->mutable_input_arrays(i)->mutable_shape();
331 shape->clear_dims();
332 // Treat an empty input shape as a scalar.
333 if (input_shapes[i].empty()) {
334 continue;
335 }
336 for (const auto& dim_str : absl::StrSplit(input_shapes[i], ',')) {
337 int size;
338 CHECK(absl::SimpleAtoi(dim_str, &size))
339 << "Failed to parse input_shape: " << input_shapes[i];
340 shape->add_dims(size);
341 }
342 }
343 }
344
345 #define READ_MODEL_FLAG(name) \
346 do { \
347 if (parsed_model_flags.name.specified()) { \
348 model_flags->set_##name(parsed_model_flags.name.value()); \
349 } \
350 } while (false)
351
352 READ_MODEL_FLAG(variable_batch);
353
354 #undef READ_MODEL_FLAG
355
356 for (const auto& element : parsed_model_flags.rnn_states.value().elements) {
357 auto* rnn_state_proto = model_flags->add_rnn_states();
358 for (const auto& kv_pair : element) {
359 const std::string& key = kv_pair.first;
360 const std::string& value = kv_pair.second;
361 if (key == "state_array") {
362 rnn_state_proto->set_state_array(value);
363 } else if (key == "back_edge_source_array") {
364 rnn_state_proto->set_back_edge_source_array(value);
365 } else if (key == "size") {
366 int32 size = 0;
367 CHECK(absl::SimpleAtoi(value, &size));
368 CHECK_GT(size, 0);
369 rnn_state_proto->set_size(size);
370 } else {
371 LOG(FATAL) << "Unknown key '" << key << "' in --rnn_states";
372 }
373 }
374 CHECK(rnn_state_proto->has_state_array() &&
375 rnn_state_proto->has_back_edge_source_array() &&
376 rnn_state_proto->has_size())
377 << "--rnn_states must include state_array, back_edge_source_array and "
378 "size.";
379 }
380
381 for (const auto& element : parsed_model_flags.model_checks.value().elements) {
382 auto* model_check_proto = model_flags->add_model_checks();
383 for (const auto& kv_pair : element) {
384 const std::string& key = kv_pair.first;
385 const std::string& value = kv_pair.second;
386 if (key == "count_type") {
387 model_check_proto->set_count_type(value);
388 } else if (key == "count_min") {
389 int32 count = 0;
390 CHECK(absl::SimpleAtoi(value, &count));
391 CHECK_GE(count, -1);
392 model_check_proto->set_count_min(count);
393 } else if (key == "count_max") {
394 int32 count = 0;
395 CHECK(absl::SimpleAtoi(value, &count));
396 CHECK_GE(count, -1);
397 model_check_proto->set_count_max(count);
398 } else {
399 LOG(FATAL) << "Unknown key '" << key << "' in --model_checks";
400 }
401 }
402 }
403
404 if (!model_flags->has_allow_nonascii_arrays()) {
405 model_flags->set_allow_nonascii_arrays(
406 parsed_model_flags.allow_nonascii_arrays.value());
407 }
408 if (!model_flags->has_allow_nonexistent_arrays()) {
409 model_flags->set_allow_nonexistent_arrays(
410 parsed_model_flags.allow_nonexistent_arrays.value());
411 }
412 if (!model_flags->has_change_concat_input_ranges()) {
413 model_flags->set_change_concat_input_ranges(
414 parsed_model_flags.change_concat_input_ranges.value());
415 }
416
417 if (parsed_model_flags.arrays_extra_info_file.specified()) {
418 std::string arrays_extra_info_file_contents;
419 CHECK(port::file::GetContents(
420 parsed_model_flags.arrays_extra_info_file.value(),
421 &arrays_extra_info_file_contents, port::file::Defaults())
422 .ok());
423 ParseFromStringEitherTextOrBinary(arrays_extra_info_file_contents,
424 model_flags->mutable_arrays_extra_info());
425 }
426 }
427
UncheckedGlobalParsedModelFlags(bool must_already_exist)428 ParsedModelFlags* UncheckedGlobalParsedModelFlags(bool must_already_exist) {
429 static auto* flags = [must_already_exist]() {
430 if (must_already_exist) {
431 fprintf(stderr, __FILE__
432 ":"
433 "GlobalParsedModelFlags() used without initialization\n");
434 fflush(stderr);
435 abort();
436 }
437 return new toco::ParsedModelFlags;
438 }();
439 return flags;
440 }
441
GlobalParsedModelFlags()442 ParsedModelFlags* GlobalParsedModelFlags() {
443 return UncheckedGlobalParsedModelFlags(true);
444 }
445
ParseModelFlagsOrDie(int * argc,char * argv[])446 void ParseModelFlagsOrDie(int* argc, char* argv[]) {
447 // TODO(aselle): in the future allow Google version to use
448 // flags, and only use this mechanism for open source
449 auto* flags = UncheckedGlobalParsedModelFlags(false);
450 std::string msg;
451 bool model_success =
452 toco::ParseModelFlagsFromCommandLineFlags(argc, argv, &msg, flags);
453 if (!model_success || !msg.empty()) {
454 // Log in non-standard way since this happens pre InitGoogle.
455 fprintf(stderr, "%s", msg.c_str());
456 fflush(stderr);
457 abort();
458 }
459 }
460
461 } // namespace toco
462