1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/toco_tooling.h"
16
17 #include <cstdlib>
18 #include <memory>
19 #include <set>
20
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_join.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/lite/toco/allocate_transient_arrays.h"
25 #include "tensorflow/lite/toco/dump_graphviz.h"
26 #include "tensorflow/lite/toco/export_tensorflow.h"
27 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
28 #include "tensorflow/lite/toco/import_tensorflow.h"
29 #include "tensorflow/lite/toco/model.h"
30 #include "tensorflow/lite/toco/model_flags.pb.h"
31 #include "tensorflow/lite/toco/tflite/export.h"
32 #include "tensorflow/lite/toco/tflite/import.h"
33 #include "tensorflow/lite/toco/toco_flags.pb.h"
34 #include "tensorflow/lite/toco/tooling_util.h"
35
36 namespace toco {
37 namespace {
38 // CHECK-fails if the model contains a kUnsupported operation.
CheckUnsupportedOperations(const Model & model)39 void CheckUnsupportedOperations(const Model& model) {
40 std::set<std::string> unsupported_ops;
41 for (auto& op : model.operators) {
42 if (op->type == OperatorType::kUnsupported) {
43 unsupported_ops.insert(
44 static_cast<const TensorFlowUnsupportedOperator*>(op.get())
45 ->tensorflow_op);
46 }
47 }
48 QCHECK(unsupported_ops.empty())
49 << "These unsupported ops were not removed by graph transformations: "
50 << absl::StrJoin(unsupported_ops, ", ");
51 }
52
MakeGeneralGraphTransformationsSet(GraphTransformationsSet * transformations)53 void MakeGeneralGraphTransformationsSet(
54 GraphTransformationsSet* transformations) {
55 CHECK(transformations->empty());
56 transformations->Add(new ConvertExpandDimsToReshape);
57 transformations->Add(new ConvertMatrixDiagV2OrV3ToV1);
58 transformations->Add(new ConvertMatrixSetDiagV2OrV3ToV1);
59 transformations->Add(new ConvertSqueezeToReshape);
60 transformations->Add(new ConvertTrivialAddNToAdd);
61 transformations->Add(new ConvertTrivialPackToReshape);
62 transformations->Add(new ConvertTrivialTileToConcat);
63 transformations->Add(new ConvertTrivialTransposeToReshape);
64 transformations->Add(new ConvertReorderAxes);
65 transformations->Add(new ResolveReshapeAttributes);
66 transformations->Add(new ResolveTransposeAttributes);
67 transformations->Add(new PropagateActivationFunctionIntoConstants);
68 transformations->Add(new PropagateArrayDataTypes);
69 transformations->Add(new PropagateFixedSizes);
70 transformations->Add(new RemoveSuccessiveTranspose);
71 transformations->Add(new RemoveTensorFlowAssert);
72 transformations->Add(new RemoveTensorFlowIdentity);
73 transformations->Add(new RemoveTrivialConcatenation);
74 transformations->Add(new RemoveTrivialConcatenationInput);
75 transformations->Add(new RemoveTrivialFakeQuant);
76 transformations->Add(new RemoveTrivialSlice);
77 transformations->Add(new RemoveUnusedOp);
78 transformations->Add(new EnsureBiasVectors);
79 transformations->Add(new ResolveReorderAxes);
80 transformations->Add(new UnrollBatchMatMul);
81 transformations->Add(new ResolveTensorFlowMatMul);
82 transformations->Add(new FuseBinaryIntoPrecedingAffine);
83 transformations->Add(new FuseBinaryIntoFollowingAffine);
84 transformations->Add(new FuseBroadcastIntoFollowingBinary);
85 transformations->Add(new MergeReshapeIntoPrecedingTranspose);
86 transformations->Add(new MoveBinaryOperatorBeforeReshape);
87 transformations->Add(new ReorderElementwiseUnary);
88 transformations->Add(new ReorderReshapeTranspose);
89 transformations->Add(new ResolveBatchNormalization);
90 transformations->Add(new ResolveConstantBinaryOperator);
91 transformations->Add(new ResolveConstantFill);
92 transformations->Add(new ResolveConstantGather);
93 transformations->Add(new ResolveConstantPack);
94 transformations->Add(new ResolveConstantRandomUniform);
95 transformations->Add(new ResolveConstantRange);
96 transformations->Add(new ResolveConstantReshape);
97 transformations->Add(new ResolveConstantSelect);
98 transformations->Add(new ResolveConstantSlice);
99 transformations->Add(new ResolveConstantStridedSlice);
100 transformations->Add(new ResolveConstantTile);
101 transformations->Add(new ResolveConstantTranspose);
102 transformations->Add(new ResolveConstantUnaryOperator);
103 transformations->Add(new ResolveTensorFlowMerge);
104 transformations->Add(new ResolveSqueezeAttributes);
105 transformations->Add(new ResolveTensorFlowSwitch);
106 transformations->Add(new ResolveTensorFlowConcat);
107 transformations->Add(new ResolveMultiplyByZero);
108 transformations->Add(new IdentifyHardSwish);
109 transformations->Add(new IdentifyL2Normalization);
110 transformations->Add(new IdentifyL2Pool);
111 transformations->Add(new IdentifyRelu1);
112 transformations->Add(new IdentifyPRelu);
113 transformations->Add(new RemoveTrivialBinaryOperator);
114 transformations->Add(new ResolveFakeQuantArgsFromVars);
115 transformations->Add(new ReadArrayMinmaxAndNarrowRangeFromFakeQuant);
116 transformations->Add(new ResolveSpaceToBatchNDAttributes);
117 transformations->Add(new ResolveBatchToSpaceNDAttributes);
118 transformations->Add(new ResolvePadAttributes);
119 transformations->Add(new ResolvePadV2Attributes);
120 transformations->Add(new ResolveStridedSliceAttributes);
121 transformations->Add(new ResolveSliceAttributes);
122 transformations->Add(new ResolveReduceAttributes);
123 transformations->Add(new ResolveConstantShapeOrRank);
124 transformations->Add(new MakeInitialDequantizeOperator);
125 transformations->Add(new UnpartitionEmbeddingLookup);
126 transformations->Add(new ResolveGatherAttributes);
127 }
128
SupportsQuantization(FileFormat format)129 bool SupportsQuantization(FileFormat format) {
130 return (format == GRAPHVIZ_DOT || format == TFLITE);
131 }
132
SupportsFusedActivationFunction(FileFormat format)133 bool SupportsFusedActivationFunction(FileFormat format) {
134 return (format == GRAPHVIZ_DOT || format == TFLITE);
135 }
136
SupportsLstmCell(FileFormat format)137 bool SupportsLstmCell(FileFormat format) {
138 return (format == TENSORFLOW_GRAPHDEF || format == GRAPHVIZ_DOT ||
139 format == TFLITE);
140 }
141
SupportsPreallocatedWorkspace(FileFormat format)142 bool SupportsPreallocatedWorkspace(FileFormat format) {
143 return (format == TFLITE);
144 }
145
SupportsShuffledFCWeights(FileFormat format)146 bool SupportsShuffledFCWeights(FileFormat format) { return format == TFLITE; }
147
IsRealValued(toco::ArrayDataType type)148 bool IsRealValued(toco::ArrayDataType type) {
149 // TODO(benoitjacob) - this is hardcoding that uint8 and int16 are only used
150 // for quantized real-number values, and no other integer type is ever used
151 // for that. This is dirty, should be resolved as part of a more general push
152 // to more explicitly distinguish between true-integers and
153 // integers used as quantized values representing real numbers.
154 return static_cast<bool>(type == toco::ArrayDataType::kFloat ||
155 type == toco::ArrayDataType::kUint8 ||
156 type == toco::ArrayDataType::kInt16);
157 }
158
SetFinalDataTypeOnInputs(const TocoFlags & toco_flags,Model * model)159 void SetFinalDataTypeOnInputs(const TocoFlags& toco_flags, Model* model) {
160 const FileFormat output_format = toco_flags.output_format();
161 ArrayDataType type;
162 if (!SupportsQuantization(output_format)) {
163 // Data type is implicitly float for non-quantized formats
164 type = ArrayDataType::kFloat;
165 } else if (toco_flags.has_inference_input_type()) {
166 type = ConvertIODataTypeToArrayDataType(toco_flags.inference_input_type());
167 } else if (toco_flags.has_inference_type()) {
168 type = ConvertIODataTypeToArrayDataType(toco_flags.inference_type());
169 } else {
170 // Nothing to do. Data types stay as-is.
171 return;
172 }
173
174 for (int i = 0; i < model->flags.input_arrays_size(); i++) {
175 std::string const& array_name = model->flags.input_arrays(i).name();
176 auto* array = &model->GetArray(array_name);
177 // Note that the notion of changing data types only applies to real-numbers
178 // arrays (see the documentation for inference_input_type).
179 // TODO(benoitjacob) this is assuming that uint8 arrays are quantized,
180 // i.e. represent real numbers by means of quantization parameters,
181 // and not plain integer uint8 input arrays.
182 if (!IsRealValued(array->data_type)) {
183 // Ignore non-real data types.
184 continue;
185 }
186 // The enum value QUANTIZED_UINT8 for --inference_type and
187 // --inference_input_type has long meant just 'QUANTIZED', being used as
188 // well in mixed 8-bit / 16-bit quantized models. However,
189 // ConvertIODataTypeToArrayDataType still interpretes it as meaning 8bit,
190 // and people have run into issues in the situation where they have an
191 // already mixed 8-bit / 16-bit quantized model in TFLITE format and
192 // want to run it again through toco, without having to re-specify all the
193 // extra array info that was used in the (complicated) process of initially
194 // quantizing that model. In order to have --inference_type=QUANTIZED_UINT8
195 // just work in that case, we implement the logic that when an array is
196 // already quantized, if --inference_type is quantized (so we're not
197 // asking to dequantize here), no change of quantized data type is to be
198 // recorded.
199 if (array->data_type != toco::ArrayDataType::kFloat &&
200 type != toco::ArrayDataType::kFloat) {
201 continue;
202 }
203
204 array->final_data_type = type;
205 }
206 }
207
208 } // namespace
209
Import(const TocoFlags & toco_flags,const ModelFlags & model_flags,const std::string & input_file_contents)210 std::unique_ptr<Model> Import(const TocoFlags& toco_flags,
211 const ModelFlags& model_flags,
212 const std::string& input_file_contents) {
213 std::unique_ptr<Model> model;
214 switch (toco_flags.input_format()) {
215 case TENSORFLOW_GRAPHDEF: {
216 TensorFlowImportFlags tf_import_flags;
217 tf_import_flags.drop_control_dependency =
218 toco_flags.has_drop_control_dependency()
219 ? toco_flags.drop_control_dependency()
220 : (toco_flags.output_format() != TENSORFLOW_GRAPHDEF);
221
222 tf_import_flags.import_all_ops_as_unsupported =
223 toco_flags.force_select_tf_ops();
224
225 model = ImportTensorFlowGraphDef(model_flags, tf_import_flags,
226 input_file_contents);
227 break;
228 }
229 case TFLITE:
230 model = toco::tflite::Import(model_flags, input_file_contents);
231 ResolveModelFlags(model_flags, model.get());
232 CheckInvariants(*model);
233 break;
234 default:
235 LOG(FATAL) << "Unhandled input_format='"
236 << FileFormat_Name(toco_flags.input_format()) << "'";
237 }
238
239 LogDump(kLogLevelModelChanged, "AT IMPORT", *model);
240
241 return model;
242 }
243
TransformWithStatus(const TocoFlags & toco_flags,Model * model)244 tensorflow::Status TransformWithStatus(const TocoFlags& toco_flags,
245 Model* model) {
246 const FileFormat output_format = toco_flags.output_format();
247 const IODataType inference_type = toco_flags.inference_type();
248
249 const bool quantize_output =
250 SupportsQuantization(output_format) &&
251 (inference_type == QUANTIZED_UINT8 || inference_type == QUANTIZED_INT16);
252
253 if (quantize_output) {
254 QCHECK_NE(toco_flags.inference_input_type(), FLOAT)
255 << "Quantized inference is not allowed with float inputs.";
256 }
257
258 // Clean up after import.
259 SetFinalDataTypeOnInputs(toco_flags, model);
260 UseArraysExtraInfo(model, quantize_output);
261 FinishBuildingRNNStates(model);
262
263 // Remove unused ops before performing any other optimizations. This is to
264 // stop optimizations from crossing the input/output boundaries. For example
265 // this will stop BatchNorm fusing if the output node is in between a conv
266 // and BatchNorm layers.
267 TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
268 model, "Removing unused ops", {new toco::RemoveUnusedOp}));
269
270 GraphTransformationsSet transformations;
271 MakeGeneralGraphTransformationsSet(&transformations);
272 auto* remove_trivial_reshape = new RemoveTrivialReshape;
273 transformations.Add(remove_trivial_reshape);
274 auto* resolve_constant_fake_quant = new ResolveConstantFakeQuant;
275 if (quantize_output) {
276 resolve_constant_fake_quant->set_propagate_fake_quant_num_bits(
277 toco_flags.propagate_fake_quant_num_bits());
278 }
279 transformations.Add(resolve_constant_fake_quant);
280 if (SupportsFusedActivationFunction(output_format)) {
281 transformations.Add(new FuseActivationFunctions);
282 } else {
283 transformations.Add(new UnfuseActivationFunctions);
284 }
285 if (toco_flags.drop_fake_quant()) {
286 transformations.Add(new DropFakeQuant);
287 } else {
288 // See the doc for --reorder_across_fake_quant: that flag is needed to
289 // support some existing models, e.g. WordLens, that have FakeQuant
290 // nodes in the wrong places.
291 // TODO(benoitjacob): drop special casing when we can.
292 if ((quantize_output && toco_flags.reorder_across_fake_quant())) {
293 transformations.Add(new DropFakeQuant);
294 }
295 }
296 transformations.Add(new ConvertPureConvToDepthwise);
297 if (SupportsLstmCell(output_format)) {
298 if (!toco_flags.debug_disable_recurrent_cell_fusion()) {
299 transformations.Add(new IdentifyLstmCell);
300 }
301 if (output_format == TFLITE && toco_flags.split_tflite_lstm_inputs()) {
302 transformations.Add(new toco::SplitLstmCellInputs);
303 } else {
304 transformations.Add(new toco::MergeLstmCellInputs);
305 }
306 }
307 transformations.Add(new ResolveConstantConcatenation);
308 // TODO(b/116063589): TF GraphDef doesn't support dilations on its depthwise
309 // conv, so we need to make sure we don't convert to dilated depthwise conv
310 // when outputing to TF GraphDef.
311 auto* identify_dilated_conv = new IdentifyDilatedConv;
312 if (output_format == TENSORFLOW_GRAPHDEF) {
313 identify_dilated_conv->set_identify_depthwise_conv(false);
314 }
315 transformations.Add(identify_dilated_conv);
316 TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
317 model, "general graph transformations", transformations));
318
319 if (quantize_output) {
320 if (toco_flags.propagate_fake_quant_num_bits()) {
321 TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
322 model, "fake quant propagation graph transformations",
323 {new PropagateFakeQuantNumBits}));
324 }
325 TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
326 model, "pre-quantization graph transformations",
327 {
328 new HardcodeMinMax,
329 new DropFakeQuant,
330 }));
331 }
332
333 // Try to merge bidirectional sequence lstm or rnn if present.
334 GraphTransformationsSet bidirectional_transformations;
335 bidirectional_transformations.Add(new RemoveUnusedOp);
336 bidirectional_transformations.Add(new toco::GroupBidirectionalSequenceLstm);
337 bidirectional_transformations.Add(new toco::GroupBidirectionalSequenceRnn);
338 bidirectional_transformations.Add(
339 new toco::GroupDynamicBidirectionalSequenceRnn);
340 bidirectional_transformations.Add(
341 new toco::GroupDynamicBidirectionalSequenceLstm);
342 TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
343 model, "Group bidirectional sequence lstm/rnn",
344 bidirectional_transformations));
345
346 // Fix any issues with IO edges. This must happen after any transform that
347 // may modify the structure of the edges.
348 FixEdgeArrays(model);
349 FixOperatorOrdering(model);
350
351 if (quantize_output) {
352 // If the user specified default min/max ranges we need to set all arrays
353 // that didn't either have a min/max specified or get one set via
354 // HardcodeMinMax or PropagateFakeQuantNumBits. This may require running
355 // HardcodeMinMax to move changes through the graph as we make changes.
356 auto propagate_default_min_max =
357 absl::make_unique<PropagateDefaultMinMax>();
358 bool has_default_ranges_flag = (toco_flags.has_default_ranges_min() &&
359 toco_flags.has_default_ranges_max());
360 if (has_default_ranges_flag) {
361 propagate_default_min_max->DefineTypeRange(
362 ArrayDataType::kUint8, toco_flags.default_ranges_min(),
363 toco_flags.default_ranges_max());
364 }
365 if (toco_flags.has_default_int16_ranges_min() &&
366 toco_flags.has_default_int16_ranges_max()) {
367 propagate_default_min_max->DefineTypeRange(
368 ArrayDataType::kInt16, toco_flags.default_int16_ranges_min(),
369 toco_flags.default_int16_ranges_max());
370 }
371 if (propagate_default_min_max->has_any_ranges_defined()) {
372 TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
373 model, "default min-max range propagation graph transformations",
374 {
375 propagate_default_min_max.release(),
376 new HardcodeMinMax,
377 }));
378 }
379
380 CheckIsReadyForQuantization(*model);
381 auto* ensure_safe_for_int8_kernels =
382 new EnsureUint8WeightsSafeForFastInt8Kernels;
383 ensure_safe_for_int8_kernels->set_allow_nudging_weights(
384 toco_flags.allow_nudging_weights_to_use_fast_gemm_kernel());
385 ensure_safe_for_int8_kernels->set_has_default_ranges_flag(
386 has_default_ranges_flag);
387 TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
388 model, "quantization graph transformations",
389 {
390 new RemoveTrivialQuantizedActivationFunc,
391 new RemoveTrivialQuantizedMinMax,
392 new Quantize,
393 new RemoveFinalDequantizeOp,
394 ensure_safe_for_int8_kernels,
395 }));
396 if (SupportsShuffledFCWeights(output_format)) {
397 TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
398 model, "shuffling of FC weights", {new ShuffleFCWeights}));
399 }
400 } else {
401 GraphTransformationsSet dequantization_transformations{new Dequantize};
402 // Dequantize creates FakeQuant nodes. We may want to discard
403 // those immediately.
404 if (toco_flags.drop_fake_quant()) {
405 dequantization_transformations.Add(new DropFakeQuant);
406 }
407
408 TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
409 model, "dequantization graph transformations",
410 dequantization_transformations));
411 }
412
413 // It's actually unfortunate we have to put the graph transformation here:
414 // If user choose to use broadcast mul to do nearset neighbor upsampling. That
415 // is:
416 // Input [1, 20, 1, 20, 1, 64] * ones [1, 3, 1, 3, 1, 1]
417 // The problem is if the input is quantized, then the quantization parameters
418 // will be slightly different for the input and the output. (although the
419 // difference is really small).
420 // But, since we're changing this pattern to be pack-based which enforce
421 // the quantization parameters to be exactly the same.
422 // So we have to wait for all quantization parameters being resolved and
423 // propagated and create our own.
424 // We may need to revisit this logic later.
425 GraphTransformationsSet nearest_upsample_transformations;
426 nearest_upsample_transformations.Add(new toco::IdentifyNearestUpsample);
427 TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
428 model, "Identify nearest upsample.", nearest_upsample_transformations));
429
430 if (output_format == TENSORFLOW_GRAPHDEF) {
431 EncodeConstantArraysMinMaxByWrappingThemInFakeQuantNodes(model);
432 }
433
434 // Deduplicate large constant arrays.
435 DedupeConstantArrays(model, toco_flags.dedupe_array_min_size_bytes());
436
437 LogDump(kLogLevelModelChanged, "AFTER TRANSFORMATIONS", *model);
438
439 if (output_format != GRAPHVIZ_DOT && output_format != TFLITE) {
440 // By now there shouldn't be any unsupported ops when exporting to
441 // TensorFlow GraphDef.
442 CheckUnsupportedOperations(*model);
443 }
444
445 if (SupportsPreallocatedWorkspace(output_format)) {
446 AllocateTransientArrays(model, kDefaultTransientDataAlignment);
447 LogDump(kLogLevelModelChanged, "AFTER ALLOCATION", *model);
448 }
449
450 CheckModelCounts(*model);
451 CheckFinalDataTypesSatisfied(*model);
452
453 // Estimate and log the number of arithmetic ops
454 int64_t ops_count = 0;
455 if (EstimateArithmeticOpsCount(*model, &ops_count)) {
456 LOG(INFO) << "Estimated count of arithmetic ops: " << ops_count
457 << " ops, equivalently " << ops_count / 2 << " MACs";
458 }
459 model->ops_count = ops_count;
460 int64_t params_count = 0;
461
462 // Compute and log the number of parameters
463 for (const auto& array_pair : model->GetArrayMap()) {
464 const Array& array = *array_pair.second;
465 if (!array.buffer) {
466 // not a parameter array
467 continue;
468 }
469 params_count += RequiredBufferSizeForShape(array.shape());
470 }
471 LOG(INFO) << "Number of parameters: " << params_count;
472 return tensorflow::Status::OK();
473 }
474
Export(const TocoFlags & toco_flags,const Model & model,bool allow_custom_ops,std::string * output_file_contents)475 tensorflow::Status Export(const TocoFlags& toco_flags, const Model& model,
476 bool allow_custom_ops,
477 std::string* output_file_contents) {
478 switch (toco_flags.output_format()) {
479 case TENSORFLOW_GRAPHDEF:
480 ExportTensorFlowGraphDef(model, output_file_contents);
481 break;
482 case TFLITE: {
483 toco::tflite::ExportParams params;
484
485 params.enable_select_tf_ops =
486 toco_flags.force_select_tf_ops() || toco_flags.enable_select_tf_ops();
487 params.allow_custom_ops = allow_custom_ops;
488 params.allow_dynamic_tensors = toco_flags.allow_dynamic_tensors();
489
490 if (toco_flags.post_training_quantize()) {
491 if (toco_flags.quantize_to_float16()) {
492 params.quantize_weights = tflite::QuantizedBufferType::FLOAT16;
493 } else {
494 params.quantize_weights = tflite::QuantizedBufferType::INT8;
495 }
496 }
497 auto status = toco::tflite::Export(model, output_file_contents, params);
498 if (!status.ok()) {
499 LOG(ERROR) << status.error_message();
500 }
501 return status;
502 } break;
503 case GRAPHVIZ_DOT:
504 DumpGraphviz(model, output_file_contents, "Computation Graph");
505 break;
506 default:
507 LOG(FATAL) << "Unhandled output_format='"
508 << FileFormat_Name(toco_flags.output_format()) << "'";
509 }
510 return tensorflow::Status();
511 }
512
513 } // namespace toco
514