• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/model_cmdline_flags.h"
16 
17 #include <string>
18 #include <vector>
19 
20 #include "absl/strings/numbers.h"
21 #include "absl/strings/str_join.h"
22 #include "absl/strings/str_split.h"
23 #include "absl/strings/string_view.h"
24 #include "absl/strings/strip.h"
25 #include "tensorflow/lite/toco/args.h"
26 #include "tensorflow/lite/toco/toco_graphviz_dump_options.h"
27 #include "tensorflow/lite/toco/toco_port.h"
28 #include "tensorflow/core/platform/logging.h"
29 #include "tensorflow/core/util/command_line_flags.h"
30 
31 // "batch" flag only exists internally
32 #ifdef PLATFORM_GOOGLE
33 #include "base/commandlineflags.h"
34 #endif
35 
36 namespace toco {
37 
ParseModelFlagsFromCommandLineFlags(int * argc,char * argv[],string * msg,ParsedModelFlags * parsed_model_flags_ptr)38 bool ParseModelFlagsFromCommandLineFlags(
39     int* argc, char* argv[], string* msg,
40     ParsedModelFlags* parsed_model_flags_ptr) {
41   ParsedModelFlags& parsed_flags = *parsed_model_flags_ptr;
42   using tensorflow::Flag;
43   std::vector<tensorflow::Flag> flags = {
44       Flag("input_array", parsed_flags.input_array.bind(),
45            parsed_flags.input_array.default_value(),
46            "Deprecated: use --input_arrays instead. Name of the input array. "
47            "If not specified, will try to read "
48            "that information from the input file."),
49       Flag("input_arrays", parsed_flags.input_arrays.bind(),
50            parsed_flags.input_arrays.default_value(),
51            "Names of the input arrays, comma-separated. If not specified, "
52            "will try to read that information from the input file."),
53       Flag("output_array", parsed_flags.output_array.bind(),
54            parsed_flags.output_array.default_value(),
55            "Deprecated: use --output_arrays instead. Name of the output array, "
56            "when specifying a unique output array. "
57            "If not specified, will try to read that information from the "
58            "input file."),
59       Flag("output_arrays", parsed_flags.output_arrays.bind(),
60            parsed_flags.output_arrays.default_value(),
61            "Names of the output arrays, comma-separated. "
62            "If not specified, will try to read "
63            "that information from the input file."),
64       Flag("input_shape", parsed_flags.input_shape.bind(),
65            parsed_flags.input_shape.default_value(),
66            "Deprecated: use --input_shapes instead. Input array shape. For "
67            "many models the shape takes the form "
68            "batch size, input array height, input array width, input array "
69            "depth."),
70       Flag("input_shapes", parsed_flags.input_shapes.bind(),
71            parsed_flags.input_shapes.default_value(),
72            "Shapes corresponding to --input_arrays, colon-separated. For "
73            "many models each shape takes the form batch size, input array "
74            "height, input array width, input array depth."),
75       Flag("batch_size", parsed_flags.batch_size.bind(),
76            parsed_flags.batch_size.default_value(),
77            "Deprecated. Batch size for the model. Replaces the first dimension "
78            "of an input size array if undefined. Use only with SavedModels "
79            "when --input_shapes flag is not specified. Always use "
80            "--input_shapes flag with frozen graphs."),
81       Flag("input_data_type", parsed_flags.input_data_type.bind(),
82            parsed_flags.input_data_type.default_value(),
83            "Deprecated: use --input_data_types instead. Input array type, if "
84            "not already provided in the graph. "
85            "Typically needs to be specified when passing arbitrary arrays "
86            "to --input_arrays."),
87       Flag("input_data_types", parsed_flags.input_data_types.bind(),
88            parsed_flags.input_data_types.default_value(),
89            "Input arrays types, comma-separated, if not already provided in "
90            "the graph. "
91            "Typically needs to be specified when passing arbitrary arrays "
92            "to --input_arrays."),
93       Flag("mean_value", parsed_flags.mean_value.bind(),
94            parsed_flags.mean_value.default_value(),
95            "Deprecated: use --mean_values instead. mean_value parameter for "
96            "image models, used to compute input "
97            "activations from input pixel data."),
98       Flag("mean_values", parsed_flags.mean_values.bind(),
99            parsed_flags.mean_values.default_value(),
100            "mean_values parameter for image models, comma-separated list of "
101            "doubles, used to compute input activations from input pixel "
102            "data. Each entry in the list should match an entry in "
103            "--input_arrays."),
104       Flag("std_value", parsed_flags.std_value.bind(),
105            parsed_flags.std_value.default_value(),
106            "Deprecated: use --std_values instead. std_value parameter for "
107            "image models, used to compute input "
108            "activations from input pixel data."),
109       Flag("std_values", parsed_flags.std_values.bind(),
110            parsed_flags.std_values.default_value(),
111            "std_value parameter for image models, comma-separated list of "
112            "doubles, used to compute input activations from input pixel "
113            "data. Each entry in the list should match an entry in "
114            "--input_arrays."),
115       Flag("variable_batch", parsed_flags.variable_batch.bind(),
116            parsed_flags.variable_batch.default_value(),
117            "If true, the model accepts an arbitrary batch size. Mutually "
118            "exclusive "
119            "with the 'batch' field: at most one of these two fields can be "
120            "set."),
121       Flag("rnn_states", parsed_flags.rnn_states.bind(),
122            parsed_flags.rnn_states.default_value(), ""),
123       Flag("model_checks", parsed_flags.model_checks.bind(),
124            parsed_flags.model_checks.default_value(),
125            "A list of model checks to be applied to verify the form of the "
126            "model.  Applied after the graph transformations after import."),
127       Flag("dump_graphviz", parsed_flags.dump_graphviz.bind(),
128            parsed_flags.dump_graphviz.default_value(),
129            "Dump graphviz during LogDump call. If string is non-empty then "
130            "it defines path to dump, otherwise will skip dumping."),
131       Flag("dump_graphviz_video", parsed_flags.dump_graphviz_video.bind(),
132            parsed_flags.dump_graphviz_video.default_value(),
133            "If true, will dump graphviz at each "
134            "graph transformation, which may be used to generate a video."),
135       Flag("conversion_summary_dir", parsed_flags.conversion_summary_dir.bind(),
136            parsed_flags.conversion_summary_dir.default_value(),
137            "Local file directory to store the conversion logs."),
138       Flag("allow_nonexistent_arrays",
139            parsed_flags.allow_nonexistent_arrays.bind(),
140            parsed_flags.allow_nonexistent_arrays.default_value(),
141            "If true, will allow passing inexistent arrays in --input_arrays "
142            "and --output_arrays. This makes little sense, is only useful to "
143            "more easily get graph visualizations."),
144       Flag("allow_nonascii_arrays", parsed_flags.allow_nonascii_arrays.bind(),
145            parsed_flags.allow_nonascii_arrays.default_value(),
146            "If true, will allow passing non-ascii-printable characters in "
147            "--input_arrays and --output_arrays. By default (if false), only "
148            "ascii printable characters are allowed, i.e. character codes "
149            "ranging from 32 to 127. This is disallowed by default so as to "
150            "catch common copy-and-paste issues where invisible unicode "
151            "characters are unwittingly added to these strings."),
152       Flag(
153           "arrays_extra_info_file", parsed_flags.arrays_extra_info_file.bind(),
154           parsed_flags.arrays_extra_info_file.default_value(),
155           "Path to an optional file containing a serialized ArraysExtraInfo "
156           "proto allowing to pass extra information about arrays not specified "
157           "in the input model file, such as extra MinMax information."),
158       Flag("model_flags_file", parsed_flags.model_flags_file.bind(),
159            parsed_flags.model_flags_file.default_value(),
160            "Path to an optional file containing a serialized ModelFlags proto. "
161            "Options specified on the command line will override the values in "
162            "the proto."),
163       Flag("change_concat_input_ranges",
164            parsed_flags.change_concat_input_ranges.bind(),
165            parsed_flags.change_concat_input_ranges.default_value(),
166            "Boolean to change the behavior of min/max ranges for inputs and"
167            " output of the concat operators."),
168   };
169   bool asked_for_help =
170       *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
171   if (asked_for_help) {
172     *msg += tensorflow::Flags::Usage(argv[0], flags);
173     return false;
174   } else {
175     if (!tensorflow::Flags::Parse(argc, argv, flags)) return false;
176   }
177   auto& dump_options = *GraphVizDumpOptions::singleton();
178   dump_options.dump_graphviz_video = parsed_flags.dump_graphviz_video.value();
179   dump_options.dump_graphviz = parsed_flags.dump_graphviz.value();
180 
181   return true;
182 }
183 
ReadModelFlagsFromCommandLineFlags(const ParsedModelFlags & parsed_model_flags,ModelFlags * model_flags)184 void ReadModelFlagsFromCommandLineFlags(
185     const ParsedModelFlags& parsed_model_flags, ModelFlags* model_flags) {
186   toco::port::CheckInitGoogleIsDone("InitGoogle is not done yet");
187 
188   // Load proto containing the initial model flags.
189   // Additional flags specified on the command line will overwrite the values.
190   if (parsed_model_flags.model_flags_file.specified()) {
191     string model_flags_file_contents;
192     QCHECK(port::file::GetContents(parsed_model_flags.model_flags_file.value(),
193                                    &model_flags_file_contents,
194                                    port::file::Defaults())
195                .ok())
196         << "Specified --model_flags_file="
197         << parsed_model_flags.model_flags_file.value()
198         << " was not found or could not be read";
199     QCHECK(ParseFromStringEitherTextOrBinary(model_flags_file_contents,
200                                              model_flags))
201         << "Specified --model_flags_file="
202         << parsed_model_flags.model_flags_file.value()
203         << " could not be parsed";
204   }
205 
206 #ifdef PLATFORM_GOOGLE
207   CHECK(!((base::SpecifiedOnCommandLine("batch") &&
208            parsed_model_flags.variable_batch.specified())))
209       << "The --batch and --variable_batch flags are mutually exclusive.";
210 #endif
211   CHECK(!(parsed_model_flags.output_array.specified() &&
212           parsed_model_flags.output_arrays.specified()))
213       << "The --output_array and --vs flags are mutually exclusive.";
214 
215   if (parsed_model_flags.output_array.specified()) {
216     model_flags->add_output_arrays(parsed_model_flags.output_array.value());
217   }
218 
219   if (parsed_model_flags.output_arrays.specified()) {
220     std::vector<string> output_arrays =
221         absl::StrSplit(parsed_model_flags.output_arrays.value(), ',');
222     for (const string& output_array : output_arrays) {
223       model_flags->add_output_arrays(output_array);
224     }
225   }
226 
227   const bool uses_single_input_flags =
228       parsed_model_flags.input_array.specified() ||
229       parsed_model_flags.mean_value.specified() ||
230       parsed_model_flags.std_value.specified() ||
231       parsed_model_flags.input_shape.specified();
232 
233   const bool uses_multi_input_flags =
234       parsed_model_flags.input_arrays.specified() ||
235       parsed_model_flags.mean_values.specified() ||
236       parsed_model_flags.std_values.specified() ||
237       parsed_model_flags.input_shapes.specified();
238 
239   QCHECK(!(uses_single_input_flags && uses_multi_input_flags))
240       << "Use either the singular-form input flags (--input_array, "
241          "--input_shape, --mean_value, --std_value) or the plural form input "
242          "flags (--input_arrays, --input_shapes, --mean_values, --std_values), "
243          "but not both forms within the same command line.";
244 
245   if (parsed_model_flags.input_array.specified()) {
246     QCHECK(uses_single_input_flags);
247     model_flags->add_input_arrays()->set_name(
248         parsed_model_flags.input_array.value());
249   }
250   if (parsed_model_flags.input_arrays.specified()) {
251     QCHECK(uses_multi_input_flags);
252     for (const auto& input_array :
253          absl::StrSplit(parsed_model_flags.input_arrays.value(), ',')) {
254       model_flags->add_input_arrays()->set_name(string(input_array));
255     }
256   }
257   if (parsed_model_flags.mean_value.specified()) {
258     QCHECK(uses_single_input_flags);
259     model_flags->mutable_input_arrays(0)->set_mean_value(
260         parsed_model_flags.mean_value.value());
261   }
262   if (parsed_model_flags.mean_values.specified()) {
263     QCHECK(uses_multi_input_flags);
264     std::vector<string> mean_values =
265         absl::StrSplit(parsed_model_flags.mean_values.value(), ',');
266     QCHECK(mean_values.size() == model_flags->input_arrays_size());
267     for (size_t i = 0; i < mean_values.size(); ++i) {
268       char* last = nullptr;
269       model_flags->mutable_input_arrays(i)->set_mean_value(
270           strtod(mean_values[i].data(), &last));
271       CHECK(last != mean_values[i].data());
272     }
273   }
274   if (parsed_model_flags.std_value.specified()) {
275     QCHECK(uses_single_input_flags);
276     model_flags->mutable_input_arrays(0)->set_std_value(
277         parsed_model_flags.std_value.value());
278   }
279   if (parsed_model_flags.std_values.specified()) {
280     QCHECK(uses_multi_input_flags);
281     std::vector<string> std_values =
282         absl::StrSplit(parsed_model_flags.std_values.value(), ',');
283     QCHECK(std_values.size() == model_flags->input_arrays_size());
284     for (size_t i = 0; i < std_values.size(); ++i) {
285       char* last = nullptr;
286       model_flags->mutable_input_arrays(i)->set_std_value(
287           strtod(std_values[i].data(), &last));
288       CHECK(last != std_values[i].data());
289     }
290   }
291   if (parsed_model_flags.input_data_type.specified()) {
292     QCHECK(uses_single_input_flags);
293     IODataType type;
294     QCHECK(IODataType_Parse(parsed_model_flags.input_data_type.value(), &type));
295     model_flags->mutable_input_arrays(0)->set_data_type(type);
296   }
297   if (parsed_model_flags.input_data_types.specified()) {
298     QCHECK(uses_multi_input_flags);
299     std::vector<string> input_data_types =
300         absl::StrSplit(parsed_model_flags.input_data_types.value(), ',');
301     QCHECK(input_data_types.size() == model_flags->input_arrays_size());
302     for (size_t i = 0; i < input_data_types.size(); ++i) {
303       IODataType type;
304       QCHECK(IODataType_Parse(input_data_types[i], &type));
305       model_flags->mutable_input_arrays(i)->set_data_type(type);
306     }
307   }
308   if (parsed_model_flags.input_shape.specified()) {
309     QCHECK(uses_single_input_flags);
310     if (model_flags->input_arrays().empty()) {
311       model_flags->add_input_arrays();
312     }
313     auto* shape = model_flags->mutable_input_arrays(0)->mutable_shape();
314     shape->clear_dims();
315     const IntList& list = parsed_model_flags.input_shape.value();
316     for (auto& dim : list.elements) {
317       shape->add_dims(dim);
318     }
319   }
320   if (parsed_model_flags.input_shapes.specified()) {
321     QCHECK(uses_multi_input_flags);
322     std::vector<string> input_shapes =
323         absl::StrSplit(parsed_model_flags.input_shapes.value(), ':');
324     QCHECK(input_shapes.size() == model_flags->input_arrays_size());
325     for (size_t i = 0; i < input_shapes.size(); ++i) {
326       auto* shape = model_flags->mutable_input_arrays(i)->mutable_shape();
327       shape->clear_dims();
328       // Treat an empty input shape as a scalar.
329       if (input_shapes[i].empty()) {
330         continue;
331       }
332       for (const auto& dim_str : absl::StrSplit(input_shapes[i], ',')) {
333         int size;
334         CHECK(absl::SimpleAtoi(dim_str, &size))
335             << "Failed to parse input_shape: " << input_shapes[i];
336         shape->add_dims(size);
337       }
338     }
339   }
340 
341 #define READ_MODEL_FLAG(name)                                   \
342   do {                                                          \
343     if (parsed_model_flags.name.specified()) {                  \
344       model_flags->set_##name(parsed_model_flags.name.value()); \
345     }                                                           \
346   } while (false)
347 
348   READ_MODEL_FLAG(variable_batch);
349 
350 #undef READ_MODEL_FLAG
351 
352   for (const auto& element : parsed_model_flags.rnn_states.value().elements) {
353     auto* rnn_state_proto = model_flags->add_rnn_states();
354     for (const auto& kv_pair : element) {
355       const string& key = kv_pair.first;
356       const string& value = kv_pair.second;
357       if (key == "state_array") {
358         rnn_state_proto->set_state_array(value);
359       } else if (key == "back_edge_source_array") {
360         rnn_state_proto->set_back_edge_source_array(value);
361       } else if (key == "size") {
362         int32 size = 0;
363         CHECK(absl::SimpleAtoi(value, &size));
364         CHECK_GT(size, 0);
365         rnn_state_proto->set_size(size);
366       } else {
367         LOG(FATAL) << "Unknown key '" << key << "' in --rnn_states";
368       }
369     }
370     CHECK(rnn_state_proto->has_state_array() &&
371           rnn_state_proto->has_back_edge_source_array() &&
372           rnn_state_proto->has_size())
373         << "--rnn_states must include state_array, back_edge_source_array and "
374            "size.";
375   }
376 
377   for (const auto& element : parsed_model_flags.model_checks.value().elements) {
378     auto* model_check_proto = model_flags->add_model_checks();
379     for (const auto& kv_pair : element) {
380       const string& key = kv_pair.first;
381       const string& value = kv_pair.second;
382       if (key == "count_type") {
383         model_check_proto->set_count_type(value);
384       } else if (key == "count_min") {
385         int32 count = 0;
386         CHECK(absl::SimpleAtoi(value, &count));
387         CHECK_GE(count, -1);
388         model_check_proto->set_count_min(count);
389       } else if (key == "count_max") {
390         int32 count = 0;
391         CHECK(absl::SimpleAtoi(value, &count));
392         CHECK_GE(count, -1);
393         model_check_proto->set_count_max(count);
394       } else {
395         LOG(FATAL) << "Unknown key '" << key << "' in --model_checks";
396       }
397     }
398   }
399 
400   if (!model_flags->has_allow_nonascii_arrays()) {
401     model_flags->set_allow_nonascii_arrays(
402         parsed_model_flags.allow_nonascii_arrays.value());
403   }
404   if (!model_flags->has_allow_nonexistent_arrays()) {
405     model_flags->set_allow_nonexistent_arrays(
406         parsed_model_flags.allow_nonexistent_arrays.value());
407   }
408   if (!model_flags->has_change_concat_input_ranges()) {
409     model_flags->set_change_concat_input_ranges(
410         parsed_model_flags.change_concat_input_ranges.value());
411   }
412 
413   if (parsed_model_flags.arrays_extra_info_file.specified()) {
414     string arrays_extra_info_file_contents;
415     CHECK(port::file::GetContents(
416               parsed_model_flags.arrays_extra_info_file.value(),
417               &arrays_extra_info_file_contents, port::file::Defaults())
418               .ok());
419     ParseFromStringEitherTextOrBinary(arrays_extra_info_file_contents,
420                                       model_flags->mutable_arrays_extra_info());
421   }
422 }
423 
UncheckedGlobalParsedModelFlags(bool must_already_exist)424 ParsedModelFlags* UncheckedGlobalParsedModelFlags(bool must_already_exist) {
425   static auto* flags = [must_already_exist]() {
426     if (must_already_exist) {
427       fprintf(stderr, __FILE__
428               ":"
429               "GlobalParsedModelFlags() used without initialization\n");
430       fflush(stderr);
431       abort();
432     }
433     return new toco::ParsedModelFlags;
434   }();
435   return flags;
436 }
437 
GlobalParsedModelFlags()438 ParsedModelFlags* GlobalParsedModelFlags() {
439   return UncheckedGlobalParsedModelFlags(true);
440 }
441 
ParseModelFlagsOrDie(int * argc,char * argv[])442 void ParseModelFlagsOrDie(int* argc, char* argv[]) {
443   // TODO(aselle): in the future allow Google version to use
444   // flags, and only use this mechanism for open source
445   auto* flags = UncheckedGlobalParsedModelFlags(false);
446   string msg;
447   bool model_success =
448       toco::ParseModelFlagsFromCommandLineFlags(argc, argv, &msg, flags);
449   if (!model_success || !msg.empty()) {
450     // Log in non-standard way since this happens pre InitGoogle.
451     fprintf(stderr, "%s", msg.c_str());
452     fflush(stderr);
453     abort();
454   }
455 }
456 
457 }  // namespace toco
458