1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h"
16
17 #include <ostream>
18 #include <unordered_set>
19 #include <utility>
20
21 #include "llvm/Support/ToolOutputFile.h"
22 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
23 #include "mlir/IR/MLIRContext.h" // from @llvm-project
24 #include "mlir/Pass/Pass.h" // from @llvm-project
25 #include "mlir/Support/FileUtilities.h" // from @llvm-project
26 #include "mlir/Transforms/ViewOpGraph.h" // from @llvm-project
27 #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
28 #include "tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h"
29 #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
30 #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
31 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
32 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
33 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
34 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
35 #include "tensorflow/core/framework/graph.pb.h"
36 #include "tensorflow/core/framework/types.pb.h"
37 #include "tensorflow/core/lib/core/errors.h"
38 #include "tensorflow/core/platform/status.h"
39 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
40 #include "tensorflow/lite/toco/model_flags.pb.h"
41 #include "tensorflow/lite/toco/toco_flags.pb.h"
42 #include "tensorflow/lite/toco/types.pb.h"
43 #include "tensorflow/stream_executor/lib/statusor.h"
44
45 using stream_executor::port::StatusOr;
46
47 namespace tensorflow {
48 namespace internal {
49 namespace {
50
51 // Op def string for TFLite_Detection_PostProcess Op.
52 const char kDetectionPostProcessOp[] =
53 "name: 'TFLite_Detection_PostProcess' input_arg: { name: "
54 "'raw_outputs/box_encodings' type: DT_FLOAT } input_arg: { name: "
55 "'raw_outputs/class_predictions' type: DT_FLOAT } input_arg: { name: "
56 "'anchors' type: DT_FLOAT } output_arg: { name: "
57 "'TFLite_Detection_PostProcess' type: DT_FLOAT } output_arg: { name: "
58 "'TFLite_Detection_PostProcess:1' type: DT_FLOAT } output_arg: { name: "
59 "'TFLite_Detection_PostProcess:2' type: DT_FLOAT } output_arg: { name: "
60 "'TFLite_Detection_PostProcess:3' type: DT_FLOAT } attr : { name: "
61 "'h_scale' type: 'float'} attr : { name: 'max_classes_per_detection' "
62 "type: 'int'} attr : { name: 'max_detections' type: 'int'} attr : { "
63 "name: 'nms_iou_threshold' type: 'float'} attr : { name: "
64 "'nms_score_threshold' type: 'float'} attr : { name: 'num_classes' type: "
65 "'int'} attr : { name: 'w_scale' type: 'float'} attr : { name: 'x_scale' "
66 "type: 'float'} attr : { name: 'y_scale' type: 'float'} attr { name: "
67 "'detections_per_class' type: 'int' default_value { i : 100 }} attr { "
68 "name: 'use_regular_nms' type: 'bool' default_value { b : false }}";
69
70 const char kUnidirectionalSequenceLstmOp[] =
71 "name: 'UnidirectionalSequenceLstm' input_arg: {name: 'Input' type: "
72 "DT_FLOAT} input_arg: { name: 'InputToInputWeights' type: DT_FLOAT } "
73 "input_arg: { name: 'InputToForgetWeights' type: DT_FLOAT } input_arg: { "
74 "name: 'InputToCellWeights' type: DT_FLOAT} input_arg: { name: "
75 "'InputToOutputWeights' type: DT_FLOAT } input_arg: { name: "
76 "'RecurrentToInputWeights' type: DT_FLOAT} input_arg: { name: "
77 "'RecurrentToForgetWeights' type: DT_FLOAT} input_arg: { name: "
78 "'RecurrentToCellWeights' type: DT_FLOAT } input_arg: { name: "
79 "'RecurrentToOutputWeights' type: DT_FLOAT } input_arg: { name: "
80 "'CellToInputWeights' type: DT_FLOAT} input_arg: { name: "
81 "'CellToForgetWeights' type: DT_FLOAT } input_arg: { name: "
82 "'CellToOutputWeights' type: DT_FLOAT } input_arg: { name: 'InputGateBias' "
83 "type: DT_FLOAT } input_arg: { name: 'ForgetGateBias' type: DT_FLOAT } "
84 "input_arg: { name: 'kCellGateBias' type: DT_FLOAT } input_arg: { name: "
85 "'OutputGateBias' type: DT_FLOAT } input_arg: { name: 'ProjectionWeights' "
86 "type: DT_FLOAT } input_arg: { name: 'ProjectionBias' type: DT_FLOAT } "
87 "input_arg: { name: 'InputActivationState' type: DT_FLOAT} input_arg: { "
88 "name: 'InputCellStateTensor' type: DT_FLOAT } "
89 "output_arg: { name: 'Concat' type: DT_FLOAT} "
90 "output_arg: { name: "
91 "'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: DT_FLOAT} "
92 "attr : { name: '_tflite_input_indices' type: 'list(int)'}";
93
94 const char kUnidirectionalSequenceRnnOp[] =
95 "name: 'UnidirectionalSequenceRnn' input_arg: {name: 'Input' type: "
96 "DT_FLOAT} input_arg: { name: 'Weights' type: DT_FLOAT } "
97 "input_arg: { name: 'RecurrentWeights' type: DT_FLOAT } input_arg: { "
98 "name: 'Bias' type: DT_FLOAT} "
99 "input_arg: { name: 'HiddenState' type: DT_FLOAT} "
100 "output_arg: { name: "
101 "'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: "
102 "DT_FLOAT} "
103 "attr : { name: '_tflite_input_indices' type: 'list(int)'}";
104
105 // Converts the toco::IODataType to tensorflow::DataType. Only contains the
106 // conversion mapping for constants defined in TFLite Python API.
ConvertIODataTypeToDataType(toco::IODataType dtype)107 DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
108 switch (dtype) {
109 case toco::IODataType::FLOAT:
110 return DT_FLOAT;
111 case toco::IODataType::FLOAT16:
112 return DT_HALF;
113 case toco::IODataType::FLOAT64:
114 return DT_DOUBLE;
115 case toco::IODataType::QUANTIZED_UINT8:
116 return DT_QUINT8;
117 case toco::IODataType::QUANTIZED_INT8:
118 return DT_QINT8;
119 case toco::IODataType::QUANTIZED_INT16:
120 return DT_QINT16;
121 case toco::IODataType::INT8:
122 return DT_INT8;
123 case toco::IODataType::INT16:
124 return DT_INT16;
125 case toco::IODataType::INT32:
126 return DT_INT32;
127 case toco::IODataType::UINT32:
128 return DT_UINT32;
129 case toco::IODataType::INT64:
130 return DT_INT64;
131 case toco::IODataType::UINT8:
132 return DT_UINT8;
133 case toco::IODataType::UINT64:
134 return DT_UINT64;
135 case toco::IODataType::STRING:
136 return DT_STRING;
137 case toco::IODataType::BOOL:
138 return DT_BOOL;
139 case toco::IODataType::COMPLEX64:
140 return DT_COMPLEX64;
141 case toco::IODataType::COMPLEX128:
142 return DT_COMPLEX128;
143 case toco::IODataType::RESOURCE:
144 return DT_RESOURCE;
145 case toco::IODataType::VARIANT:
146 return DT_VARIANT;
147 default:
148 return DT_INVALID;
149 }
150 }
151
InputStatsToMinMax(double mean,double std,DataType type)152 StatusOr<std::pair<double, double>> InputStatsToMinMax(double mean, double std,
153 DataType type) {
154 // Only qint8 and quint8 are considered here.
155 double qmin, qmax;
156 if (type == DT_QUINT8) {
157 qmin = 0.0;
158 qmax = 255.0;
159 } else if (type == DT_QINT8) {
160 qmin = -128.0;
161 qmax = 127.0;
162 } else {
163 return errors::InvalidArgument("Only int8 and uint8 are considered.");
164 }
165 return std::make_pair((qmin - mean) / std, (qmax - mean) / std);
166 }
167
RegisterCustomBuiltinOps(const std::vector<string> extra_tf_opdefs)168 Status RegisterCustomBuiltinOps(const std::vector<string> extra_tf_opdefs) {
169 for (const auto& tf_opdefs_string : extra_tf_opdefs) {
170 tensorflow::OpDef opdef;
171 if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string,
172 &opdef)) {
173 return errors::InvalidArgument("fail to parse extra OpDef");
174 }
175 // Make sure the op is not already registered. If registered continue.
176 const OpRegistrationData* op_reg =
177 tensorflow::OpRegistry::Global()->LookUp(opdef.name());
178 if (op_reg) continue;
179
180 tensorflow::OpRegistry::Global()->Register(
181 [opdef](tensorflow::OpRegistrationData* op_reg_data) -> Status {
182 *op_reg_data = tensorflow::OpRegistrationData(opdef);
183 return Status::OK();
184 });
185 }
186 return Status::OK();
187 }
188
189 } // namespace
190
RegisterAllCustomOps(const toco::TocoFlags & toco_flags)191 Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags) {
192 // Register any custom OpDefs.
193 std::vector<string> extra_tf_opdefs(toco_flags.custom_opdefs().begin(),
194 toco_flags.custom_opdefs().end());
195 extra_tf_opdefs.push_back(kDetectionPostProcessOp);
196 extra_tf_opdefs.push_back(kUnidirectionalSequenceLstmOp);
197 extra_tf_opdefs.push_back(kUnidirectionalSequenceRnnOp);
198 return RegisterCustomBuiltinOps(extra_tf_opdefs);
199 }
200
PopulateQuantizationSpecs(const toco::ModelFlags & model_flags,const toco::TocoFlags & toco_flags,mlir::TFL::QuantizationSpecs * quant_specs,std::vector<string> * node_names,std::vector<string> * node_dtypes,std::vector<llvm::Optional<std::vector<int>>> * node_shapes,std::vector<llvm::Optional<double>> * node_mins,std::vector<llvm::Optional<double>> * node_maxs)201 Status PopulateQuantizationSpecs(
202 const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
203 mlir::TFL::QuantizationSpecs* quant_specs, std::vector<string>* node_names,
204 std::vector<string>* node_dtypes,
205 std::vector<llvm::Optional<std::vector<int>>>* node_shapes,
206 std::vector<llvm::Optional<double>>* node_mins,
207 std::vector<llvm::Optional<double>>* node_maxs) {
208 quant_specs->inference_input_type =
209 ConvertIODataTypeToDataType(toco_flags.inference_input_type());
210 tensorflow::DataType inference_type =
211 ConvertIODataTypeToDataType(toco_flags.inference_type());
212 // Use non-float flag `inference_input_type` to override the `inference_type`
213 // because we have to apply quantization to satisfy that.
214 if (quant_specs->inference_input_type != tensorflow::DT_FLOAT) {
215 inference_type = quant_specs->inference_input_type;
216 }
217
218 for (auto& flag : model_flags.input_arrays()) {
219 node_names->push_back(flag.name());
220 // TOCO doesn't required `data_type` to be filled for every input.
221 // If it's not filled, make it an empty string so the importer will use
222 // the data type in the NodeDef.
223 auto toco_data_type = flag.data_type();
224 if (toco_data_type == ::toco::IODataType::IO_DATA_TYPE_UNKNOWN) {
225 node_dtypes->push_back("");
226 } else {
227 node_dtypes->push_back(
228 DataType_Name(ConvertIODataTypeToDataType(toco_data_type)));
229 }
230 if (flag.shape().unknown_rank()) {
231 node_shapes->push_back(llvm::None);
232 } else {
233 node_shapes->push_back(std::vector<int>(flag.shape().dims().begin(),
234 flag.shape().dims().end()));
235 }
236 // Currently, only UINT8 and INT8 require inputs stats
237 if (inference_type == DT_QINT8 || inference_type == DT_QUINT8) {
238 if (flag.has_mean_value() && flag.has_std_value()) {
239 TF_ASSIGN_OR_RETURN(
240 auto min_max, InputStatsToMinMax(flag.mean_value(),
241 flag.std_value(), inference_type));
242 node_mins->push_back(min_max.first);
243 node_maxs->push_back(min_max.second);
244 } else {
245 node_mins->push_back(llvm::None);
246 node_maxs->push_back(llvm::None);
247 }
248 }
249 }
250
251 if (mlir::TFL::GetInputNodeQuantSpecs(*node_names, *node_mins, *node_maxs,
252 inference_type, quant_specs)) {
253 return errors::InvalidArgument("Failed to get input quant spec.");
254 }
255
256 // Some extra flag related to post training quantization. If post-training
257 // quantization is enabled, `inference_type` and `inference_input_type` are
258 // not used by MLIR passes.
259 if (toco_flags.post_training_quantize()) {
260 quant_specs->weight_quantization = true;
261 if (toco_flags.quantize_to_float16()) {
262 quant_specs->inference_type = tensorflow::DT_HALF;
263 quant_specs->inference_input_type = tensorflow::DT_HALF;
264 } else {
265 quant_specs->inference_type = tensorflow::DT_QINT8;
266 quant_specs->inference_input_type = tensorflow::DT_QINT8;
267 }
268 }
269
270 // Other flags.
271 if (toco_flags.has_default_ranges_min()) {
272 quant_specs->default_ranges.first = toco_flags.default_ranges_min();
273 }
274 if (toco_flags.has_default_ranges_max()) {
275 quant_specs->default_ranges.second = toco_flags.default_ranges_max();
276 }
277
278 return ::tensorflow::Status::OK();
279 }
280
281 // Dumps the op graph of the `module` to `filename` in DOT format.
DumpOpGraphToFile(mlir::ModuleOp module,const std::string & filename)282 Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) {
283 std::string error_message;
284 auto output = mlir::openOutputFile(filename, &error_message);
285 if (!error_message.empty()) {
286 return errors::InvalidArgument("Failed to open file in ", filename);
287 }
288 mlir::PassManager pm(module.getContext());
289 pm.addPass(mlir::createPrintOpGraphPass(output->os()));
290 if (failed(pm.run(module))) {
291 return errors::Unknown("Failed to dump Op Graph from MLIR module.");
292 }
293 output->keep();
294 return Status::OK();
295 }
296
ConvertMLIRToTFLiteFlatBuffer(const toco::ModelFlags & model_flags,const toco::TocoFlags & toco_flags,mlir::OwningModuleRef module,const mlir::TFL::PassConfig & pass_config,const std::unordered_set<std::string> & saved_model_tags,string * result,llvm::Optional<tensorflow::Session * > session)297 Status ConvertMLIRToTFLiteFlatBuffer(
298 const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
299 mlir::OwningModuleRef module, const mlir::TFL::PassConfig& pass_config,
300 const std::unordered_set<std::string>& saved_model_tags, string* result,
301 llvm::Optional<tensorflow::Session*> session) {
302 bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
303 bool emit_select_tf_ops = toco_flags.enable_select_tf_ops();
304 bool emit_custom_ops = toco_flags.allow_custom_ops();
305 bool allow_all_select_tf_ops = toco_flags.allow_all_select_tf_ops();
306
307 const std::unordered_set<std::string> select_user_tf_ops(
308 toco_flags.select_user_tf_ops().begin(),
309 toco_flags.select_user_tf_ops().end());
310
311 if (toco_flags.has_dump_graphviz_dir()) {
312 TF_RETURN_IF_ERROR(DumpOpGraphToFile(
313 module.get(),
314 // rename once we enable the new converter feature flag.
315 absl::StrCat(toco_flags.dump_graphviz_dir(), "/toco_AT_IMPORT.dot")));
316 }
317
318 mlir::PassManager pm(module->getContext(),
319 mlir::OpPassManager::Nesting::Implicit);
320 ::tensorflow::SetCrashReproducer(pm);
321 pm.addInstrumentation(
322 std::make_unique<mlir::TFL::ErrorCollectorInstrumentation>(
323 module->getContext()));
324
325 tensorflow::AddTFToTFLConversionPasses(model_flags, toco_flags, pass_config,
326 &pm, session);
327 // Convert back to outlined while format for export back to flatbuffer.
328 if (pass_config.legalize_tf_while) {
329 pm.addPass(mlir::TFL::CreateWhileOutlinePass());
330 }
331 pm.addPass(mlir::TFL::CreateRuntimeVerifyPass());
332
333 auto status = ConvertTFExecutorToTFLOrFlatbuffer(
334 module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
335 emit_select_tf_ops, emit_custom_ops, allow_all_select_tf_ops,
336 select_user_tf_ops, pass_config.quant_specs, saved_model_tags, result,
337 &pm);
338 if (toco_flags.has_dump_graphviz_dir()) {
339 TF_RETURN_IF_ERROR(DumpOpGraphToFile(
340 // rename once we enable the new converter feature flag.
341 module.get(), absl::StrCat(toco_flags.dump_graphviz_dir(),
342 "/toco_AFTER_TRANSFORMATIONS.dot")));
343 }
344
345 return status;
346 }
347
WarningUnusedFlags(const toco::ModelFlags & model_flags,const toco::TocoFlags & toco_flags)348 void WarningUnusedFlags(const toco::ModelFlags& model_flags,
349 const toco::TocoFlags& toco_flags) {
350 if (toco_flags.output_format()) {
351 LOG(WARNING) << "Ignored output_format.";
352 }
353 if (toco_flags.drop_control_dependency()) {
354 LOG(WARNING) << "Ignored drop_control_dependency.";
355 }
356 if (toco_flags.reorder_across_fake_quant()) {
357 LOG(WARNING) << "Ignored reorder_across_fake_quant.";
358 }
359 if (model_flags.change_concat_input_ranges()) {
360 LOG(WARNING) << "Ignored change_concat_input_ranges.";
361 }
362 if (toco_flags.dump_graphviz_include_video()) {
363 LOG(WARNING) << "Ignored dump_graphviz_video.";
364 }
365 if (model_flags.allow_nonexistent_arrays()) {
366 LOG(WARNING) << "Allow allow_nonexistent_arrays.";
367 }
368 }
369
370 } // namespace internal
371 } // namespace tensorflow
372