1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
17
18 #include <string>
19 #include <unordered_set>
20
21 #include "absl/types/span.h"
22 #include "llvm/Support/raw_ostream.h"
23 #include "mlir/IR/Attributes.h" // from @llvm-project
24 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
25 #include "mlir/IR/Visitors.h" // from @llvm-project
26 #include "mlir/Parser.h" // from @llvm-project
27 #include "mlir/Pass/Pass.h" // from @llvm-project
28 #include "mlir/Support/FileUtilities.h" // from @llvm-project
29 #include "mlir/Support/LogicalResult.h" // from @llvm-project
30 #include "mlir/Transforms/Passes.h" // from @llvm-project
31 #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
32 #include "tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h"
33 #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
34 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
36 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
37 #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
38 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
39 #include "tensorflow/core/framework/op.h"
40 #include "tensorflow/core/framework/op_def.pb.h"
41 #include "tensorflow/core/framework/types.pb.h"
42 #include "tensorflow/lite/tools/optimize/quantize_weights.h"
43 #include "tensorflow/lite/tools/optimize/reduced_precision_support.h"
44 #include "tensorflow/stream_executor/lib/statusor.h"
45
46 namespace tensorflow {
47 namespace {
48 using mlir::MLIRContext;
49 using mlir::ModuleOp;
50 using mlir::Operation;
51 using mlir::OwningModuleRef;
52 using stream_executor::port::StatusOr;
53
IsControlFlowV1Op(Operation * op)54 bool IsControlFlowV1Op(Operation* op) {
55 return mlir::isa<mlir::tf_executor::SwitchOp, mlir::tf_executor::MergeOp,
56 mlir::tf_executor::EnterOp, mlir::tf_executor::ExitOp,
57 mlir::tf_executor::NextIterationSinkOp,
58 mlir::tf_executor::NextIterationSourceOp>(op);
59 }
60
IsValidGraph(mlir::ModuleOp module)61 mlir::LogicalResult IsValidGraph(mlir::ModuleOp module) {
62 auto result = module.walk([&](Operation* op) {
63 return IsControlFlowV1Op(op) ? mlir::WalkResult::interrupt()
64 : mlir::WalkResult::advance();
65 });
66 if (result.wasInterrupted()) {
67 mlir::TFL::AttachErrorCode(
68 module.emitError(
69 "The graph has Control Flow V1 ops. TFLite converter doesn't "
70 "support Control Flow V1 ops. Consider using Control Flow V2 ops "
71 "instead. See https://www.tensorflow.org/api_docs/python/tf/compat/"
72 "v1/enable_control_flow_v2."),
73 tflite::metrics::ConverterErrorData::ERROR_UNSUPPORTED_CONTROL_FLOW_V1);
74 return mlir::failure();
75 }
76 return mlir::success();
77 }
78
79 // Util that registers 'extra_tf_opdefs' to the TF global registry.
80 // Return OK on success, failure if registering failed.
RegisterExtraTfOpDefs(absl::Span<const std::string> extra_tf_opdefs)81 Status RegisterExtraTfOpDefs(absl::Span<const std::string> extra_tf_opdefs) {
82 for (const auto& tf_opdefs_string : extra_tf_opdefs) {
83 tensorflow::OpDef opdef;
84 if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string,
85 &opdef)) {
86 LOG(ERROR) << "OpDef parsing failed for: " << tf_opdefs_string;
87 return errors::InvalidArgument("fail to parse extra OpDef");
88 }
89 // Register extra opdefs.
90 // TODO(b/133770952): Support shape functions.
91 tensorflow::OpRegistry::Global()->Register(
92 [opdef](tensorflow::OpRegistrationData* op_reg_data) -> Status {
93 *op_reg_data = tensorflow::OpRegistrationData(opdef);
94 return Status::OK();
95 });
96 }
97 return Status::OK();
98 }
99 } // namespace
100
LoadFromGraphdefOrMlirSource(const std::string & input_filename,bool input_mlir,bool use_splatted_constant,const std::vector<std::string> & extra_tf_opdefs,const GraphImportConfig & specs,absl::string_view debug_info_file,absl::string_view input_arrays,absl::string_view input_dtypes,absl::string_view input_shapes,absl::string_view output_arrays,absl::string_view control_output_arrays,llvm::SourceMgr * source_mgr,MLIRContext * context)101 StatusOr<OwningModuleRef> LoadFromGraphdefOrMlirSource(
102 const std::string& input_filename, bool input_mlir,
103 bool use_splatted_constant, const std::vector<std::string>& extra_tf_opdefs,
104 const GraphImportConfig& specs, absl::string_view debug_info_file,
105 absl::string_view input_arrays, absl::string_view input_dtypes,
106 absl::string_view input_shapes, absl::string_view output_arrays,
107 absl::string_view control_output_arrays, llvm::SourceMgr* source_mgr,
108 MLIRContext* context) {
109 // Set up the input file.
110 std::string error_message;
111 auto file = mlir::openInputFile(input_filename, &error_message);
112 if (!file) {
113 llvm::errs() << error_message << "\n";
114 return errors::InvalidArgument("fail to open input file");
115 }
116
117 if (input_mlir) {
118 source_mgr->AddNewSourceBuffer(std::move(file), llvm::SMLoc());
119 return OwningModuleRef(mlir::parseSourceFile(*source_mgr, context));
120 }
121
122 // Register extra TF ops passed as OpDef.
123 auto extra_opdefs_status = RegisterExtraTfOpDefs(extra_tf_opdefs);
124 if (!extra_opdefs_status.ok()) return extra_opdefs_status;
125
126 if (use_splatted_constant) {
127 return tensorflow::GraphdefToSplattedMlirTranslateFunction(
128 file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
129 input_shapes, output_arrays, control_output_arrays,
130 specs.prune_unused_nodes, /*convert_legacy_fed_inputs=*/true,
131 /*graph_as_function=*/false, specs.upgrade_legacy,
132 /*enable_shape_inference=*/false, context);
133 }
134 return tensorflow::GraphdefToMlirTranslateFunction(
135 file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
136 input_shapes, output_arrays, control_output_arrays,
137 specs.prune_unused_nodes, /*convert_legacy_fed_inputs=*/true,
138 /*graph_as_function=*/false, specs.upgrade_legacy,
139 /*enable_shape_inference=*/false, context);
140 }
141
ConvertTFExecutorToTFLOrFlatbuffer(mlir::ModuleOp module,bool export_to_mlir,bool emit_builtin_tflite_ops,bool emit_select_tf_ops,bool emit_custom_ops,bool allow_all_select_tf_ops,const std::unordered_set<std::string> & select_user_tf_ops,const mlir::TFL::QuantizationSpecs & quant_specs,const std::unordered_set<std::string> & saved_model_tags,std::string * result,mlir::PassManager * pass_manager)142 Status ConvertTFExecutorToTFLOrFlatbuffer(
143 mlir::ModuleOp module, bool export_to_mlir, bool emit_builtin_tflite_ops,
144 bool emit_select_tf_ops, bool emit_custom_ops, bool allow_all_select_tf_ops,
145 const std::unordered_set<std::string>& select_user_tf_ops,
146 const mlir::TFL::QuantizationSpecs& quant_specs,
147 const std::unordered_set<std::string>& saved_model_tags,
148 std::string* result, mlir::PassManager* pass_manager) {
149 // Explicitly disable dumping Op details on failures.
150 module.getContext()->printOpOnDiagnostic(false);
151
152 // Register a warning handler only log to std out.
153 mlir::ScopedDiagnosticHandler s(
154 module.getContext(), [](mlir::Diagnostic& diag) {
155 if (diag.getSeverity() == mlir::DiagnosticSeverity::Warning) {
156 for (auto& note : diag.getNotes()) {
157 std::cout << note.str() << "\n";
158 LOG(WARNING) << note.str() << "\n";
159 }
160 }
161 return mlir::failure();
162 });
163
164 mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext(),
165 /*propagate=*/true);
166
167 if (failed(IsValidGraph(module))) {
168 return statusHandler.ConsumeStatus();
169 }
170
171 if (failed(pass_manager->run(module))) {
172 auto status = statusHandler.ConsumeStatus();
173 mlir::TFL::ErrorCollector* collector =
174 mlir::TFL::ErrorCollector::GetErrorCollector();
175 for (const auto& error_data : collector->CollectedErrors()) {
176 if (error_data.subcomponent() == "FreezeGlobalTensorsPass") {
177 // LINT.IfChange
178 return errors::InvalidArgument(
179 "Variable constant folding is failed. Please consider using "
180 "enabling `experimental_enable_resource_variables` flag in the "
181 "TFLite converter object. For example, "
182 "converter.experimental_enable_resource_variables = True");
183 // LINT.ThenChange(//tensorflow/lite/python/lite_v2_test.py)
184 }
185 }
186 return status;
187 }
188
189 if (export_to_mlir) {
190 llvm::raw_string_ostream os(*result);
191 module.print(os);
192 return statusHandler.ConsumeStatus();
193 }
194
195 // Write MLIR TFLite dialect into FlatBuffer
196 OpOrArgLocNameMapper op_or_arg_name_mapper;
197 if (!quant_specs.RunWeightQuantization()) {
198 tflite::FlatbufferExportOptions options;
199 options.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
200 options.emit_select_tf_ops = emit_select_tf_ops;
201 options.select_user_tf_ops = select_user_tf_ops;
202 options.allow_all_select_tf_ops = allow_all_select_tf_ops;
203 options.emit_custom_ops = emit_custom_ops;
204 options.saved_model_tags = saved_model_tags;
205 options.op_or_arg_name_mapper = &op_or_arg_name_mapper;
206 if (quant_specs.support_mask !=
207 tflite::optimize::ReducedPrecisionSupport::None) {
208 options.metadata.insert(
209 MetadataForReducedPrecisionSupport(quant_specs.support_mask));
210 }
211 if (!tflite::MlirToFlatBufferTranslateFunction(module, options, result)) {
212 return statusHandler.ConsumeStatus();
213 }
214 } else {
215 // Post-training weight quantization path. Once MLIR has support for this,
216 // we can remove this else statement.
217 std::string pre_quantized_result;
218 tflite::FlatbufferExportOptions options;
219 options.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
220 options.emit_select_tf_ops = emit_select_tf_ops;
221 options.select_user_tf_ops = select_user_tf_ops;
222 options.allow_all_select_tf_ops = allow_all_select_tf_ops;
223 options.emit_custom_ops = emit_custom_ops;
224 options.saved_model_tags = saved_model_tags;
225 options.op_or_arg_name_mapper = &op_or_arg_name_mapper;
226 if (quant_specs.support_mask !=
227 tflite::optimize::ReducedPrecisionSupport::None) {
228 options.metadata.insert(
229 MetadataForReducedPrecisionSupport(quant_specs.support_mask));
230 }
231 if (!tflite::MlirToFlatBufferTranslateFunction(module, options,
232 &pre_quantized_result)) {
233 return statusHandler.ConsumeStatus();
234 }
235 flatbuffers::FlatBufferBuilder q_builder(/*initial_size=*/10240);
236 const uint8_t* buffer =
237 reinterpret_cast<const uint8_t*>(pre_quantized_result.c_str());
238 const ::tflite::Model* input_model = ::tflite::GetModel(buffer);
239
240 ::tflite::optimize::BufferType quantized_type;
241 if (quant_specs.inference_type == tensorflow::DT_QINT8) {
242 quantized_type = ::tflite::optimize::BufferType::QUANTIZED_INT8;
243 } else if (quant_specs.inference_type == tensorflow::DT_HALF) {
244 quantized_type = ::tflite::optimize::BufferType::QUANTIZED_FLOAT16;
245 } else {
246 return errors::InvalidArgument("Quantized type not supported");
247 }
248 if (::tflite::optimize::QuantizeWeights(&q_builder, input_model,
249 quantized_type) != kTfLiteOk) {
250 return errors::InvalidArgument("Quantize weights transformation failed.");
251 }
252 const uint8_t* q_buffer = q_builder.GetBufferPointer();
253 *result =
254 string(reinterpret_cast<const char*>(q_buffer), q_builder.GetSize());
255 }
256
257 if (mlir::failed(module.verify())) {
258 return tensorflow::errors::Unknown("Final module is invalid");
259 }
260 return Status::OK();
261 }
262
ImportSavedModel(const std::string & input_filename,const int saved_model_version,const std::unordered_set<std::string> & tags,absl::Span<const std::string> extra_tf_opdefs,absl::Span<std::string> exported_names,const GraphImportConfig & specs,bool enable_variable_lifting,mlir::MLIRContext * context,std::unique_ptr<tensorflow::SavedModelBundle> * saved_model_bundle)263 StatusOr<mlir::OwningModuleRef> ImportSavedModel(
264 const std::string& input_filename, const int saved_model_version,
265 const std::unordered_set<std::string>& tags,
266 absl::Span<const std::string> extra_tf_opdefs,
267 absl::Span<std::string> exported_names, const GraphImportConfig& specs,
268 bool enable_variable_lifting, mlir::MLIRContext* context,
269 std::unique_ptr<tensorflow::SavedModelBundle>* saved_model_bundle) {
270 // Register extra TF ops passed as OpDef.
271 auto extra_opdefs_status = RegisterExtraTfOpDefs(extra_tf_opdefs);
272 if (!extra_opdefs_status.ok()) return extra_opdefs_status;
273
274 if (saved_model_version == 2) {
275 auto module_or = tensorflow::SavedModelObjectGraphToMlirImport(
276 input_filename, tags, exported_names, context);
277 if (!module_or.status().ok()) return module_or.status();
278 return module_or.ConsumeValueOrDie();
279 } else if (saved_model_version == 1) {
280 MLIRImportOptions options;
281 options.upgrade_legacy = specs.upgrade_legacy;
282 auto module_or = tensorflow::SavedModelSignatureDefsToMlirImport(
283 input_filename, tags, exported_names, context, options,
284 enable_variable_lifting, saved_model_bundle);
285
286 if (!module_or.status().ok()) return module_or.status();
287 return module_or.ConsumeValueOrDie();
288 } else {
289 return tensorflow::errors::InvalidArgument(
290 "Should be either saved model v1 or v2");
291 }
292 }
293
294 } // namespace tensorflow
295