• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h"
16 
17 #include <ostream>
18 #include <unordered_set>
19 #include <utility>
20 
21 #include "llvm/Support/ToolOutputFile.h"
22 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
23 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
24 #include "mlir/Pass/Pass.h"  // from @llvm-project
25 #include "mlir/Support/FileUtilities.h"  // from @llvm-project
26 #include "mlir/Transforms/ViewOpGraph.h"  // from @llvm-project
27 #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
28 #include "tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h"
29 #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
30 #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
31 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
32 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
33 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
34 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
35 #include "tensorflow/core/framework/graph.pb.h"
36 #include "tensorflow/core/framework/types.pb.h"
37 #include "tensorflow/core/lib/core/errors.h"
38 #include "tensorflow/core/platform/status.h"
39 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
40 #include "tensorflow/lite/toco/model_flags.pb.h"
41 #include "tensorflow/lite/toco/toco_flags.pb.h"
42 #include "tensorflow/lite/toco/types.pb.h"
43 #include "tensorflow/stream_executor/lib/statusor.h"
44 
45 using stream_executor::port::StatusOr;
46 
47 namespace tensorflow {
48 namespace internal {
49 namespace {
50 
51 // Op def string for TFLite_Detection_PostProcess Op.
52 const char kDetectionPostProcessOp[] =
53     "name: 'TFLite_Detection_PostProcess' input_arg: { name: "
54     "'raw_outputs/box_encodings' type: DT_FLOAT } input_arg: { name: "
55     "'raw_outputs/class_predictions' type: DT_FLOAT } input_arg: { name: "
56     "'anchors' type: DT_FLOAT } output_arg: { name: "
57     "'TFLite_Detection_PostProcess' type: DT_FLOAT } output_arg: { name: "
58     "'TFLite_Detection_PostProcess:1' type: DT_FLOAT } output_arg: { name: "
59     "'TFLite_Detection_PostProcess:2' type: DT_FLOAT } output_arg: { name: "
60     "'TFLite_Detection_PostProcess:3' type: DT_FLOAT } attr : { name: "
61     "'h_scale' type: 'float'} attr : { name: 'max_classes_per_detection' "
62     "type: 'int'} attr : { name: 'max_detections' type: 'int'} attr : { "
63     "name: 'nms_iou_threshold' type: 'float'} attr : { name: "
64     "'nms_score_threshold' type: 'float'} attr : { name: 'num_classes' type: "
65     "'int'} attr : { name: 'w_scale' type: 'float'} attr : { name: 'x_scale' "
66     "type: 'float'} attr : { name: 'y_scale' type: 'float'} attr { name: "
67     "'detections_per_class' type: 'int' default_value { i : 100 }} attr { "
68     "name: 'use_regular_nms' type: 'bool' default_value { b : false }}";
69 
70 const char kUnidirectionalSequenceLstmOp[] =
71     "name: 'UnidirectionalSequenceLstm' input_arg: {name: 'Input' type: "
72     "DT_FLOAT} input_arg: { name: 'InputToInputWeights' type: DT_FLOAT } "
73     "input_arg: { name: 'InputToForgetWeights' type: DT_FLOAT } input_arg: { "
74     "name: 'InputToCellWeights' type: DT_FLOAT} input_arg: { name: "
75     "'InputToOutputWeights' type: DT_FLOAT } input_arg: { name: "
76     "'RecurrentToInputWeights' type: DT_FLOAT} input_arg: { name: "
77     "'RecurrentToForgetWeights' type: DT_FLOAT} input_arg: { name: "
78     "'RecurrentToCellWeights' type: DT_FLOAT } input_arg: { name: "
79     "'RecurrentToOutputWeights' type: DT_FLOAT } input_arg: { name: "
80     "'CellToInputWeights' type: DT_FLOAT} input_arg: { name: "
81     "'CellToForgetWeights' type: DT_FLOAT } input_arg: { name: "
82     "'CellToOutputWeights' type: DT_FLOAT } input_arg: { name: 'InputGateBias' "
83     "type: DT_FLOAT } input_arg: { name: 'ForgetGateBias' type: DT_FLOAT } "
84     "input_arg: { name: 'kCellGateBias' type: DT_FLOAT } input_arg: { name: "
85     "'OutputGateBias' type: DT_FLOAT } input_arg: { name: 'ProjectionWeights' "
86     "type: DT_FLOAT } input_arg: { name: 'ProjectionBias' type: DT_FLOAT } "
87     "input_arg: { name: 'InputActivationState' type: DT_FLOAT} input_arg: { "
88     "name: 'InputCellStateTensor' type: DT_FLOAT } "
89     "output_arg: { name: 'Concat' type: DT_FLOAT} "
90     "output_arg: { name: "
91     "'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: DT_FLOAT} "
92     "attr : { name: '_tflite_input_indices' type: 'list(int)'}";
93 
94 const char kUnidirectionalSequenceRnnOp[] =
95     "name: 'UnidirectionalSequenceRnn' input_arg: {name: 'Input' type: "
96     "DT_FLOAT} input_arg: { name: 'Weights' type: DT_FLOAT } "
97     "input_arg: { name: 'RecurrentWeights' type: DT_FLOAT } input_arg: { "
98     "name: 'Bias' type: DT_FLOAT} "
99     "input_arg: { name: 'HiddenState' type: DT_FLOAT} "
100     "output_arg: { name: "
101     "'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: "
102     "DT_FLOAT} "
103     "attr : { name: '_tflite_input_indices' type: 'list(int)'}";
104 
105 // Converts the toco::IODataType to tensorflow::DataType. Only contains the
106 // conversion mapping for constants defined in TFLite Python API.
ConvertIODataTypeToDataType(toco::IODataType dtype)107 DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
108   switch (dtype) {
109     case toco::IODataType::FLOAT:
110       return DT_FLOAT;
111     case toco::IODataType::FLOAT16:
112       return DT_HALF;
113     case toco::IODataType::FLOAT64:
114       return DT_DOUBLE;
115     case toco::IODataType::QUANTIZED_UINT8:
116       return DT_QUINT8;
117     case toco::IODataType::QUANTIZED_INT8:
118       return DT_QINT8;
119     case toco::IODataType::QUANTIZED_INT16:
120       return DT_QINT16;
121     case toco::IODataType::INT8:
122       return DT_INT8;
123     case toco::IODataType::INT16:
124       return DT_INT16;
125     case toco::IODataType::INT32:
126       return DT_INT32;
127     case toco::IODataType::UINT32:
128       return DT_UINT32;
129     case toco::IODataType::INT64:
130       return DT_INT64;
131     case toco::IODataType::UINT8:
132       return DT_UINT8;
133     case toco::IODataType::UINT64:
134       return DT_UINT64;
135     case toco::IODataType::STRING:
136       return DT_STRING;
137     case toco::IODataType::BOOL:
138       return DT_BOOL;
139     case toco::IODataType::COMPLEX64:
140       return DT_COMPLEX64;
141     case toco::IODataType::COMPLEX128:
142       return DT_COMPLEX128;
143     case toco::IODataType::RESOURCE:
144       return DT_RESOURCE;
145     case toco::IODataType::VARIANT:
146       return DT_VARIANT;
147     default:
148       return DT_INVALID;
149   }
150 }
151 
InputStatsToMinMax(double mean,double std,DataType type)152 StatusOr<std::pair<double, double>> InputStatsToMinMax(double mean, double std,
153                                                        DataType type) {
154   // Only qint8 and quint8 are considered here.
155   double qmin, qmax;
156   if (type == DT_QUINT8) {
157     qmin = 0.0;
158     qmax = 255.0;
159   } else if (type == DT_QINT8) {
160     qmin = -128.0;
161     qmax = 127.0;
162   } else {
163     return errors::InvalidArgument("Only int8 and uint8 are considered.");
164   }
165   return std::make_pair((qmin - mean) / std, (qmax - mean) / std);
166 }
167 
RegisterCustomBuiltinOps(const std::vector<string> extra_tf_opdefs)168 Status RegisterCustomBuiltinOps(const std::vector<string> extra_tf_opdefs) {
169   for (const auto& tf_opdefs_string : extra_tf_opdefs) {
170     tensorflow::OpDef opdef;
171     if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string,
172                                                            &opdef)) {
173       return errors::InvalidArgument("fail to parse extra OpDef");
174     }
175     // Make sure the op is not already registered. If registered continue.
176     const OpRegistrationData* op_reg =
177         tensorflow::OpRegistry::Global()->LookUp(opdef.name());
178     if (op_reg) continue;
179 
180     tensorflow::OpRegistry::Global()->Register(
181         [opdef](tensorflow::OpRegistrationData* op_reg_data) -> Status {
182           *op_reg_data = tensorflow::OpRegistrationData(opdef);
183           return Status::OK();
184         });
185   }
186   return Status::OK();
187 }
188 
189 }  // namespace
190 
RegisterAllCustomOps(const toco::TocoFlags & toco_flags)191 Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags) {
192   // Register any custom OpDefs.
193   std::vector<string> extra_tf_opdefs(toco_flags.custom_opdefs().begin(),
194                                       toco_flags.custom_opdefs().end());
195   extra_tf_opdefs.push_back(kDetectionPostProcessOp);
196   extra_tf_opdefs.push_back(kUnidirectionalSequenceLstmOp);
197   extra_tf_opdefs.push_back(kUnidirectionalSequenceRnnOp);
198   return RegisterCustomBuiltinOps(extra_tf_opdefs);
199 }
200 
PopulateQuantizationSpecs(const toco::ModelFlags & model_flags,const toco::TocoFlags & toco_flags,mlir::TFL::QuantizationSpecs * quant_specs,std::vector<string> * node_names,std::vector<string> * node_dtypes,std::vector<llvm::Optional<std::vector<int>>> * node_shapes,std::vector<llvm::Optional<double>> * node_mins,std::vector<llvm::Optional<double>> * node_maxs)201 Status PopulateQuantizationSpecs(
202     const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
203     mlir::TFL::QuantizationSpecs* quant_specs, std::vector<string>* node_names,
204     std::vector<string>* node_dtypes,
205     std::vector<llvm::Optional<std::vector<int>>>* node_shapes,
206     std::vector<llvm::Optional<double>>* node_mins,
207     std::vector<llvm::Optional<double>>* node_maxs) {
208   quant_specs->inference_input_type =
209       ConvertIODataTypeToDataType(toco_flags.inference_input_type());
210   tensorflow::DataType inference_type =
211       ConvertIODataTypeToDataType(toco_flags.inference_type());
212   // Use non-float flag `inference_input_type` to override the `inference_type`
213   // because we have to apply quantization to satisfy that.
214   if (quant_specs->inference_input_type != tensorflow::DT_FLOAT) {
215     inference_type = quant_specs->inference_input_type;
216   }
217 
218   for (auto& flag : model_flags.input_arrays()) {
219     node_names->push_back(flag.name());
220     // TOCO doesn't required `data_type` to be filled for every input.
221     // If it's not filled, make it an empty string so the importer will use
222     // the data type in the NodeDef.
223     auto toco_data_type = flag.data_type();
224     if (toco_data_type == ::toco::IODataType::IO_DATA_TYPE_UNKNOWN) {
225       node_dtypes->push_back("");
226     } else {
227       node_dtypes->push_back(
228           DataType_Name(ConvertIODataTypeToDataType(toco_data_type)));
229     }
230     if (flag.shape().unknown_rank()) {
231       node_shapes->push_back(llvm::None);
232     } else {
233       node_shapes->push_back(std::vector<int>(flag.shape().dims().begin(),
234                                               flag.shape().dims().end()));
235     }
236     // Currently, only UINT8 and INT8 require inputs stats
237     if (inference_type == DT_QINT8 || inference_type == DT_QUINT8) {
238       if (flag.has_mean_value() && flag.has_std_value()) {
239         TF_ASSIGN_OR_RETURN(
240             auto min_max, InputStatsToMinMax(flag.mean_value(),
241                                              flag.std_value(), inference_type));
242         node_mins->push_back(min_max.first);
243         node_maxs->push_back(min_max.second);
244       } else {
245         node_mins->push_back(llvm::None);
246         node_maxs->push_back(llvm::None);
247       }
248     }
249   }
250 
251   if (mlir::TFL::GetInputNodeQuantSpecs(*node_names, *node_mins, *node_maxs,
252                                         inference_type, quant_specs)) {
253     return errors::InvalidArgument("Failed to get input quant spec.");
254   }
255 
256   // Some extra flag related to post training quantization. If post-training
257   // quantization is enabled, `inference_type` and `inference_input_type` are
258   // not used by MLIR passes.
259   if (toco_flags.post_training_quantize()) {
260     quant_specs->weight_quantization = true;
261     if (toco_flags.quantize_to_float16()) {
262       quant_specs->inference_type = tensorflow::DT_HALF;
263       quant_specs->inference_input_type = tensorflow::DT_HALF;
264     } else {
265       quant_specs->inference_type = tensorflow::DT_QINT8;
266       quant_specs->inference_input_type = tensorflow::DT_QINT8;
267     }
268   }
269 
270   // Other flags.
271   if (toco_flags.has_default_ranges_min()) {
272     quant_specs->default_ranges.first = toco_flags.default_ranges_min();
273   }
274   if (toco_flags.has_default_ranges_max()) {
275     quant_specs->default_ranges.second = toco_flags.default_ranges_max();
276   }
277 
278   return ::tensorflow::Status::OK();
279 }
280 
281 // Dumps the op graph of the `module` to `filename` in DOT format.
DumpOpGraphToFile(mlir::ModuleOp module,const std::string & filename)282 Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) {
283   std::string error_message;
284   auto output = mlir::openOutputFile(filename, &error_message);
285   if (!error_message.empty()) {
286     return errors::InvalidArgument("Failed to open file in ", filename);
287   }
288   mlir::PassManager pm(module.getContext());
289   pm.addPass(mlir::createPrintOpGraphPass(output->os()));
290   if (failed(pm.run(module))) {
291     return errors::Unknown("Failed to dump Op Graph from MLIR module.");
292   }
293   output->keep();
294   return Status::OK();
295 }
296 
ConvertMLIRToTFLiteFlatBuffer(const toco::ModelFlags & model_flags,const toco::TocoFlags & toco_flags,mlir::OwningModuleRef module,const mlir::TFL::PassConfig & pass_config,const std::unordered_set<std::string> & saved_model_tags,string * result,llvm::Optional<tensorflow::Session * > session)297 Status ConvertMLIRToTFLiteFlatBuffer(
298     const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
299     mlir::OwningModuleRef module, const mlir::TFL::PassConfig& pass_config,
300     const std::unordered_set<std::string>& saved_model_tags, string* result,
301     llvm::Optional<tensorflow::Session*> session) {
302   bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
303   bool emit_select_tf_ops = toco_flags.enable_select_tf_ops();
304   bool emit_custom_ops = toco_flags.allow_custom_ops();
305   bool allow_all_select_tf_ops = toco_flags.allow_all_select_tf_ops();
306 
307   const std::unordered_set<std::string> select_user_tf_ops(
308       toco_flags.select_user_tf_ops().begin(),
309       toco_flags.select_user_tf_ops().end());
310 
311   if (toco_flags.has_dump_graphviz_dir()) {
312     TF_RETURN_IF_ERROR(DumpOpGraphToFile(
313         module.get(),
314         // rename once we enable the new converter feature flag.
315         absl::StrCat(toco_flags.dump_graphviz_dir(), "/toco_AT_IMPORT.dot")));
316   }
317 
318   mlir::PassManager pm(module->getContext(),
319                        mlir::OpPassManager::Nesting::Implicit);
320   ::tensorflow::SetCrashReproducer(pm);
321   pm.addInstrumentation(
322       std::make_unique<mlir::TFL::ErrorCollectorInstrumentation>(
323           module->getContext()));
324 
325   tensorflow::AddTFToTFLConversionPasses(model_flags, toco_flags, pass_config,
326                                          &pm, session);
327   // Convert back to outlined while format for export back to flatbuffer.
328   if (pass_config.legalize_tf_while) {
329     pm.addPass(mlir::TFL::CreateWhileOutlinePass());
330   }
331   pm.addPass(mlir::TFL::CreateRuntimeVerifyPass());
332 
333   auto status = ConvertTFExecutorToTFLOrFlatbuffer(
334       module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
335       emit_select_tf_ops, emit_custom_ops, allow_all_select_tf_ops,
336       select_user_tf_ops, pass_config.quant_specs, saved_model_tags, result,
337       &pm);
338   if (toco_flags.has_dump_graphviz_dir()) {
339     TF_RETURN_IF_ERROR(DumpOpGraphToFile(
340         // rename once we enable the new converter feature flag.
341         module.get(), absl::StrCat(toco_flags.dump_graphviz_dir(),
342                                    "/toco_AFTER_TRANSFORMATIONS.dot")));
343   }
344 
345   return status;
346 }
347 
WarningUnusedFlags(const toco::ModelFlags & model_flags,const toco::TocoFlags & toco_flags)348 void WarningUnusedFlags(const toco::ModelFlags& model_flags,
349                         const toco::TocoFlags& toco_flags) {
350   if (toco_flags.output_format()) {
351     LOG(WARNING) << "Ignored output_format.";
352   }
353   if (toco_flags.drop_control_dependency()) {
354     LOG(WARNING) << "Ignored drop_control_dependency.";
355   }
356   if (toco_flags.reorder_across_fake_quant()) {
357     LOG(WARNING) << "Ignored reorder_across_fake_quant.";
358   }
359   if (model_flags.change_concat_input_ranges()) {
360     LOG(WARNING) << "Ignored change_concat_input_ranges.";
361   }
362   if (toco_flags.dump_graphviz_include_video()) {
363     LOG(WARNING) << "Ignored dump_graphviz_video.";
364   }
365   if (model_flags.allow_nonexistent_arrays()) {
366     LOG(WARNING) << "Allow allow_nonexistent_arrays.";
367   }
368 }
369 
370 }  // namespace internal
371 }  // namespace tensorflow
372