1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <string>
17 #include <vector>
18
19 #include "absl/strings/numbers.h"
20 #include "absl/strings/str_join.h"
21 #include "absl/strings/str_split.h"
22 #include "absl/strings/strip.h"
23 #include "absl/types/optional.h"
24 #include "tensorflow/lite/toco/toco_cmdline_flags.h"
25 #include "tensorflow/lite/toco/toco_port.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/util/command_line_flags.h"
28
29 namespace toco {
30
ParseTocoFlagsFromCommandLineFlags(int * argc,char * argv[],string * msg,ParsedTocoFlags * parsed_toco_flags_ptr)31 bool ParseTocoFlagsFromCommandLineFlags(
32 int* argc, char* argv[], string* msg,
33 ParsedTocoFlags* parsed_toco_flags_ptr) {
34 using tensorflow::Flag;
35 ParsedTocoFlags& parsed_flags = *parsed_toco_flags_ptr;
36 std::vector<tensorflow::Flag> flags = {
37 Flag("input_file", parsed_flags.input_file.bind(),
38 parsed_flags.input_file.default_value(),
39 "Input file (model of any supported format). For Protobuf "
40 "formats, both text and binary are supported regardless of file "
41 "extension."),
42 Flag("savedmodel_directory", parsed_flags.savedmodel_directory.bind(),
43 parsed_flags.savedmodel_directory.default_value(),
44 "Deprecated. Full path to the directory containing the SavedModel."),
45 Flag("output_file", parsed_flags.output_file.bind(),
46 parsed_flags.output_file.default_value(),
47 "Output file. "
48 "For Protobuf formats, the binary format will be used."),
49 Flag("input_format", parsed_flags.input_format.bind(),
50 parsed_flags.input_format.default_value(),
51 "Input file format. One of: TENSORFLOW_GRAPHDEF, TFLITE."),
52 Flag("output_format", parsed_flags.output_format.bind(),
53 parsed_flags.output_format.default_value(),
54 "Output file format. "
55 "One of TENSORFLOW_GRAPHDEF, TFLITE, GRAPHVIZ_DOT."),
56 Flag("savedmodel_tagset", parsed_flags.savedmodel_tagset.bind(),
57 parsed_flags.savedmodel_tagset.default_value(),
58 "Deprecated. Comma-separated set of tags identifying the "
59 "MetaGraphDef within the SavedModel to analyze. All tags in the tag "
60 "set must be specified."),
61 Flag("default_ranges_min", parsed_flags.default_ranges_min.bind(),
62 parsed_flags.default_ranges_min.default_value(),
63 "If defined, will be used as the default value for the min bound "
64 "of min/max ranges used for quantization of uint8 arrays."),
65 Flag("default_ranges_max", parsed_flags.default_ranges_max.bind(),
66 parsed_flags.default_ranges_max.default_value(),
67 "If defined, will be used as the default value for the max bound "
68 "of min/max ranges used for quantization of uint8 arrays."),
69 Flag("default_int16_ranges_min",
70 parsed_flags.default_int16_ranges_min.bind(),
71 parsed_flags.default_int16_ranges_min.default_value(),
72 "If defined, will be used as the default value for the min bound "
73 "of min/max ranges used for quantization of int16 arrays."),
74 Flag("default_int16_ranges_max",
75 parsed_flags.default_int16_ranges_max.bind(),
76 parsed_flags.default_int16_ranges_max.default_value(),
77 "If defined, will be used as the default value for the max bound "
78 "of min/max ranges used for quantization of int16 arrays."),
79 Flag("inference_type", parsed_flags.inference_type.bind(),
80 parsed_flags.inference_type.default_value(),
81 "Target data type of arrays in the output file (for input_arrays, "
82 "this may be overridden by inference_input_type). "
83 "One of FLOAT, QUANTIZED_UINT8."),
84 Flag("inference_input_type", parsed_flags.inference_input_type.bind(),
85 parsed_flags.inference_input_type.default_value(),
86 "Target data type of input arrays. "
87 "If not specified, inference_type is used. "
88 "One of FLOAT, QUANTIZED_UINT8."),
89 Flag("input_type", parsed_flags.input_type.bind(),
90 parsed_flags.input_type.default_value(),
91 "Deprecated ambiguous flag that set both --input_data_types and "
92 "--inference_input_type."),
93 Flag("input_types", parsed_flags.input_types.bind(),
94 parsed_flags.input_types.default_value(),
95 "Deprecated ambiguous flag that set both --input_data_types and "
96 "--inference_input_type. Was meant to be a "
97 "comma-separated list, but this was deprecated before "
98 "multiple-input-types was ever properly supported."),
99
100 Flag("drop_fake_quant", parsed_flags.drop_fake_quant.bind(),
101 parsed_flags.drop_fake_quant.default_value(),
102 "Ignore and discard FakeQuant nodes. For instance, to "
103 "generate plain float code without fake-quantization from a "
104 "quantized graph."),
105 Flag(
106 "reorder_across_fake_quant",
107 parsed_flags.reorder_across_fake_quant.bind(),
108 parsed_flags.reorder_across_fake_quant.default_value(),
109 "Normally, FakeQuant nodes must be strict boundaries for graph "
110 "transformations, in order to ensure that quantized inference has "
111 "the exact same arithmetic behavior as quantized training --- which "
112 "is the whole point of quantized training and of FakeQuant nodes in "
113 "the first place. "
114 "However, that entails subtle requirements on where exactly "
115 "FakeQuant nodes must be placed in the graph. Some quantized graphs "
116 "have FakeQuant nodes at unexpected locations, that prevent graph "
117 "transformations that are necessary in order to generate inference "
118 "code for these graphs. Such graphs should be fixed, but as a "
119 "temporary work-around, setting this reorder_across_fake_quant flag "
120 "allows TOCO to perform necessary graph transformaitons on them, "
121 "at the cost of no longer faithfully matching inference and training "
122 "arithmetic."),
123 Flag("allow_custom_ops", parsed_flags.allow_custom_ops.bind(),
124 parsed_flags.allow_custom_ops.default_value(),
125 "If true, allow TOCO to create TF Lite Custom operators for all the "
126 "unsupported TensorFlow ops."),
127 Flag("custom_opdefs", parsed_flags.custom_opdefs.bind(),
128 parsed_flags.custom_opdefs.default_value(),
129 "List of strings representing custom ops OpDefs that are included "
130 "in the GraphDef."),
131 Flag("allow_dynamic_tensors", parsed_flags.allow_dynamic_tensors.bind(),
132 parsed_flags.allow_dynamic_tensors.default_value(),
133 "Boolean flag indicating whether the converter should allow models "
134 "with dynamic Tensor shape. When set to False, the converter will "
135 "generate runtime memory offsets for activation Tensors (with 128 "
136 "bits alignment) and error out on models with undetermined Tensor "
137 "shape. (Default: True)"),
138 Flag(
139 "drop_control_dependency",
140 parsed_flags.drop_control_dependency.bind(),
141 parsed_flags.drop_control_dependency.default_value(),
142 "If true, ignore control dependency requirements in input TensorFlow "
143 "GraphDef. Otherwise an error will be raised upon control dependency "
144 "inputs."),
145 Flag("debug_disable_recurrent_cell_fusion",
146 parsed_flags.debug_disable_recurrent_cell_fusion.bind(),
147 parsed_flags.debug_disable_recurrent_cell_fusion.default_value(),
148 "If true, disable fusion of known identifiable cell subgraphs into "
149 "cells. This includes, for example, specific forms of LSTM cell."),
150 Flag("propagate_fake_quant_num_bits",
151 parsed_flags.propagate_fake_quant_num_bits.bind(),
152 parsed_flags.propagate_fake_quant_num_bits.default_value(),
153 "If true, use FakeQuant* operator num_bits attributes to adjust "
154 "array data_types."),
155 Flag("allow_nudging_weights_to_use_fast_gemm_kernel",
156 parsed_flags.allow_nudging_weights_to_use_fast_gemm_kernel.bind(),
157 parsed_flags.allow_nudging_weights_to_use_fast_gemm_kernel
158 .default_value(),
159 "Some fast uint8 GEMM kernels require uint8 weights to avoid the "
160 "value 0. This flag allows nudging them to 1 to allow proceeding, "
161 "with moderate inaccuracy."),
162 Flag("dedupe_array_min_size_bytes",
163 parsed_flags.dedupe_array_min_size_bytes.bind(),
164 parsed_flags.dedupe_array_min_size_bytes.default_value(),
165 "Minimum size of constant arrays to deduplicate; arrays smaller "
166 "will not be deduplicated."),
167 Flag("split_tflite_lstm_inputs",
168 parsed_flags.split_tflite_lstm_inputs.bind(),
169 parsed_flags.split_tflite_lstm_inputs.default_value(),
170 "Split the LSTM inputs from 5 tensors to 18 tensors for TFLite. "
171 "Ignored if the output format is not TFLite."),
172 Flag("quantize_to_float16", parsed_flags.quantize_to_float16.bind(),
173 parsed_flags.quantize_to_float16.default_value(),
174 "Used in conjuction with post_training_quantize. Specifies that "
175 "the weights should be quantized to fp16 instead of the default "
176 "(int8)"),
177 Flag("quantize_weights", parsed_flags.quantize_weights.bind(),
178 parsed_flags.quantize_weights.default_value(),
179 "Deprecated. Please use --post_training_quantize instead."),
180 Flag("post_training_quantize", parsed_flags.post_training_quantize.bind(),
181 parsed_flags.post_training_quantize.default_value(),
182 "Boolean indicating whether to quantize the weights of the "
183 "converted float model. Model size will be reduced and there will "
184 "be latency improvements (at the cost of accuracy)."),
185 // TODO(b/118822804): Unify the argument definition with `tflite_convert`.
186 // WARNING: Experimental interface, subject to change
187 Flag("enable_select_tf_ops", parsed_flags.enable_select_tf_ops.bind(),
188 parsed_flags.enable_select_tf_ops.default_value(), ""),
189 // WARNING: Experimental interface, subject to change
190 Flag("force_select_tf_ops", parsed_flags.force_select_tf_ops.bind(),
191 parsed_flags.force_select_tf_ops.default_value(), "")};
192 bool asked_for_help =
193 *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
194 if (asked_for_help) {
195 *msg += tensorflow::Flags::Usage(argv[0], flags);
196 return false;
197 } else {
198 return tensorflow::Flags::Parse(argc, argv, flags);
199 }
200 }
201
202 namespace {
203
204 // Defines the requirements for a given flag. kUseDefault means the default
205 // should be used in cases where the value isn't specified by the user.
206 enum class FlagRequirement {
207 kNone,
208 kMustBeSpecified,
209 kMustNotBeSpecified,
210 kUseDefault,
211 };
212
213 // Enforces the FlagRequirements are met for a given flag.
214 template <typename T>
EnforceFlagRequirement(const T & flag,const string & flag_name,FlagRequirement requirement)215 void EnforceFlagRequirement(const T& flag, const string& flag_name,
216 FlagRequirement requirement) {
217 if (requirement == FlagRequirement::kMustBeSpecified) {
218 QCHECK(flag.specified()) << "Missing required flag " << flag_name;
219 }
220 if (requirement == FlagRequirement::kMustNotBeSpecified) {
221 QCHECK(!flag.specified())
222 << "Given other flags, this flag should not have been specified: "
223 << flag_name;
224 }
225 }
226
227 // Gets the value from the flag if specified. Returns default if the
228 // FlagRequirement is kUseDefault.
229 template <typename T>
GetFlagValue(const Arg<T> & flag,FlagRequirement requirement)230 absl::optional<T> GetFlagValue(const Arg<T>& flag,
231 FlagRequirement requirement) {
232 if (flag.specified()) return flag.value();
233 if (requirement == FlagRequirement::kUseDefault) return flag.default_value();
234 return absl::optional<T>();
235 }
236
237 } // namespace
238
ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags & parsed_toco_flags,TocoFlags * toco_flags)239 void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
240 TocoFlags* toco_flags) {
241 namespace port = toco::port;
242 port::CheckInitGoogleIsDone("InitGoogle is not done yet");
243
244 #define READ_TOCO_FLAG(name, requirement) \
245 do { \
246 EnforceFlagRequirement(parsed_toco_flags.name, #name, requirement); \
247 auto flag_value = GetFlagValue(parsed_toco_flags.name, requirement); \
248 if (flag_value.has_value()) { \
249 toco_flags->set_##name(flag_value.value()); \
250 } \
251 } while (false)
252
253 #define PARSE_TOCO_FLAG(Type, name, requirement) \
254 do { \
255 EnforceFlagRequirement(parsed_toco_flags.name, #name, requirement); \
256 auto flag_value = GetFlagValue(parsed_toco_flags.name, requirement); \
257 if (flag_value.has_value()) { \
258 Type x; \
259 QCHECK(Type##_Parse(flag_value.value(), &x)) \
260 << "Unrecognized " << #Type << " value " \
261 << parsed_toco_flags.name.value(); \
262 toco_flags->set_##name(x); \
263 } \
264 } while (false)
265
266 PARSE_TOCO_FLAG(FileFormat, input_format, FlagRequirement::kUseDefault);
267 PARSE_TOCO_FLAG(FileFormat, output_format, FlagRequirement::kUseDefault);
268 PARSE_TOCO_FLAG(IODataType, inference_type, FlagRequirement::kNone);
269 PARSE_TOCO_FLAG(IODataType, inference_input_type, FlagRequirement::kNone);
270 READ_TOCO_FLAG(default_ranges_min, FlagRequirement::kNone);
271 READ_TOCO_FLAG(default_ranges_max, FlagRequirement::kNone);
272 READ_TOCO_FLAG(default_int16_ranges_min, FlagRequirement::kNone);
273 READ_TOCO_FLAG(default_int16_ranges_max, FlagRequirement::kNone);
274 READ_TOCO_FLAG(drop_fake_quant, FlagRequirement::kNone);
275 READ_TOCO_FLAG(reorder_across_fake_quant, FlagRequirement::kNone);
276 READ_TOCO_FLAG(allow_custom_ops, FlagRequirement::kNone);
277 READ_TOCO_FLAG(drop_control_dependency, FlagRequirement::kNone);
278 READ_TOCO_FLAG(debug_disable_recurrent_cell_fusion, FlagRequirement::kNone);
279 READ_TOCO_FLAG(propagate_fake_quant_num_bits, FlagRequirement::kNone);
280 READ_TOCO_FLAG(allow_nudging_weights_to_use_fast_gemm_kernel,
281 FlagRequirement::kNone);
282 READ_TOCO_FLAG(dedupe_array_min_size_bytes, FlagRequirement::kNone);
283 READ_TOCO_FLAG(split_tflite_lstm_inputs, FlagRequirement::kNone);
284 READ_TOCO_FLAG(quantize_weights, FlagRequirement::kNone);
285 READ_TOCO_FLAG(quantize_to_float16, FlagRequirement::kNone);
286 READ_TOCO_FLAG(post_training_quantize, FlagRequirement::kNone);
287 READ_TOCO_FLAG(enable_select_tf_ops, FlagRequirement::kNone);
288 READ_TOCO_FLAG(force_select_tf_ops, FlagRequirement::kNone);
289
290 if (parsed_toco_flags.force_select_tf_ops.value() &&
291 !parsed_toco_flags.enable_select_tf_ops.value()) {
292 // TODO(ycling): Consider to enforce `enable_select_tf_ops` when
293 // `force_select_tf_ops` is true.
294 LOG(WARNING) << "--force_select_tf_ops should always be used with "
295 "--enable_select_tf_ops.";
296 }
297
298 // Deprecated flag handling.
299 if (parsed_toco_flags.input_type.specified()) {
300 LOG(WARNING)
301 << "--input_type is deprecated. It was an ambiguous flag that set both "
302 "--input_data_types and --inference_input_type. If you are trying "
303 "to complement the input file with information about the type of "
304 "input arrays, use --input_data_type. If you are trying to control "
305 "the quantization/dequantization of real-numbers input arrays in "
306 "the output file, use --inference_input_type.";
307 toco::IODataType input_type;
308 QCHECK(toco::IODataType_Parse(parsed_toco_flags.input_type.value(),
309 &input_type));
310 toco_flags->set_inference_input_type(input_type);
311 }
312 if (parsed_toco_flags.input_types.specified()) {
313 LOG(WARNING)
314 << "--input_types is deprecated. It was an ambiguous flag that set "
315 "both --input_data_types and --inference_input_type. If you are "
316 "trying to complement the input file with information about the "
317 "type of input arrays, use --input_data_type. If you are trying to "
318 "control the quantization/dequantization of real-numbers input "
319 "arrays in the output file, use --inference_input_type.";
320 std::vector<string> input_types =
321 absl::StrSplit(parsed_toco_flags.input_types.value(), ',');
322 QCHECK(!input_types.empty());
323 for (int i = 1; i < input_types.size(); i++) {
324 QCHECK_EQ(input_types[i], input_types[0]);
325 }
326 toco::IODataType input_type;
327 QCHECK(toco::IODataType_Parse(input_types[0], &input_type));
328 toco_flags->set_inference_input_type(input_type);
329 }
330 if (parsed_toco_flags.quantize_weights.value()) {
331 LOG(WARNING)
332 << "--quantize_weights is deprecated. Falling back to "
333 "--post_training_quantize. Please switch --post_training_quantize.";
334 toco_flags->set_post_training_quantize(
335 parsed_toco_flags.quantize_weights.value());
336 }
337 if (parsed_toco_flags.quantize_weights.value()) {
338 if (toco_flags->inference_type() == IODataType::QUANTIZED_UINT8) {
339 LOG(WARNING)
340 << "--post_training_quantize quantizes a graph of inference_type "
341 "FLOAT. Overriding inference type QUANTIZED_UINT8 to FLOAT.";
342 toco_flags->set_inference_type(IODataType::FLOAT);
343 }
344 }
345
346 #undef READ_TOCO_FLAG
347 #undef PARSE_TOCO_FLAG
348 }
349 } // namespace toco
350