1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h"
16
17 #include <ostream>
18 #include <unordered_set>
19 #include <utility>
20
21 #include "llvm/Support/ToolOutputFile.h"
22 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
23 #include "mlir/IR/MLIRContext.h" // from @llvm-project
24 #include "mlir/Pass/Pass.h" // from @llvm-project
25 #include "mlir/Support/FileUtilities.h" // from @llvm-project
26 #include "mlir/Transforms/ViewOpGraph.h" // from @llvm-project
27 #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
28 #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
29 #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
30 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
31 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
32 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
33 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
34 #include "tensorflow/core/framework/graph.pb.h"
35 #include "tensorflow/core/framework/types.pb.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/platform/status.h"
38 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
39 #include "tensorflow/lite/toco/model_flags.pb.h"
40 #include "tensorflow/lite/toco/toco_flags.pb.h"
41 #include "tensorflow/lite/toco/types.pb.h"
42 #include "tensorflow/stream_executor/lib/statusor.h"
43
44 using stream_executor::port::StatusOr;
45
46 namespace tensorflow {
47 namespace internal {
48 namespace {
49
50 // Op def string for TFLite_Detection_PostProcess Op.
51 const char kDetectionPostProcessOp[] =
52 "name: 'TFLite_Detection_PostProcess' input_arg: { name: "
53 "'raw_outputs/box_encodings' type: DT_FLOAT } input_arg: { name: "
54 "'raw_outputs/class_predictions' type: DT_FLOAT } input_arg: { name: "
55 "'anchors' type: DT_FLOAT } output_arg: { name: "
56 "'TFLite_Detection_PostProcess' type: DT_FLOAT } output_arg: { name: "
57 "'TFLite_Detection_PostProcess:1' type: DT_FLOAT } output_arg: { name: "
58 "'TFLite_Detection_PostProcess:2' type: DT_FLOAT } output_arg: { name: "
59 "'TFLite_Detection_PostProcess:3' type: DT_FLOAT } attr : { name: "
60 "'h_scale' type: 'float'} attr : { name: 'max_classes_per_detection' "
61 "type: 'int'} attr : { name: 'max_detections' type: 'int'} attr : { "
62 "name: 'nms_iou_threshold' type: 'float'} attr : { name: "
63 "'nms_score_threshold' type: 'float'} attr : { name: 'num_classes' type: "
64 "'int'} attr : { name: 'w_scale' type: 'float'} attr : { name: 'x_scale' "
65 "type: 'float'} attr : { name: 'y_scale' type: 'float'} attr { name: "
66 "'detections_per_class' type: 'int' default_value { i : 100 }} attr { "
67 "name: 'use_regular_nms' type: 'bool' default_value { b : false }}";
68
69 const char kUnidirectionalSequenceLstmOp[] =
70 "name: 'UnidirectionalSequenceLstm' input_arg: {name: 'Input' type: "
71 "DT_FLOAT} input_arg: { name: 'InputToInputWeights' type: DT_FLOAT } "
72 "input_arg: { name: 'InputToForgetWeights' type: DT_FLOAT } input_arg: { "
73 "name: 'InputToCellWeights' type: DT_FLOAT} input_arg: { name: "
74 "'InputToOutputWeights' type: DT_FLOAT } input_arg: { name: "
75 "'RecurrentToInputWeights' type: DT_FLOAT} input_arg: { name: "
76 "'RecurrentToForgetWeights' type: DT_FLOAT} input_arg: { name: "
77 "'RecurrentToCellWeights' type: DT_FLOAT } input_arg: { name: "
78 "'RecurrentToOutputWeights' type: DT_FLOAT } input_arg: { name: "
79 "'CellToInputWeights' type: DT_FLOAT} input_arg: { name: "
80 "'CellToForgetWeights' type: DT_FLOAT } input_arg: { name: "
81 "'CellToOutputWeights' type: DT_FLOAT } input_arg: { name: 'InputGateBias' "
82 "type: DT_FLOAT } input_arg: { name: 'ForgetGateBias' type: DT_FLOAT } "
83 "input_arg: { name: 'kCellGateBias' type: DT_FLOAT } input_arg: { name: "
84 "'OutputGateBias' type: DT_FLOAT } input_arg: { name: 'ProjectionWeights' "
85 "type: DT_FLOAT } input_arg: { name: 'ProjectionBias' type: DT_FLOAT } "
86 "input_arg: { name: 'InputActivationState' type: DT_FLOAT} input_arg: { "
87 "name: 'InputCellStateTensor' type: DT_FLOAT } "
88 "output_arg: { name: 'Concat' type: DT_FLOAT} "
89 "output_arg: { name: "
90 "'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: DT_FLOAT} "
91 "attr : { name: '_tflite_input_indices' type: 'list(int)'}";
92
93 const char kUnidirectionalSequenceRnnOp[] =
94 "name: 'UnidirectionalSequenceRnn' input_arg: {name: 'Input' type: "
95 "DT_FLOAT} input_arg: { name: 'Weights' type: DT_FLOAT } "
96 "input_arg: { name: 'RecurrentWeights' type: DT_FLOAT } input_arg: { "
97 "name: 'Bias' type: DT_FLOAT} "
98 "input_arg: { name: 'HiddenState' type: DT_FLOAT} "
99 "output_arg: { name: "
100 "'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: "
101 "DT_FLOAT} "
102 "attr : { name: '_tflite_input_indices' type: 'list(int)'}";
103
104 // Converts the toco::IODataType to tensorflow::DataType. Only contains the
105 // conversion mapping for constants defined in TFLite Python API.
ConvertIODataTypeToDataType(toco::IODataType dtype)106 DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
107 switch (dtype) {
108 case toco::IODataType::FLOAT:
109 return DT_FLOAT;
110 case toco::IODataType::FLOAT16:
111 return DT_HALF;
112 case toco::IODataType::FLOAT64:
113 return DT_DOUBLE;
114 case toco::IODataType::QUANTIZED_UINT8:
115 return DT_QUINT8;
116 case toco::IODataType::INT8:
117 return DT_QINT8;
118 case toco::IODataType::QUANTIZED_INT16:
119 return DT_INT16;
120 case toco::IODataType::INT32:
121 return DT_INT32;
122 case toco::IODataType::UINT32:
123 return DT_UINT32;
124 case toco::IODataType::INT64:
125 return DT_INT64;
126 case toco::IODataType::UINT64:
127 return DT_UINT64;
128 case toco::IODataType::STRING:
129 return DT_STRING;
130 case toco::IODataType::BOOL:
131 return DT_BOOL;
132 case toco::IODataType::COMPLEX64:
133 return DT_COMPLEX64;
134 case toco::IODataType::COMPLEX128:
135 return DT_COMPLEX128;
136 case toco::IODataType::RESOURCE:
137 return DT_RESOURCE;
138 case toco::IODataType::VARIANT:
139 return DT_VARIANT;
140 default:
141 return DT_INVALID;
142 }
143 }
144
InputStatsToMinMax(double mean,double std,DataType type)145 StatusOr<std::pair<double, double>> InputStatsToMinMax(double mean, double std,
146 DataType type) {
147 // Only qint8 and quint8 are considered here.
148 double qmin, qmax;
149 if (type == DT_QUINT8) {
150 qmin = 0.0;
151 qmax = 255.0;
152 } else if (type == DT_QINT8) {
153 qmin = -128.0;
154 qmax = 127.0;
155 } else {
156 return errors::InvalidArgument("Only int8 and uint8 are considered.");
157 }
158 return std::make_pair((qmin - mean) / std, (qmax - mean) / std);
159 }
160
RegisterCustomBuiltinOps(const std::vector<string> extra_tf_opdefs)161 Status RegisterCustomBuiltinOps(const std::vector<string> extra_tf_opdefs) {
162 for (const auto& tf_opdefs_string : extra_tf_opdefs) {
163 tensorflow::OpDef opdef;
164 if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string,
165 &opdef)) {
166 return errors::InvalidArgument("fail to parse extra OpDef");
167 }
168 // Make sure the op is not already registered. If registered continue.
169 const OpRegistrationData* op_reg =
170 tensorflow::OpRegistry::Global()->LookUp(opdef.name());
171 if (op_reg) continue;
172
173 tensorflow::OpRegistry::Global()->Register(
174 [opdef](tensorflow::OpRegistrationData* op_reg_data) -> Status {
175 *op_reg_data = tensorflow::OpRegistrationData(opdef);
176 return Status::OK();
177 });
178 }
179 return Status::OK();
180 }
181
182 } // namespace
183
RegisterAllCustomOps(const toco::TocoFlags & toco_flags)184 Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags) {
185 // Register any custom OpDefs.
186 std::vector<string> extra_tf_opdefs(toco_flags.custom_opdefs().begin(),
187 toco_flags.custom_opdefs().end());
188 extra_tf_opdefs.push_back(kDetectionPostProcessOp);
189 extra_tf_opdefs.push_back(kUnidirectionalSequenceLstmOp);
190 extra_tf_opdefs.push_back(kUnidirectionalSequenceRnnOp);
191 return RegisterCustomBuiltinOps(extra_tf_opdefs);
192 }
193
PopulateQuantizationSpecs(const toco::ModelFlags & model_flags,const toco::TocoFlags & toco_flags,mlir::TFL::QuantizationSpecs * quant_specs,std::vector<string> * node_names,std::vector<string> * node_dtypes,std::vector<llvm::Optional<std::vector<int>>> * node_shapes,std::vector<llvm::Optional<double>> * node_mins,std::vector<llvm::Optional<double>> * node_maxs)194 Status PopulateQuantizationSpecs(
195 const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
196 mlir::TFL::QuantizationSpecs* quant_specs, std::vector<string>* node_names,
197 std::vector<string>* node_dtypes,
198 std::vector<llvm::Optional<std::vector<int>>>* node_shapes,
199 std::vector<llvm::Optional<double>>* node_mins,
200 std::vector<llvm::Optional<double>>* node_maxs) {
201 quant_specs->inference_input_type =
202 ConvertIODataTypeToDataType(toco_flags.inference_input_type());
203 tensorflow::DataType inference_type =
204 ConvertIODataTypeToDataType(toco_flags.inference_type());
205 // Use non-float flag `inference_input_type` to override the `inference_type`
206 // because we have to apply quantization to satisfy that.
207 if (quant_specs->inference_input_type != tensorflow::DT_FLOAT) {
208 inference_type = quant_specs->inference_input_type;
209 }
210
211 for (auto& flag : model_flags.input_arrays()) {
212 node_names->push_back(flag.name());
213 // TOCO doesn't required `data_type` to be filled for every input.
214 // If it's not filled, make it an empty string so the importer will use
215 // the data type in the NodeDef.
216 auto toco_data_type = flag.data_type();
217 if (toco_data_type == ::toco::IODataType::IO_DATA_TYPE_UNKNOWN) {
218 node_dtypes->push_back("");
219 } else {
220 node_dtypes->push_back(
221 DataType_Name(ConvertIODataTypeToDataType(toco_data_type)));
222 }
223 if (flag.shape().unknown_rank()) {
224 node_shapes->push_back(llvm::None);
225 } else {
226 node_shapes->push_back(std::vector<int>(flag.shape().dims().begin(),
227 flag.shape().dims().end()));
228 }
229 // Currently, only UINT8 and INT8 require inputs stats
230 if (inference_type == DT_QINT8 || inference_type == DT_QUINT8) {
231 if (flag.has_mean_value() && flag.has_std_value()) {
232 TF_ASSIGN_OR_RETURN(
233 auto min_max, InputStatsToMinMax(flag.mean_value(),
234 flag.std_value(), inference_type));
235 node_mins->push_back(min_max.first);
236 node_maxs->push_back(min_max.second);
237 } else {
238 node_mins->push_back(llvm::None);
239 node_maxs->push_back(llvm::None);
240 }
241 }
242 }
243
244 if (mlir::TFL::GetInputNodeQuantSpecs(*node_names, *node_mins, *node_maxs,
245 inference_type, quant_specs)) {
246 return errors::InvalidArgument("Failed to get input quant spec.");
247 }
248
249 // Some extra flag related to post training quantization. If post-training
250 // quantization is enabled, `inference_type` and `inference_input_type` are
251 // not used by MLIR passes.
252 if (toco_flags.post_training_quantize()) {
253 quant_specs->weight_quantization = true;
254 if (toco_flags.quantize_to_float16()) {
255 quant_specs->inference_type = tensorflow::DT_HALF;
256 quant_specs->inference_input_type = tensorflow::DT_HALF;
257 } else {
258 quant_specs->inference_type = tensorflow::DT_QINT8;
259 quant_specs->inference_input_type = tensorflow::DT_QINT8;
260 }
261 }
262
263 // Other flags.
264 if (toco_flags.has_default_ranges_min()) {
265 quant_specs->default_ranges.first = toco_flags.default_ranges_min();
266 }
267 if (toco_flags.has_default_ranges_max()) {
268 quant_specs->default_ranges.second = toco_flags.default_ranges_max();
269 }
270
271 return ::tensorflow::Status::OK();
272 }
273
274 // Dumps the op graph of the `module` to `filename` in DOT format.
DumpOpGraphToFile(mlir::ModuleOp module,const std::string & filename)275 Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) {
276 std::string error_message;
277 auto output = mlir::openOutputFile(filename, &error_message);
278 if (!error_message.empty()) {
279 return errors::InvalidArgument("Failed to open file in ", filename);
280 }
281 mlir::PassManager pm(module.getContext());
282 pm.addPass(mlir::createPrintOpGraphPass(output->os()));
283 if (failed(pm.run(module))) {
284 return errors::Unknown("Failed to dump Op Graph from MLIR module.");
285 }
286 output->keep();
287 return Status::OK();
288 }
289
ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags & toco_flags,mlir::OwningModuleRef module,const mlir::TFL::PassConfig & pass_config,const std::unordered_set<std::string> & saved_model_tags,string * result,llvm::Optional<tensorflow::Session * > session)290 Status ConvertMLIRToTFLiteFlatBuffer(
291 const toco::TocoFlags& toco_flags, mlir::OwningModuleRef module,
292 const mlir::TFL::PassConfig& pass_config,
293 const std::unordered_set<std::string>& saved_model_tags, string* result,
294 llvm::Optional<tensorflow::Session*> session) {
295 bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
296 bool emit_select_tf_ops = toco_flags.enable_select_tf_ops();
297 bool emit_custom_ops = toco_flags.allow_custom_ops();
298
299 const std::unordered_set<std::string> select_user_tf_ops(
300 toco_flags.select_user_tf_ops().begin(),
301 toco_flags.select_user_tf_ops().end());
302
303 if (toco_flags.has_dump_graphviz_dir()) {
304 TF_RETURN_IF_ERROR(DumpOpGraphToFile(
305 module.get(),
306 // rename once we enable the new converter feature flag.
307 absl::StrCat(toco_flags.dump_graphviz_dir(), "/toco_AT_IMPORT.dot")));
308 }
309
310 mlir::PassManager pm(module->getContext(),
311 mlir::OpPassManager::Nesting::Implicit);
312 ::tensorflow::SetCrashReproducer(pm);
313
314 tensorflow::AddTFToTFLConversionPasses(pass_config, &pm, session);
315 // Convert back to outlined while format for export back to flatbuffer.
316 if (pass_config.legalize_tf_while) {
317 pm.addPass(mlir::TFL::CreateWhileOutlinePass());
318 }
319 pm.addPass(mlir::TFL::CreateRuntimeVerifyPass());
320
321 auto status = ConvertTFExecutorToTFLOrFlatbuffer(
322 module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
323 emit_select_tf_ops, emit_custom_ops, select_user_tf_ops,
324 pass_config.quant_specs, saved_model_tags, result, &pm);
325 if (toco_flags.has_dump_graphviz_dir()) {
326 TF_RETURN_IF_ERROR(DumpOpGraphToFile(
327 // rename once we enable the new converter feature flag.
328 module.get(), absl::StrCat(toco_flags.dump_graphviz_dir(),
329 "/toco_AFTER_TRANSFORMATIONS.dot")));
330 }
331
332 return status;
333 }
334
WarningUnusedFlags(const toco::ModelFlags & model_flags,const toco::TocoFlags & toco_flags)335 void WarningUnusedFlags(const toco::ModelFlags& model_flags,
336 const toco::TocoFlags& toco_flags) {
337 if (toco_flags.output_format()) {
338 LOG(WARNING) << "Ignored output_format.";
339 }
340 if (toco_flags.drop_control_dependency()) {
341 LOG(WARNING) << "Ignored drop_control_dependency.";
342 }
343 if (toco_flags.reorder_across_fake_quant()) {
344 LOG(WARNING) << "Ignored reorder_across_fake_quant.";
345 }
346 if (model_flags.change_concat_input_ranges()) {
347 LOG(WARNING) << "Ignored change_concat_input_ranges.";
348 }
349 if (toco_flags.dump_graphviz_include_video()) {
350 LOG(WARNING) << "Ignored dump_graphviz_video.";
351 }
352 if (model_flags.allow_nonexistent_arrays()) {
353 LOG(WARNING) << "Allow allow_nonexistent_arrays.";
354 }
355 }
356
357 } // namespace internal
358 } // namespace tensorflow
359