• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 #include <vector>
18 
19 #include "absl/strings/numbers.h"
20 #include "absl/strings/str_join.h"
21 #include "absl/strings/str_split.h"
22 #include "absl/strings/strip.h"
23 #include "absl/types/optional.h"
24 #include "tensorflow/lite/toco/toco_cmdline_flags.h"
25 #include "tensorflow/lite/toco/toco_port.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/util/command_line_flags.h"
28 
29 namespace toco {
30 
ParseTocoFlagsFromCommandLineFlags(int * argc,char * argv[],string * msg,ParsedTocoFlags * parsed_toco_flags_ptr)31 bool ParseTocoFlagsFromCommandLineFlags(
32     int* argc, char* argv[], string* msg,
33     ParsedTocoFlags* parsed_toco_flags_ptr) {
34   using tensorflow::Flag;
35   ParsedTocoFlags& parsed_flags = *parsed_toco_flags_ptr;
36   std::vector<tensorflow::Flag> flags = {
37       Flag("input_file", parsed_flags.input_file.bind(),
38            parsed_flags.input_file.default_value(),
39            "Input file (model of any supported format). For Protobuf "
40            "formats, both text and binary are supported regardless of file "
41            "extension."),
42       Flag("savedmodel_directory", parsed_flags.savedmodel_directory.bind(),
43            parsed_flags.savedmodel_directory.default_value(),
44            "Deprecated. Full path to the directory containing the SavedModel."),
45       Flag("output_file", parsed_flags.output_file.bind(),
46            parsed_flags.output_file.default_value(),
47            "Output file. "
48            "For Protobuf formats, the binary format will be used."),
49       Flag("input_format", parsed_flags.input_format.bind(),
50            parsed_flags.input_format.default_value(),
51            "Input file format. One of: TENSORFLOW_GRAPHDEF, TFLITE."),
52       Flag("output_format", parsed_flags.output_format.bind(),
53            parsed_flags.output_format.default_value(),
54            "Output file format. "
55            "One of TENSORFLOW_GRAPHDEF, TFLITE, GRAPHVIZ_DOT."),
56       Flag("savedmodel_tagset", parsed_flags.savedmodel_tagset.bind(),
57            parsed_flags.savedmodel_tagset.default_value(),
58            "Deprecated. Comma-separated set of tags identifying the "
59            "MetaGraphDef within the SavedModel to analyze. All tags in the tag "
60            "set must be specified."),
61       Flag("default_ranges_min", parsed_flags.default_ranges_min.bind(),
62            parsed_flags.default_ranges_min.default_value(),
63            "If defined, will be used as the default value for the min bound "
64            "of min/max ranges used for quantization of uint8 arrays."),
65       Flag("default_ranges_max", parsed_flags.default_ranges_max.bind(),
66            parsed_flags.default_ranges_max.default_value(),
67            "If defined, will be used as the default value for the max bound "
68            "of min/max ranges used for quantization of uint8 arrays."),
69       Flag("default_int16_ranges_min",
70            parsed_flags.default_int16_ranges_min.bind(),
71            parsed_flags.default_int16_ranges_min.default_value(),
72            "If defined, will be used as the default value for the min bound "
73            "of min/max ranges used for quantization of int16 arrays."),
74       Flag("default_int16_ranges_max",
75            parsed_flags.default_int16_ranges_max.bind(),
76            parsed_flags.default_int16_ranges_max.default_value(),
77            "If defined, will be used as the default value for the max bound "
78            "of min/max ranges used for quantization of int16 arrays."),
79       Flag("inference_type", parsed_flags.inference_type.bind(),
80            parsed_flags.inference_type.default_value(),
81            "Target data type of arrays in the output file (for input_arrays, "
82            "this may be overridden by inference_input_type). "
83            "One of FLOAT, QUANTIZED_UINT8."),
84       Flag("inference_input_type", parsed_flags.inference_input_type.bind(),
85            parsed_flags.inference_input_type.default_value(),
86            "Target data type of input arrays. "
87            "If not specified, inference_type is used. "
88            "One of FLOAT, QUANTIZED_UINT8."),
89       Flag("input_type", parsed_flags.input_type.bind(),
90            parsed_flags.input_type.default_value(),
91            "Deprecated ambiguous flag that set both --input_data_types and "
92            "--inference_input_type."),
93       Flag("input_types", parsed_flags.input_types.bind(),
94            parsed_flags.input_types.default_value(),
95            "Deprecated ambiguous flag that set both --input_data_types and "
96            "--inference_input_type. Was meant to be a "
97            "comma-separated list, but this was deprecated before "
98            "multiple-input-types was ever properly supported."),
99 
100       Flag("drop_fake_quant", parsed_flags.drop_fake_quant.bind(),
101            parsed_flags.drop_fake_quant.default_value(),
102            "Ignore and discard FakeQuant nodes. For instance, to "
103            "generate plain float code without fake-quantization from a "
104            "quantized graph."),
105       Flag(
106           "reorder_across_fake_quant",
107           parsed_flags.reorder_across_fake_quant.bind(),
108           parsed_flags.reorder_across_fake_quant.default_value(),
109           "Normally, FakeQuant nodes must be strict boundaries for graph "
110           "transformations, in order to ensure that quantized inference has "
111           "the exact same arithmetic behavior as quantized training --- which "
112           "is the whole point of quantized training and of FakeQuant nodes in "
113           "the first place. "
114           "However, that entails subtle requirements on where exactly "
115           "FakeQuant nodes must be placed in the graph. Some quantized graphs "
116           "have FakeQuant nodes at unexpected locations, that prevent graph "
117           "transformations that are necessary in order to generate inference "
118           "code for these graphs. Such graphs should be fixed, but as a "
119           "temporary work-around, setting this reorder_across_fake_quant flag "
120           "allows TOCO to perform necessary graph transformaitons on them, "
121           "at the cost of no longer faithfully matching inference and training "
122           "arithmetic."),
123       Flag("allow_custom_ops", parsed_flags.allow_custom_ops.bind(),
124            parsed_flags.allow_custom_ops.default_value(),
125            "If true, allow TOCO to create TF Lite Custom operators for all the "
126            "unsupported TensorFlow ops."),
127       Flag("custom_opdefs", parsed_flags.custom_opdefs.bind(),
128            parsed_flags.custom_opdefs.default_value(),
129            "List of strings representing custom ops OpDefs that are included "
130            "in the GraphDef."),
131       Flag("allow_dynamic_tensors", parsed_flags.allow_dynamic_tensors.bind(),
132            parsed_flags.allow_dynamic_tensors.default_value(),
133            "Boolean flag indicating whether the converter should allow models "
134            "with dynamic Tensor shape. When set to False, the converter will "
135            "generate runtime memory offsets for activation Tensors (with 128 "
136            "bits alignment) and error out on models with undetermined Tensor "
137            "shape. (Default: True)"),
138       Flag(
139           "drop_control_dependency",
140           parsed_flags.drop_control_dependency.bind(),
141           parsed_flags.drop_control_dependency.default_value(),
142           "If true, ignore control dependency requirements in input TensorFlow "
143           "GraphDef. Otherwise an error will be raised upon control dependency "
144           "inputs."),
145       Flag("debug_disable_recurrent_cell_fusion",
146            parsed_flags.debug_disable_recurrent_cell_fusion.bind(),
147            parsed_flags.debug_disable_recurrent_cell_fusion.default_value(),
148            "If true, disable fusion of known identifiable cell subgraphs into "
149            "cells. This includes, for example, specific forms of LSTM cell."),
150       Flag("propagate_fake_quant_num_bits",
151            parsed_flags.propagate_fake_quant_num_bits.bind(),
152            parsed_flags.propagate_fake_quant_num_bits.default_value(),
153            "If true, use FakeQuant* operator num_bits attributes to adjust "
154            "array data_types."),
155       Flag("allow_nudging_weights_to_use_fast_gemm_kernel",
156            parsed_flags.allow_nudging_weights_to_use_fast_gemm_kernel.bind(),
157            parsed_flags.allow_nudging_weights_to_use_fast_gemm_kernel
158                .default_value(),
159            "Some fast uint8 GEMM kernels require uint8 weights to avoid the "
160            "value 0. This flag allows nudging them to 1 to allow proceeding, "
161            "with moderate inaccuracy."),
162       Flag("dedupe_array_min_size_bytes",
163            parsed_flags.dedupe_array_min_size_bytes.bind(),
164            parsed_flags.dedupe_array_min_size_bytes.default_value(),
165            "Minimum size of constant arrays to deduplicate; arrays smaller "
166            "will not be deduplicated."),
167       Flag("split_tflite_lstm_inputs",
168            parsed_flags.split_tflite_lstm_inputs.bind(),
169            parsed_flags.split_tflite_lstm_inputs.default_value(),
170            "Split the LSTM inputs from 5 tensors to 18 tensors for TFLite. "
171            "Ignored if the output format is not TFLite."),
172       Flag("quantize_to_float16", parsed_flags.quantize_to_float16.bind(),
173            parsed_flags.quantize_to_float16.default_value(),
174            "Used in conjuction with post_training_quantize. Specifies that "
175            "the weights should be quantized to fp16 instead of the default "
176            "(int8)"),
177       Flag("quantize_weights", parsed_flags.quantize_weights.bind(),
178            parsed_flags.quantize_weights.default_value(),
179            "Deprecated. Please use --post_training_quantize instead."),
180       Flag("post_training_quantize", parsed_flags.post_training_quantize.bind(),
181            parsed_flags.post_training_quantize.default_value(),
182            "Boolean indicating whether to quantize the weights of the "
183            "converted float model. Model size will be reduced and there will "
184            "be latency improvements (at the cost of accuracy)."),
185       // TODO(b/118822804): Unify the argument definition with `tflite_convert`.
186       // WARNING: Experimental interface, subject to change
187       Flag("enable_select_tf_ops", parsed_flags.enable_select_tf_ops.bind(),
188            parsed_flags.enable_select_tf_ops.default_value(), ""),
189       // WARNING: Experimental interface, subject to change
190       Flag("force_select_tf_ops", parsed_flags.force_select_tf_ops.bind(),
191            parsed_flags.force_select_tf_ops.default_value(), "")};
192   bool asked_for_help =
193       *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
194   if (asked_for_help) {
195     *msg += tensorflow::Flags::Usage(argv[0], flags);
196     return false;
197   } else {
198     return tensorflow::Flags::Parse(argc, argv, flags);
199   }
200 }
201 
202 namespace {
203 
204 // Defines the requirements for a given flag. kUseDefault means the default
205 // should be used in cases where the value isn't specified by the user.
206 enum class FlagRequirement {
207   kNone,
208   kMustBeSpecified,
209   kMustNotBeSpecified,
210   kUseDefault,
211 };
212 
213 // Enforces the FlagRequirements are met for a given flag.
214 template <typename T>
EnforceFlagRequirement(const T & flag,const string & flag_name,FlagRequirement requirement)215 void EnforceFlagRequirement(const T& flag, const string& flag_name,
216                             FlagRequirement requirement) {
217   if (requirement == FlagRequirement::kMustBeSpecified) {
218     QCHECK(flag.specified()) << "Missing required flag " << flag_name;
219   }
220   if (requirement == FlagRequirement::kMustNotBeSpecified) {
221     QCHECK(!flag.specified())
222         << "Given other flags, this flag should not have been specified: "
223         << flag_name;
224   }
225 }
226 
227 // Gets the value from the flag if specified. Returns default if the
228 // FlagRequirement is kUseDefault.
229 template <typename T>
GetFlagValue(const Arg<T> & flag,FlagRequirement requirement)230 absl::optional<T> GetFlagValue(const Arg<T>& flag,
231                                FlagRequirement requirement) {
232   if (flag.specified()) return flag.value();
233   if (requirement == FlagRequirement::kUseDefault) return flag.default_value();
234   return absl::optional<T>();
235 }
236 
237 }  // namespace
238 
ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags & parsed_toco_flags,TocoFlags * toco_flags)239 void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
240                                        TocoFlags* toco_flags) {
241   namespace port = toco::port;
242   port::CheckInitGoogleIsDone("InitGoogle is not done yet");
243 
244 #define READ_TOCO_FLAG(name, requirement)                                \
245   do {                                                                   \
246     EnforceFlagRequirement(parsed_toco_flags.name, #name, requirement);  \
247     auto flag_value = GetFlagValue(parsed_toco_flags.name, requirement); \
248     if (flag_value.has_value()) {                                        \
249       toco_flags->set_##name(flag_value.value());                        \
250     }                                                                    \
251   } while (false)
252 
253 #define PARSE_TOCO_FLAG(Type, name, requirement)                         \
254   do {                                                                   \
255     EnforceFlagRequirement(parsed_toco_flags.name, #name, requirement);  \
256     auto flag_value = GetFlagValue(parsed_toco_flags.name, requirement); \
257     if (flag_value.has_value()) {                                        \
258       Type x;                                                            \
259       QCHECK(Type##_Parse(flag_value.value(), &x))                       \
260           << "Unrecognized " << #Type << " value "                       \
261           << parsed_toco_flags.name.value();                             \
262       toco_flags->set_##name(x);                                         \
263     }                                                                    \
264   } while (false)
265 
266   PARSE_TOCO_FLAG(FileFormat, input_format, FlagRequirement::kUseDefault);
267   PARSE_TOCO_FLAG(FileFormat, output_format, FlagRequirement::kUseDefault);
268   PARSE_TOCO_FLAG(IODataType, inference_type, FlagRequirement::kNone);
269   PARSE_TOCO_FLAG(IODataType, inference_input_type, FlagRequirement::kNone);
270   READ_TOCO_FLAG(default_ranges_min, FlagRequirement::kNone);
271   READ_TOCO_FLAG(default_ranges_max, FlagRequirement::kNone);
272   READ_TOCO_FLAG(default_int16_ranges_min, FlagRequirement::kNone);
273   READ_TOCO_FLAG(default_int16_ranges_max, FlagRequirement::kNone);
274   READ_TOCO_FLAG(drop_fake_quant, FlagRequirement::kNone);
275   READ_TOCO_FLAG(reorder_across_fake_quant, FlagRequirement::kNone);
276   READ_TOCO_FLAG(allow_custom_ops, FlagRequirement::kNone);
277   READ_TOCO_FLAG(drop_control_dependency, FlagRequirement::kNone);
278   READ_TOCO_FLAG(debug_disable_recurrent_cell_fusion, FlagRequirement::kNone);
279   READ_TOCO_FLAG(propagate_fake_quant_num_bits, FlagRequirement::kNone);
280   READ_TOCO_FLAG(allow_nudging_weights_to_use_fast_gemm_kernel,
281                  FlagRequirement::kNone);
282   READ_TOCO_FLAG(dedupe_array_min_size_bytes, FlagRequirement::kNone);
283   READ_TOCO_FLAG(split_tflite_lstm_inputs, FlagRequirement::kNone);
284   READ_TOCO_FLAG(quantize_weights, FlagRequirement::kNone);
285   READ_TOCO_FLAG(quantize_to_float16, FlagRequirement::kNone);
286   READ_TOCO_FLAG(post_training_quantize, FlagRequirement::kNone);
287   READ_TOCO_FLAG(enable_select_tf_ops, FlagRequirement::kNone);
288   READ_TOCO_FLAG(force_select_tf_ops, FlagRequirement::kNone);
289 
290   if (parsed_toco_flags.force_select_tf_ops.value() &&
291       !parsed_toco_flags.enable_select_tf_ops.value()) {
292     // TODO(ycling): Consider to enforce `enable_select_tf_ops` when
293     // `force_select_tf_ops` is true.
294     LOG(WARNING) << "--force_select_tf_ops should always be used with "
295                     "--enable_select_tf_ops.";
296   }
297 
298   // Deprecated flag handling.
299   if (parsed_toco_flags.input_type.specified()) {
300     LOG(WARNING)
301         << "--input_type is deprecated. It was an ambiguous flag that set both "
302            "--input_data_types and --inference_input_type. If you are trying "
303            "to complement the input file with information about the type of "
304            "input arrays, use --input_data_type. If you are trying to control "
305            "the quantization/dequantization of real-numbers input arrays in "
306            "the output file, use --inference_input_type.";
307     toco::IODataType input_type;
308     QCHECK(toco::IODataType_Parse(parsed_toco_flags.input_type.value(),
309                                   &input_type));
310     toco_flags->set_inference_input_type(input_type);
311   }
312   if (parsed_toco_flags.input_types.specified()) {
313     LOG(WARNING)
314         << "--input_types is deprecated. It was an ambiguous flag that set "
315            "both --input_data_types and --inference_input_type. If you are "
316            "trying to complement the input file with information about the "
317            "type of input arrays, use --input_data_type. If you are trying to "
318            "control the quantization/dequantization of real-numbers input "
319            "arrays in the output file, use --inference_input_type.";
320     std::vector<string> input_types =
321         absl::StrSplit(parsed_toco_flags.input_types.value(), ',');
322     QCHECK(!input_types.empty());
323     for (int i = 1; i < input_types.size(); i++) {
324       QCHECK_EQ(input_types[i], input_types[0]);
325     }
326     toco::IODataType input_type;
327     QCHECK(toco::IODataType_Parse(input_types[0], &input_type));
328     toco_flags->set_inference_input_type(input_type);
329   }
330   if (parsed_toco_flags.quantize_weights.value()) {
331     LOG(WARNING)
332         << "--quantize_weights is deprecated. Falling back to "
333            "--post_training_quantize. Please switch --post_training_quantize.";
334     toco_flags->set_post_training_quantize(
335         parsed_toco_flags.quantize_weights.value());
336   }
337   if (parsed_toco_flags.quantize_weights.value()) {
338     if (toco_flags->inference_type() == IODataType::QUANTIZED_UINT8) {
339       LOG(WARNING)
340           << "--post_training_quantize quantizes a graph of inference_type "
341              "FLOAT. Overriding inference type QUANTIZED_UINT8 to FLOAT.";
342       toco_flags->set_inference_type(IODataType::FLOAT);
343     }
344   }
345 
346 #undef READ_TOCO_FLAG
347 #undef PARSE_TOCO_FLAG
348 }
349 }  // namespace toco
350