• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/toco_tooling.h"
16 
17 #include <cstdlib>
18 #include <memory>
19 #include <set>
20 
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_join.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/lite/toco/allocate_transient_arrays.h"
25 #include "tensorflow/lite/toco/dump_graphviz.h"
26 #include "tensorflow/lite/toco/export_tensorflow.h"
27 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
28 #include "tensorflow/lite/toco/import_tensorflow.h"
29 #include "tensorflow/lite/toco/model.h"
30 #include "tensorflow/lite/toco/model_flags.pb.h"
31 #include "tensorflow/lite/toco/tflite/export.h"
32 #include "tensorflow/lite/toco/tflite/import.h"
33 #include "tensorflow/lite/toco/toco_flags.pb.h"
34 #include "tensorflow/lite/toco/tooling_util.h"
35 
36 namespace toco {
37 namespace {
38 // CHECK-fails if the model contains a kUnsupported operation.
CheckUnsupportedOperations(const Model & model)39 void CheckUnsupportedOperations(const Model& model) {
40   std::set<std::string> unsupported_ops;
41   for (auto& op : model.operators) {
42     if (op->type == OperatorType::kUnsupported) {
43       unsupported_ops.insert(
44           static_cast<const TensorFlowUnsupportedOperator*>(op.get())
45               ->tensorflow_op);
46     }
47   }
48   QCHECK(unsupported_ops.empty())
49       << "These unsupported ops were not removed by graph transformations: "
50       << absl::StrJoin(unsupported_ops, ", ");
51 }
52 
MakeGeneralGraphTransformationsSet(GraphTransformationsSet * transformations)53 void MakeGeneralGraphTransformationsSet(
54     GraphTransformationsSet* transformations) {
55   CHECK(transformations->empty());
56   transformations->Add(new ConvertExpandDimsToReshape);
57   transformations->Add(new ConvertMatrixDiagV2OrV3ToV1);
58   transformations->Add(new ConvertMatrixSetDiagV2OrV3ToV1);
59   transformations->Add(new ConvertSqueezeToReshape);
60   transformations->Add(new ConvertTrivialAddNToAdd);
61   transformations->Add(new ConvertTrivialPackToReshape);
62   transformations->Add(new ConvertTrivialTileToConcat);
63   transformations->Add(new ConvertTrivialTransposeToReshape);
64   transformations->Add(new ConvertReorderAxes);
65   transformations->Add(new ResolveReshapeAttributes);
66   transformations->Add(new ResolveTransposeAttributes);
67   transformations->Add(new PropagateActivationFunctionIntoConstants);
68   transformations->Add(new PropagateArrayDataTypes);
69   transformations->Add(new PropagateFixedSizes);
70   transformations->Add(new RemoveSuccessiveTranspose);
71   transformations->Add(new RemoveTensorFlowAssert);
72   transformations->Add(new RemoveTensorFlowIdentity);
73   transformations->Add(new RemoveTrivialConcatenation);
74   transformations->Add(new RemoveTrivialConcatenationInput);
75   transformations->Add(new RemoveTrivialFakeQuant);
76   transformations->Add(new RemoveTrivialSlice);
77   transformations->Add(new RemoveUnusedOp);
78   transformations->Add(new EnsureBiasVectors);
79   transformations->Add(new ResolveReorderAxes);
80   transformations->Add(new UnrollBatchMatMul);
81   transformations->Add(new ResolveTensorFlowMatMul);
82   transformations->Add(new FuseBinaryIntoPrecedingAffine);
83   transformations->Add(new FuseBinaryIntoFollowingAffine);
84   transformations->Add(new FuseBroadcastIntoFollowingBinary);
85   transformations->Add(new MergeReshapeIntoPrecedingTranspose);
86   transformations->Add(new MoveBinaryOperatorBeforeReshape);
87   transformations->Add(new ReorderElementwiseUnary);
88   transformations->Add(new ReorderReshapeTranspose);
89   transformations->Add(new ResolveBatchNormalization);
90   transformations->Add(new ResolveConstantBinaryOperator);
91   transformations->Add(new ResolveConstantFill);
92   transformations->Add(new ResolveConstantGather);
93   transformations->Add(new ResolveConstantPack);
94   transformations->Add(new ResolveConstantRandomUniform);
95   transformations->Add(new ResolveConstantRange);
96   transformations->Add(new ResolveConstantReshape);
97   transformations->Add(new ResolveConstantSelect);
98   transformations->Add(new ResolveConstantSlice);
99   transformations->Add(new ResolveConstantStridedSlice);
100   transformations->Add(new ResolveConstantTile);
101   transformations->Add(new ResolveConstantTranspose);
102   transformations->Add(new ResolveConstantUnaryOperator);
103   transformations->Add(new ResolveTensorFlowMerge);
104   transformations->Add(new ResolveSqueezeAttributes);
105   transformations->Add(new ResolveTensorFlowSwitch);
106   transformations->Add(new ResolveTensorFlowConcat);
107   transformations->Add(new ResolveMultiplyByZero);
108   transformations->Add(new IdentifyHardSwish);
109   transformations->Add(new IdentifyL2Normalization);
110   transformations->Add(new IdentifyL2Pool);
111   transformations->Add(new IdentifyRelu1);
112   transformations->Add(new IdentifyPRelu);
113   transformations->Add(new RemoveTrivialBinaryOperator);
114   transformations->Add(new ResolveFakeQuantArgsFromVars);
115   transformations->Add(new ReadArrayMinmaxAndNarrowRangeFromFakeQuant);
116   transformations->Add(new ResolveSpaceToBatchNDAttributes);
117   transformations->Add(new ResolveBatchToSpaceNDAttributes);
118   transformations->Add(new ResolvePadAttributes);
119   transformations->Add(new ResolvePadV2Attributes);
120   transformations->Add(new ResolveStridedSliceAttributes);
121   transformations->Add(new ResolveSliceAttributes);
122   transformations->Add(new ResolveReduceAttributes);
123   transformations->Add(new ResolveConstantShapeOrRank);
124   transformations->Add(new MakeInitialDequantizeOperator);
125   transformations->Add(new UnpartitionEmbeddingLookup);
126   transformations->Add(new ResolveGatherAttributes);
127 }
128 
SupportsQuantization(FileFormat format)129 bool SupportsQuantization(FileFormat format) {
130   return (format == GRAPHVIZ_DOT || format == TFLITE);
131 }
132 
SupportsFusedActivationFunction(FileFormat format)133 bool SupportsFusedActivationFunction(FileFormat format) {
134   return (format == GRAPHVIZ_DOT || format == TFLITE);
135 }
136 
SupportsLstmCell(FileFormat format)137 bool SupportsLstmCell(FileFormat format) {
138   return (format == TENSORFLOW_GRAPHDEF || format == GRAPHVIZ_DOT ||
139           format == TFLITE);
140 }
141 
SupportsPreallocatedWorkspace(FileFormat format)142 bool SupportsPreallocatedWorkspace(FileFormat format) {
143   return (format == TFLITE);
144 }
145 
SupportsShuffledFCWeights(FileFormat format)146 bool SupportsShuffledFCWeights(FileFormat format) { return format == TFLITE; }
147 
IsRealValued(toco::ArrayDataType type)148 bool IsRealValued(toco::ArrayDataType type) {
149   // TODO(benoitjacob) - this is hardcoding that uint8 and int16 are only used
150   // for quantized real-number values, and no other integer type is ever used
151   // for that. This is dirty, should be resolved as part of a more general push
152   // to more explicitly distinguish between true-integers and
153   // integers used as quantized values representing real numbers.
154   return static_cast<bool>(type == toco::ArrayDataType::kFloat ||
155                            type == toco::ArrayDataType::kUint8 ||
156                            type == toco::ArrayDataType::kInt16);
157 }
158 
SetFinalDataTypeOnInputs(const TocoFlags & toco_flags,Model * model)159 void SetFinalDataTypeOnInputs(const TocoFlags& toco_flags, Model* model) {
160   const FileFormat output_format = toco_flags.output_format();
161   ArrayDataType type;
162   if (!SupportsQuantization(output_format)) {
163     // Data type is implicitly float for non-quantized formats
164     type = ArrayDataType::kFloat;
165   } else if (toco_flags.has_inference_input_type()) {
166     type = ConvertIODataTypeToArrayDataType(toco_flags.inference_input_type());
167   } else if (toco_flags.has_inference_type()) {
168     type = ConvertIODataTypeToArrayDataType(toco_flags.inference_type());
169   } else {
170     // Nothing to do. Data types stay as-is.
171     return;
172   }
173 
174   for (int i = 0; i < model->flags.input_arrays_size(); i++) {
175     std::string const& array_name = model->flags.input_arrays(i).name();
176     auto* array = &model->GetArray(array_name);
177     // Note that the notion of changing data types only applies to real-numbers
178     // arrays (see the documentation for inference_input_type).
179     // TODO(benoitjacob) this is assuming that uint8 arrays are quantized,
180     // i.e. represent real numbers by means of quantization parameters,
181     // and not plain integer uint8 input arrays.
182     if (!IsRealValued(array->data_type)) {
183       // Ignore non-real data types.
184       continue;
185     }
186     // The enum value QUANTIZED_UINT8 for --inference_type and
187     // --inference_input_type has long meant just 'QUANTIZED', being used as
188     // well in mixed 8-bit / 16-bit quantized models. However,
189     // ConvertIODataTypeToArrayDataType still interpretes it as meaning 8bit,
190     // and people have run into issues in the situation where they have an
191     // already mixed 8-bit / 16-bit quantized model in TFLITE format and
192     // want to run it again through toco, without having to re-specify all the
193     // extra array info that was used in the (complicated) process of initially
194     // quantizing that model. In order to have --inference_type=QUANTIZED_UINT8
195     // just work in that case, we implement the logic that when an array is
196     // already quantized, if  --inference_type is quantized (so we're not
197     // asking to dequantize here), no change of quantized data type is to be
198     // recorded.
199     if (array->data_type != toco::ArrayDataType::kFloat &&
200         type != toco::ArrayDataType::kFloat) {
201       continue;
202     }
203 
204     array->final_data_type = type;
205   }
206 }
207 
208 }  // namespace
209 
Import(const TocoFlags & toco_flags,const ModelFlags & model_flags,const std::string & input_file_contents)210 std::unique_ptr<Model> Import(const TocoFlags& toco_flags,
211                               const ModelFlags& model_flags,
212                               const std::string& input_file_contents) {
213   std::unique_ptr<Model> model;
214   switch (toco_flags.input_format()) {
215     case TENSORFLOW_GRAPHDEF: {
216       TensorFlowImportFlags tf_import_flags;
217       tf_import_flags.drop_control_dependency =
218           toco_flags.has_drop_control_dependency()
219               ? toco_flags.drop_control_dependency()
220               : (toco_flags.output_format() != TENSORFLOW_GRAPHDEF);
221 
222       tf_import_flags.import_all_ops_as_unsupported =
223           toco_flags.force_select_tf_ops();
224 
225       model = ImportTensorFlowGraphDef(model_flags, tf_import_flags,
226                                        input_file_contents);
227       break;
228     }
229     case TFLITE:
230       model = toco::tflite::Import(model_flags, input_file_contents);
231       ResolveModelFlags(model_flags, model.get());
232       CheckInvariants(*model);
233       break;
234     default:
235       LOG(FATAL) << "Unhandled input_format='"
236                  << FileFormat_Name(toco_flags.input_format()) << "'";
237   }
238 
239   LogDump(kLogLevelModelChanged, "AT IMPORT", *model);
240 
241   return model;
242 }
243 
TransformWithStatus(const TocoFlags & toco_flags,Model * model)244 tensorflow::Status TransformWithStatus(const TocoFlags& toco_flags,
245                                        Model* model) {
246   const FileFormat output_format = toco_flags.output_format();
247   const IODataType inference_type = toco_flags.inference_type();
248 
249   const bool quantize_output =
250       SupportsQuantization(output_format) &&
251       (inference_type == QUANTIZED_UINT8 || inference_type == QUANTIZED_INT16);
252 
253   if (quantize_output) {
254     QCHECK_NE(toco_flags.inference_input_type(), FLOAT)
255         << "Quantized inference is not allowed with float inputs.";
256   }
257 
258   // Clean up after import.
259   SetFinalDataTypeOnInputs(toco_flags, model);
260   UseArraysExtraInfo(model, quantize_output);
261   FinishBuildingRNNStates(model);
262 
263   // Remove unused ops before performing any other optimizations. This is to
264   // stop optimizations from crossing the input/output boundaries. For example
265   // this will stop BatchNorm fusing if the output node is in between a conv
266   // and BatchNorm layers.
267   TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
268       model, "Removing unused ops", {new toco::RemoveUnusedOp}));
269 
270   GraphTransformationsSet transformations;
271   MakeGeneralGraphTransformationsSet(&transformations);
272   auto* remove_trivial_reshape = new RemoveTrivialReshape;
273   transformations.Add(remove_trivial_reshape);
274   auto* resolve_constant_fake_quant = new ResolveConstantFakeQuant;
275   if (quantize_output) {
276     resolve_constant_fake_quant->set_propagate_fake_quant_num_bits(
277         toco_flags.propagate_fake_quant_num_bits());
278   }
279   transformations.Add(resolve_constant_fake_quant);
280   if (SupportsFusedActivationFunction(output_format)) {
281     transformations.Add(new FuseActivationFunctions);
282   } else {
283     transformations.Add(new UnfuseActivationFunctions);
284   }
285   if (toco_flags.drop_fake_quant()) {
286     transformations.Add(new DropFakeQuant);
287   } else {
288     // See the doc for --reorder_across_fake_quant: that flag is needed to
289     // support some existing models, e.g. WordLens, that have FakeQuant
290     // nodes in the wrong places.
291     // TODO(benoitjacob): drop special casing when we can.
292     if ((quantize_output && toco_flags.reorder_across_fake_quant())) {
293       transformations.Add(new DropFakeQuant);
294     }
295   }
296   transformations.Add(new ConvertPureConvToDepthwise);
297   if (SupportsLstmCell(output_format)) {
298     if (!toco_flags.debug_disable_recurrent_cell_fusion()) {
299       transformations.Add(new IdentifyLstmCell);
300     }
301     if (output_format == TFLITE && toco_flags.split_tflite_lstm_inputs()) {
302       transformations.Add(new toco::SplitLstmCellInputs);
303     } else {
304       transformations.Add(new toco::MergeLstmCellInputs);
305     }
306   }
307   transformations.Add(new ResolveConstantConcatenation);
308   // TODO(b/116063589): TF GraphDef doesn't support dilations on its depthwise
309   // conv, so we need to make sure we don't convert to dilated depthwise conv
310   // when outputing to TF GraphDef.
311   auto* identify_dilated_conv = new IdentifyDilatedConv;
312   if (output_format == TENSORFLOW_GRAPHDEF) {
313     identify_dilated_conv->set_identify_depthwise_conv(false);
314   }
315   transformations.Add(identify_dilated_conv);
316   TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
317       model, "general graph transformations", transformations));
318 
319   if (quantize_output) {
320     if (toco_flags.propagate_fake_quant_num_bits()) {
321       TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
322           model, "fake quant propagation graph transformations",
323           {new PropagateFakeQuantNumBits}));
324     }
325     TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
326         model, "pre-quantization graph transformations",
327         {
328             new HardcodeMinMax,
329             new DropFakeQuant,
330         }));
331   }
332 
333   // Try to merge bidirectional sequence lstm or rnn if present.
334   GraphTransformationsSet bidirectional_transformations;
335   bidirectional_transformations.Add(new RemoveUnusedOp);
336   bidirectional_transformations.Add(new toco::GroupBidirectionalSequenceLstm);
337   bidirectional_transformations.Add(new toco::GroupBidirectionalSequenceRnn);
338   bidirectional_transformations.Add(
339       new toco::GroupDynamicBidirectionalSequenceRnn);
340   bidirectional_transformations.Add(
341       new toco::GroupDynamicBidirectionalSequenceLstm);
342   TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
343       model, "Group bidirectional sequence lstm/rnn",
344       bidirectional_transformations));
345 
346   // Fix any issues with IO edges. This must happen after any transform that
347   // may modify the structure of the edges.
348   FixEdgeArrays(model);
349   FixOperatorOrdering(model);
350 
351   if (quantize_output) {
352     // If the user specified default min/max ranges we need to set all arrays
353     // that didn't either have a min/max specified or get one set via
354     // HardcodeMinMax or PropagateFakeQuantNumBits. This may require running
355     // HardcodeMinMax to move changes through the graph as we make changes.
356     auto propagate_default_min_max =
357         absl::make_unique<PropagateDefaultMinMax>();
358     bool has_default_ranges_flag = (toco_flags.has_default_ranges_min() &&
359                                     toco_flags.has_default_ranges_max());
360     if (has_default_ranges_flag) {
361       propagate_default_min_max->DefineTypeRange(
362           ArrayDataType::kUint8, toco_flags.default_ranges_min(),
363           toco_flags.default_ranges_max());
364     }
365     if (toco_flags.has_default_int16_ranges_min() &&
366         toco_flags.has_default_int16_ranges_max()) {
367       propagate_default_min_max->DefineTypeRange(
368           ArrayDataType::kInt16, toco_flags.default_int16_ranges_min(),
369           toco_flags.default_int16_ranges_max());
370     }
371     if (propagate_default_min_max->has_any_ranges_defined()) {
372       TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
373           model, "default min-max range propagation graph transformations",
374           {
375               propagate_default_min_max.release(),
376               new HardcodeMinMax,
377           }));
378     }
379 
380     CheckIsReadyForQuantization(*model);
381     auto* ensure_safe_for_int8_kernels =
382         new EnsureUint8WeightsSafeForFastInt8Kernels;
383     ensure_safe_for_int8_kernels->set_allow_nudging_weights(
384         toco_flags.allow_nudging_weights_to_use_fast_gemm_kernel());
385     ensure_safe_for_int8_kernels->set_has_default_ranges_flag(
386         has_default_ranges_flag);
387     TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
388         model, "quantization graph transformations",
389         {
390             new RemoveTrivialQuantizedActivationFunc,
391             new RemoveTrivialQuantizedMinMax,
392             new Quantize,
393             new RemoveFinalDequantizeOp,
394             ensure_safe_for_int8_kernels,
395         }));
396     if (SupportsShuffledFCWeights(output_format)) {
397       TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
398           model, "shuffling of FC weights", {new ShuffleFCWeights}));
399     }
400   } else {
401     GraphTransformationsSet dequantization_transformations{new Dequantize};
402     // Dequantize creates FakeQuant nodes. We may want to discard
403     // those immediately.
404     if (toco_flags.drop_fake_quant()) {
405       dequantization_transformations.Add(new DropFakeQuant);
406     }
407 
408     TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
409         model, "dequantization graph transformations",
410         dequantization_transformations));
411   }
412 
413   // It's actually unfortunate we have to put the graph transformation here:
414   // If user choose to use broadcast mul to do nearset neighbor upsampling. That
415   // is:
416   //    Input [1, 20, 1, 20, 1, 64] * ones [1, 3, 1, 3, 1, 1]
417   // The problem is if the input is quantized, then the quantization parameters
418   // will be slightly different for the input and the output. (although the
419   // difference is really small).
420   // But, since we're changing this pattern to be pack-based which enforce
421   // the quantization parameters to be exactly the same.
422   // So we have to wait for all quantization parameters being resolved and
423   // propagated and create our own.
424   // We may need to revisit this logic later.
425   GraphTransformationsSet nearest_upsample_transformations;
426   nearest_upsample_transformations.Add(new toco::IdentifyNearestUpsample);
427   TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
428       model, "Identify nearest upsample.", nearest_upsample_transformations));
429 
430   if (output_format == TENSORFLOW_GRAPHDEF) {
431     EncodeConstantArraysMinMaxByWrappingThemInFakeQuantNodes(model);
432   }
433 
434   // Deduplicate large constant arrays.
435   DedupeConstantArrays(model, toco_flags.dedupe_array_min_size_bytes());
436 
437   LogDump(kLogLevelModelChanged, "AFTER TRANSFORMATIONS", *model);
438 
439   if (output_format != GRAPHVIZ_DOT && output_format != TFLITE) {
440     // By now there shouldn't be any unsupported ops when exporting to
441     // TensorFlow GraphDef.
442     CheckUnsupportedOperations(*model);
443   }
444 
445   if (SupportsPreallocatedWorkspace(output_format)) {
446     AllocateTransientArrays(model, kDefaultTransientDataAlignment);
447     LogDump(kLogLevelModelChanged, "AFTER ALLOCATION", *model);
448   }
449 
450   CheckModelCounts(*model);
451   CheckFinalDataTypesSatisfied(*model);
452 
453   // Estimate and log the number of arithmetic ops
454   int64_t ops_count = 0;
455   if (EstimateArithmeticOpsCount(*model, &ops_count)) {
456     LOG(INFO) << "Estimated count of arithmetic ops: " << ops_count
457               << " ops, equivalently " << ops_count / 2 << " MACs";
458   }
459   model->ops_count = ops_count;
460   int64_t params_count = 0;
461 
462   // Compute and log the number of parameters
463   for (const auto& array_pair : model->GetArrayMap()) {
464     const Array& array = *array_pair.second;
465     if (!array.buffer) {
466       // not a parameter array
467       continue;
468     }
469     params_count += RequiredBufferSizeForShape(array.shape());
470   }
471   LOG(INFO) << "Number of parameters: " << params_count;
472   return tensorflow::Status::OK();
473 }
474 
Export(const TocoFlags & toco_flags,const Model & model,bool allow_custom_ops,std::string * output_file_contents)475 tensorflow::Status Export(const TocoFlags& toco_flags, const Model& model,
476                           bool allow_custom_ops,
477                           std::string* output_file_contents) {
478   switch (toco_flags.output_format()) {
479     case TENSORFLOW_GRAPHDEF:
480       ExportTensorFlowGraphDef(model, output_file_contents);
481       break;
482     case TFLITE: {
483       toco::tflite::ExportParams params;
484 
485       params.enable_select_tf_ops =
486           toco_flags.force_select_tf_ops() || toco_flags.enable_select_tf_ops();
487       params.allow_custom_ops = allow_custom_ops;
488       params.allow_dynamic_tensors = toco_flags.allow_dynamic_tensors();
489 
490       if (toco_flags.post_training_quantize()) {
491         if (toco_flags.quantize_to_float16()) {
492           params.quantize_weights = tflite::QuantizedBufferType::FLOAT16;
493         } else {
494           params.quantize_weights = tflite::QuantizedBufferType::INT8;
495         }
496       }
497       auto status = toco::tflite::Export(model, output_file_contents, params);
498       if (!status.ok()) {
499         LOG(ERROR) << status.error_message();
500       }
501       return status;
502     } break;
503     case GRAPHVIZ_DOT:
504       DumpGraphviz(model, output_file_contents, "Computation Graph");
505       break;
506     default:
507       LOG(FATAL) << "Unhandled output_format='"
508                  << FileFormat_Name(toco_flags.output_format()) << "'";
509   }
510   return tensorflow::Status();
511 }
512 
513 }  // namespace toco
514