• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h"
16 
17 #include <ostream>
18 #include <unordered_set>
19 #include <utility>
20 
21 #include "llvm/Support/ToolOutputFile.h"
22 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
23 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
24 #include "mlir/Pass/Pass.h"  // from @llvm-project
25 #include "mlir/Support/FileUtilities.h"  // from @llvm-project
26 #include "mlir/Transforms/ViewOpGraph.h"  // from @llvm-project
27 #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
28 #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
29 #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
30 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
31 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
32 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
33 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
34 #include "tensorflow/core/framework/graph.pb.h"
35 #include "tensorflow/core/framework/types.pb.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/platform/status.h"
38 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
39 #include "tensorflow/lite/toco/model_flags.pb.h"
40 #include "tensorflow/lite/toco/toco_flags.pb.h"
41 #include "tensorflow/lite/toco/types.pb.h"
42 #include "tensorflow/stream_executor/lib/statusor.h"
43 
44 using stream_executor::port::StatusOr;
45 
46 namespace tensorflow {
47 namespace internal {
48 namespace {
49 
50 // Op def string for TFLite_Detection_PostProcess Op.
51 const char kDetectionPostProcessOp[] =
52     "name: 'TFLite_Detection_PostProcess' input_arg: { name: "
53     "'raw_outputs/box_encodings' type: DT_FLOAT } input_arg: { name: "
54     "'raw_outputs/class_predictions' type: DT_FLOAT } input_arg: { name: "
55     "'anchors' type: DT_FLOAT } output_arg: { name: "
56     "'TFLite_Detection_PostProcess' type: DT_FLOAT } output_arg: { name: "
57     "'TFLite_Detection_PostProcess:1' type: DT_FLOAT } output_arg: { name: "
58     "'TFLite_Detection_PostProcess:2' type: DT_FLOAT } output_arg: { name: "
59     "'TFLite_Detection_PostProcess:3' type: DT_FLOAT } attr : { name: "
60     "'h_scale' type: 'float'} attr : { name: 'max_classes_per_detection' "
61     "type: 'int'} attr : { name: 'max_detections' type: 'int'} attr : { "
62     "name: 'nms_iou_threshold' type: 'float'} attr : { name: "
63     "'nms_score_threshold' type: 'float'} attr : { name: 'num_classes' type: "
64     "'int'} attr : { name: 'w_scale' type: 'float'} attr : { name: 'x_scale' "
65     "type: 'float'} attr : { name: 'y_scale' type: 'float'} attr { name: "
66     "'detections_per_class' type: 'int' default_value { i : 100 }} attr { "
67     "name: 'use_regular_nms' type: 'bool' default_value { b : false }}";
68 
69 const char kUnidirectionalSequenceLstmOp[] =
70     "name: 'UnidirectionalSequenceLstm' input_arg: {name: 'Input' type: "
71     "DT_FLOAT} input_arg: { name: 'InputToInputWeights' type: DT_FLOAT } "
72     "input_arg: { name: 'InputToForgetWeights' type: DT_FLOAT } input_arg: { "
73     "name: 'InputToCellWeights' type: DT_FLOAT} input_arg: { name: "
74     "'InputToOutputWeights' type: DT_FLOAT } input_arg: { name: "
75     "'RecurrentToInputWeights' type: DT_FLOAT} input_arg: { name: "
76     "'RecurrentToForgetWeights' type: DT_FLOAT} input_arg: { name: "
77     "'RecurrentToCellWeights' type: DT_FLOAT } input_arg: { name: "
78     "'RecurrentToOutputWeights' type: DT_FLOAT } input_arg: { name: "
79     "'CellToInputWeights' type: DT_FLOAT} input_arg: { name: "
80     "'CellToForgetWeights' type: DT_FLOAT } input_arg: { name: "
81     "'CellToOutputWeights' type: DT_FLOAT } input_arg: { name: 'InputGateBias' "
82     "type: DT_FLOAT } input_arg: { name: 'ForgetGateBias' type: DT_FLOAT } "
83     "input_arg: { name: 'kCellGateBias' type: DT_FLOAT } input_arg: { name: "
84     "'OutputGateBias' type: DT_FLOAT } input_arg: { name: 'ProjectionWeights' "
85     "type: DT_FLOAT } input_arg: { name: 'ProjectionBias' type: DT_FLOAT } "
86     "input_arg: { name: 'InputActivationState' type: DT_FLOAT} input_arg: { "
87     "name: 'InputCellStateTensor' type: DT_FLOAT } "
88     "output_arg: { name: 'Concat' type: DT_FLOAT} "
89     "output_arg: { name: "
90     "'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: DT_FLOAT} "
91     "attr : { name: '_tflite_input_indices' type: 'list(int)'}";
92 
93 const char kUnidirectionalSequenceRnnOp[] =
94     "name: 'UnidirectionalSequenceRnn' input_arg: {name: 'Input' type: "
95     "DT_FLOAT} input_arg: { name: 'Weights' type: DT_FLOAT } "
96     "input_arg: { name: 'RecurrentWeights' type: DT_FLOAT } input_arg: { "
97     "name: 'Bias' type: DT_FLOAT} "
98     "input_arg: { name: 'HiddenState' type: DT_FLOAT} "
99     "output_arg: { name: "
100     "'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: "
101     "DT_FLOAT} "
102     "attr : { name: '_tflite_input_indices' type: 'list(int)'}";
103 
104 // Converts the toco::IODataType to tensorflow::DataType. Only contains the
105 // conversion mapping for constants defined in TFLite Python API.
ConvertIODataTypeToDataType(toco::IODataType dtype)106 DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
107   switch (dtype) {
108     case toco::IODataType::FLOAT:
109       return DT_FLOAT;
110     case toco::IODataType::FLOAT16:
111       return DT_HALF;
112     case toco::IODataType::FLOAT64:
113       return DT_DOUBLE;
114     case toco::IODataType::QUANTIZED_UINT8:
115       return DT_QUINT8;
116     case toco::IODataType::INT8:
117       return DT_QINT8;
118     case toco::IODataType::QUANTIZED_INT16:
119       return DT_INT16;
120     case toco::IODataType::INT32:
121       return DT_INT32;
122     case toco::IODataType::UINT32:
123       return DT_UINT32;
124     case toco::IODataType::INT64:
125       return DT_INT64;
126     case toco::IODataType::UINT64:
127       return DT_UINT64;
128     case toco::IODataType::STRING:
129       return DT_STRING;
130     case toco::IODataType::BOOL:
131       return DT_BOOL;
132     case toco::IODataType::COMPLEX64:
133       return DT_COMPLEX64;
134     case toco::IODataType::COMPLEX128:
135       return DT_COMPLEX128;
136     case toco::IODataType::RESOURCE:
137       return DT_RESOURCE;
138     case toco::IODataType::VARIANT:
139       return DT_VARIANT;
140     default:
141       return DT_INVALID;
142   }
143 }
144 
InputStatsToMinMax(double mean,double std,DataType type)145 StatusOr<std::pair<double, double>> InputStatsToMinMax(double mean, double std,
146                                                        DataType type) {
147   // Only qint8 and quint8 are considered here.
148   double qmin, qmax;
149   if (type == DT_QUINT8) {
150     qmin = 0.0;
151     qmax = 255.0;
152   } else if (type == DT_QINT8) {
153     qmin = -128.0;
154     qmax = 127.0;
155   } else {
156     return errors::InvalidArgument("Only int8 and uint8 are considered.");
157   }
158   return std::make_pair((qmin - mean) / std, (qmax - mean) / std);
159 }
160 
RegisterCustomBuiltinOps(const std::vector<string> extra_tf_opdefs)161 Status RegisterCustomBuiltinOps(const std::vector<string> extra_tf_opdefs) {
162   for (const auto& tf_opdefs_string : extra_tf_opdefs) {
163     tensorflow::OpDef opdef;
164     if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string,
165                                                            &opdef)) {
166       return errors::InvalidArgument("fail to parse extra OpDef");
167     }
168     // Make sure the op is not already registered. If registered continue.
169     const OpRegistrationData* op_reg =
170         tensorflow::OpRegistry::Global()->LookUp(opdef.name());
171     if (op_reg) continue;
172 
173     tensorflow::OpRegistry::Global()->Register(
174         [opdef](tensorflow::OpRegistrationData* op_reg_data) -> Status {
175           *op_reg_data = tensorflow::OpRegistrationData(opdef);
176           return Status::OK();
177         });
178   }
179   return Status::OK();
180 }
181 
182 }  // namespace
183 
RegisterAllCustomOps(const toco::TocoFlags & toco_flags)184 Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags) {
185   // Register any custom OpDefs.
186   std::vector<string> extra_tf_opdefs(toco_flags.custom_opdefs().begin(),
187                                       toco_flags.custom_opdefs().end());
188   extra_tf_opdefs.push_back(kDetectionPostProcessOp);
189   extra_tf_opdefs.push_back(kUnidirectionalSequenceLstmOp);
190   extra_tf_opdefs.push_back(kUnidirectionalSequenceRnnOp);
191   return RegisterCustomBuiltinOps(extra_tf_opdefs);
192 }
193 
PopulateQuantizationSpecs(const toco::ModelFlags & model_flags,const toco::TocoFlags & toco_flags,mlir::TFL::QuantizationSpecs * quant_specs,std::vector<string> * node_names,std::vector<string> * node_dtypes,std::vector<llvm::Optional<std::vector<int>>> * node_shapes,std::vector<llvm::Optional<double>> * node_mins,std::vector<llvm::Optional<double>> * node_maxs)194 Status PopulateQuantizationSpecs(
195     const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
196     mlir::TFL::QuantizationSpecs* quant_specs, std::vector<string>* node_names,
197     std::vector<string>* node_dtypes,
198     std::vector<llvm::Optional<std::vector<int>>>* node_shapes,
199     std::vector<llvm::Optional<double>>* node_mins,
200     std::vector<llvm::Optional<double>>* node_maxs) {
201   quant_specs->inference_input_type =
202       ConvertIODataTypeToDataType(toco_flags.inference_input_type());
203   tensorflow::DataType inference_type =
204       ConvertIODataTypeToDataType(toco_flags.inference_type());
205   // Use non-float flag `inference_input_type` to override the `inference_type`
206   // because we have to apply quantization to satisfy that.
207   if (quant_specs->inference_input_type != tensorflow::DT_FLOAT) {
208     inference_type = quant_specs->inference_input_type;
209   }
210 
211   for (auto& flag : model_flags.input_arrays()) {
212     node_names->push_back(flag.name());
213     // TOCO doesn't required `data_type` to be filled for every input.
214     // If it's not filled, make it an empty string so the importer will use
215     // the data type in the NodeDef.
216     auto toco_data_type = flag.data_type();
217     if (toco_data_type == ::toco::IODataType::IO_DATA_TYPE_UNKNOWN) {
218       node_dtypes->push_back("");
219     } else {
220       node_dtypes->push_back(
221           DataType_Name(ConvertIODataTypeToDataType(toco_data_type)));
222     }
223     if (flag.shape().unknown_rank()) {
224       node_shapes->push_back(llvm::None);
225     } else {
226       node_shapes->push_back(std::vector<int>(flag.shape().dims().begin(),
227                                               flag.shape().dims().end()));
228     }
229     // Currently, only UINT8 and INT8 require inputs stats
230     if (inference_type == DT_QINT8 || inference_type == DT_QUINT8) {
231       if (flag.has_mean_value() && flag.has_std_value()) {
232         TF_ASSIGN_OR_RETURN(
233             auto min_max, InputStatsToMinMax(flag.mean_value(),
234                                              flag.std_value(), inference_type));
235         node_mins->push_back(min_max.first);
236         node_maxs->push_back(min_max.second);
237       } else {
238         node_mins->push_back(llvm::None);
239         node_maxs->push_back(llvm::None);
240       }
241     }
242   }
243 
244   if (mlir::TFL::GetInputNodeQuantSpecs(*node_names, *node_mins, *node_maxs,
245                                         inference_type, quant_specs)) {
246     return errors::InvalidArgument("Failed to get input quant spec.");
247   }
248 
249   // Some extra flag related to post training quantization. If post-training
250   // quantization is enabled, `inference_type` and `inference_input_type` are
251   // not used by MLIR passes.
252   if (toco_flags.post_training_quantize()) {
253     quant_specs->weight_quantization = true;
254     if (toco_flags.quantize_to_float16()) {
255       quant_specs->inference_type = tensorflow::DT_HALF;
256       quant_specs->inference_input_type = tensorflow::DT_HALF;
257     } else {
258       quant_specs->inference_type = tensorflow::DT_QINT8;
259       quant_specs->inference_input_type = tensorflow::DT_QINT8;
260     }
261   }
262 
263   // Other flags.
264   if (toco_flags.has_default_ranges_min()) {
265     quant_specs->default_ranges.first = toco_flags.default_ranges_min();
266   }
267   if (toco_flags.has_default_ranges_max()) {
268     quant_specs->default_ranges.second = toco_flags.default_ranges_max();
269   }
270 
271   return ::tensorflow::Status::OK();
272 }
273 
274 // Dumps the op graph of the `module` to `filename` in DOT format.
DumpOpGraphToFile(mlir::ModuleOp module,const std::string & filename)275 Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) {
276   std::string error_message;
277   auto output = mlir::openOutputFile(filename, &error_message);
278   if (!error_message.empty()) {
279     return errors::InvalidArgument("Failed to open file in ", filename);
280   }
281   mlir::PassManager pm(module.getContext());
282   pm.addPass(mlir::createPrintOpGraphPass(output->os()));
283   if (failed(pm.run(module))) {
284     return errors::Unknown("Failed to dump Op Graph from MLIR module.");
285   }
286   output->keep();
287   return Status::OK();
288 }
289 
ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags & toco_flags,mlir::OwningModuleRef module,const mlir::TFL::PassConfig & pass_config,const std::unordered_set<std::string> & saved_model_tags,string * result,llvm::Optional<tensorflow::Session * > session)290 Status ConvertMLIRToTFLiteFlatBuffer(
291     const toco::TocoFlags& toco_flags, mlir::OwningModuleRef module,
292     const mlir::TFL::PassConfig& pass_config,
293     const std::unordered_set<std::string>& saved_model_tags, string* result,
294     llvm::Optional<tensorflow::Session*> session) {
295   bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
296   bool emit_select_tf_ops = toco_flags.enable_select_tf_ops();
297   bool emit_custom_ops = toco_flags.allow_custom_ops();
298 
299   const std::unordered_set<std::string> select_user_tf_ops(
300       toco_flags.select_user_tf_ops().begin(),
301       toco_flags.select_user_tf_ops().end());
302 
303   if (toco_flags.has_dump_graphviz_dir()) {
304     TF_RETURN_IF_ERROR(DumpOpGraphToFile(
305         module.get(),
306         // rename once we enable the new converter feature flag.
307         absl::StrCat(toco_flags.dump_graphviz_dir(), "/toco_AT_IMPORT.dot")));
308   }
309 
310   mlir::PassManager pm(module->getContext(),
311                        mlir::OpPassManager::Nesting::Implicit);
312   ::tensorflow::SetCrashReproducer(pm);
313 
314   tensorflow::AddTFToTFLConversionPasses(pass_config, &pm, session);
315   // Convert back to outlined while format for export back to flatbuffer.
316   if (pass_config.legalize_tf_while) {
317     pm.addPass(mlir::TFL::CreateWhileOutlinePass());
318   }
319   pm.addPass(mlir::TFL::CreateRuntimeVerifyPass());
320 
321   auto status = ConvertTFExecutorToTFLOrFlatbuffer(
322       module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
323       emit_select_tf_ops, emit_custom_ops, select_user_tf_ops,
324       pass_config.quant_specs, saved_model_tags, result, &pm);
325   if (toco_flags.has_dump_graphviz_dir()) {
326     TF_RETURN_IF_ERROR(DumpOpGraphToFile(
327         // rename once we enable the new converter feature flag.
328         module.get(), absl::StrCat(toco_flags.dump_graphviz_dir(),
329                                    "/toco_AFTER_TRANSFORMATIONS.dot")));
330   }
331 
332   return status;
333 }
334 
WarningUnusedFlags(const toco::ModelFlags & model_flags,const toco::TocoFlags & toco_flags)335 void WarningUnusedFlags(const toco::ModelFlags& model_flags,
336                         const toco::TocoFlags& toco_flags) {
337   if (toco_flags.output_format()) {
338     LOG(WARNING) << "Ignored output_format.";
339   }
340   if (toco_flags.drop_control_dependency()) {
341     LOG(WARNING) << "Ignored drop_control_dependency.";
342   }
343   if (toco_flags.reorder_across_fake_quant()) {
344     LOG(WARNING) << "Ignored reorder_across_fake_quant.";
345   }
346   if (model_flags.change_concat_input_ranges()) {
347     LOG(WARNING) << "Ignored change_concat_input_ranges.";
348   }
349   if (toco_flags.dump_graphviz_include_video()) {
350     LOG(WARNING) << "Ignored dump_graphviz_video.";
351   }
352   if (model_flags.allow_nonexistent_arrays()) {
353     LOG(WARNING) << "Allow allow_nonexistent_arrays.";
354   }
355 }
356 
357 }  // namespace internal
358 }  // namespace tensorflow
359