1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
17
18 #include <string>
19 #include <unordered_set>
20
21 #include "absl/types/span.h"
22 #include "llvm/Support/raw_ostream.h"
23 #include "mlir/IR/Attributes.h" // from @llvm-project
24 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
25 #include "mlir/IR/Visitors.h" // from @llvm-project
26 #include "mlir/Parser.h" // from @llvm-project
27 #include "mlir/Pass/Pass.h" // from @llvm-project
28 #include "mlir/Support/FileUtilities.h" // from @llvm-project
29 #include "mlir/Transforms/Passes.h" // from @llvm-project
30 #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
31 #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
32 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
34 #include "tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h"
35 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
36 #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
37 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
38 #include "tensorflow/core/framework/op.h"
39 #include "tensorflow/core/framework/op_def.pb.h"
40 #include "tensorflow/core/framework/types.pb.h"
41 #include "tensorflow/lite/tools/optimize/quantize_weights.h"
42 #include "tensorflow/stream_executor/lib/statusor.h"
43
44 namespace tensorflow {
45 namespace {
46 using mlir::MLIRContext;
47 using mlir::ModuleOp;
48 using mlir::Operation;
49 using mlir::OwningModuleRef;
50 using stream_executor::port::StatusOr;
51
IsControlFlowV1Op(Operation * op)52 bool IsControlFlowV1Op(Operation* op) {
53 return mlir::isa<mlir::tf_executor::SwitchOp, mlir::tf_executor::MergeOp,
54 mlir::tf_executor::EnterOp, mlir::tf_executor::ExitOp,
55 mlir::tf_executor::NextIterationSinkOp,
56 mlir::tf_executor::NextIterationSourceOp>(op);
57 }
58
IsValidGraph(mlir::ModuleOp module)59 mlir::LogicalResult IsValidGraph(mlir::ModuleOp module) {
60 auto result = module.walk([&](Operation* op) {
61 return IsControlFlowV1Op(op) ? mlir::WalkResult::interrupt()
62 : mlir::WalkResult::advance();
63 });
64 if (result.wasInterrupted()) {
65 module.emitError(
66 "The graph has Control Flow V1 ops. TFLite converter doesn't support "
67 "Control Flow V1 ops. Consider using Control Flow V2 ops instead. See "
68 "https://www.tensorflow.org/api_docs/python/tf/compat/v1/"
69 "enable_control_flow_v2.");
70 return mlir::failure();
71 }
72 return mlir::success();
73 }
74
75 // Util that registers 'extra_tf_opdefs' to the TF global registry.
76 // Return OK on success, failure if registering failed.
RegisterExtraTfOpDefs(absl::Span<const std::string> extra_tf_opdefs)77 Status RegisterExtraTfOpDefs(absl::Span<const std::string> extra_tf_opdefs) {
78 for (const auto& tf_opdefs_string : extra_tf_opdefs) {
79 tensorflow::OpDef opdef;
80 if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string,
81 &opdef)) {
82 LOG(ERROR) << "OpDef parsing failed for: " << tf_opdefs_string;
83 return errors::InvalidArgument("fail to parse extra OpDef");
84 }
85 // Register extra opdefs.
86 // TODO(b/133770952): Support shape functions.
87 tensorflow::OpRegistry::Global()->Register(
88 [opdef](tensorflow::OpRegistrationData* op_reg_data) -> Status {
89 *op_reg_data = tensorflow::OpRegistrationData(opdef);
90 return Status::OK();
91 });
92 }
93 return Status::OK();
94 }
95 } // namespace
96
LoadFromGraphdefOrMlirSource(const std::string & input_filename,bool input_mlir,bool use_splatted_constant,const std::vector<std::string> & extra_tf_opdefs,const GraphImportConfig & specs,absl::string_view debug_info_file,absl::string_view input_arrays,absl::string_view input_dtypes,absl::string_view input_shapes,absl::string_view output_arrays,llvm::SourceMgr * source_mgr,MLIRContext * context)97 StatusOr<OwningModuleRef> LoadFromGraphdefOrMlirSource(
98 const std::string& input_filename, bool input_mlir,
99 bool use_splatted_constant, const std::vector<std::string>& extra_tf_opdefs,
100 const GraphImportConfig& specs, absl::string_view debug_info_file,
101 absl::string_view input_arrays, absl::string_view input_dtypes,
102 absl::string_view input_shapes, absl::string_view output_arrays,
103 llvm::SourceMgr* source_mgr, MLIRContext* context) {
104 // Set up the input file.
105 std::string error_message;
106 auto file = mlir::openInputFile(input_filename, &error_message);
107 if (!file) {
108 llvm::errs() << error_message << "\n";
109 return errors::InvalidArgument("fail to open input file");
110 }
111
112 if (input_mlir) {
113 source_mgr->AddNewSourceBuffer(std::move(file), llvm::SMLoc());
114 return OwningModuleRef(mlir::parseSourceFile(*source_mgr, context));
115 }
116
117 // Register extra TF ops passed as OpDef.
118 auto extra_opdefs_status = RegisterExtraTfOpDefs(extra_tf_opdefs);
119 if (!extra_opdefs_status.ok()) return extra_opdefs_status;
120
121 if (use_splatted_constant) {
122 return tensorflow::GraphdefToSplattedMlirTranslateFunction(
123 file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
124 input_shapes, output_arrays, /*control_output_arrays=*/"",
125 specs.prune_unused_nodes, /*convert_legacy_fed_inputs=*/true,
126 /*graph_as_function=*/false, specs.upgrade_legacy,
127 /*enable_shape_inference=*/false, context);
128 }
129 return tensorflow::GraphdefToMlirTranslateFunction(
130 file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
131 input_shapes, output_arrays, /*control_output_arrays=*/"",
132 specs.prune_unused_nodes, /*convert_legacy_fed_inputs=*/true,
133 /*graph_as_function=*/false, specs.upgrade_legacy,
134 /*enable_shape_inference=*/false, context);
135 }
136
ConvertTFExecutorToTFLOrFlatbuffer(mlir::ModuleOp module,bool export_to_mlir,bool emit_builtin_tflite_ops,bool emit_select_tf_ops,bool emit_custom_ops,const std::unordered_set<std::string> & select_user_tf_ops,const mlir::TFL::QuantizationSpecs & quant_specs,const std::unordered_set<std::string> & saved_model_tags,std::string * result,mlir::PassManager * pass_manager)137 Status ConvertTFExecutorToTFLOrFlatbuffer(
138 mlir::ModuleOp module, bool export_to_mlir, bool emit_builtin_tflite_ops,
139 bool emit_select_tf_ops, bool emit_custom_ops,
140 const std::unordered_set<std::string>& select_user_tf_ops,
141 const mlir::TFL::QuantizationSpecs& quant_specs,
142 const std::unordered_set<std::string>& saved_model_tags,
143 std::string* result, mlir::PassManager* pass_manager) {
144 // Explicitly disable dumping Op details on failures.
145 module.getContext()->printOpOnDiagnostic(false);
146
147 // Register a warning handler only log to std out.
148 mlir::ScopedDiagnosticHandler s(
149 module.getContext(), [](mlir::Diagnostic& diag) {
150 if (diag.getSeverity() == mlir::DiagnosticSeverity::Warning) {
151 for (auto& note : diag.getNotes()) {
152 std::cout << note.str() << "\n";
153 LOG(WARNING) << note.str() << "\n";
154 }
155 }
156 return mlir::failure();
157 });
158
159 mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext(),
160 /*propagate=*/true);
161
162 if (failed(IsValidGraph(module)) || failed(pass_manager->run(module))) {
163 return statusHandler.ConsumeStatus();
164 }
165
166 if (export_to_mlir) {
167 llvm::raw_string_ostream os(*result);
168 module.print(os);
169 return Status::OK();
170 }
171
172 // Write MLIR TFLite dialect into FlatBuffer
173 OpOrArgLocNameMapper op_or_arg_name_mapper;
174 if (!quant_specs.RunWeightQuantization()) {
175 tflite::FlatbufferExportOptions options;
176 options.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
177 options.emit_select_tf_ops = emit_select_tf_ops;
178 options.select_user_tf_ops = select_user_tf_ops;
179 options.emit_custom_ops = emit_custom_ops;
180 options.saved_model_tags = saved_model_tags;
181 options.op_or_arg_name_mapper = &op_or_arg_name_mapper;
182 if (!tflite::MlirToFlatBufferTranslateFunction(module, options, result)) {
183 return statusHandler.ConsumeStatus();
184 }
185 } else {
186 // Post-training weight quantization path. Once MLIR has support for this,
187 // we can remove this else statement.
188 std::string pre_quantized_result;
189 tflite::FlatbufferExportOptions options;
190 options.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
191 options.emit_select_tf_ops = emit_select_tf_ops;
192 options.select_user_tf_ops = select_user_tf_ops;
193 options.emit_custom_ops = emit_custom_ops;
194 options.saved_model_tags = saved_model_tags;
195 options.op_or_arg_name_mapper = &op_or_arg_name_mapper;
196 if (!tflite::MlirToFlatBufferTranslateFunction(module, options,
197 &pre_quantized_result)) {
198 return statusHandler.ConsumeStatus();
199 }
200 flatbuffers::FlatBufferBuilder q_builder(/*initial_size=*/10240);
201 const uint8_t* buffer =
202 reinterpret_cast<const uint8_t*>(pre_quantized_result.c_str());
203 const ::tflite::Model* input_model = ::tflite::GetModel(buffer);
204
205 ::tflite::optimize::BufferType quantized_type;
206 if (quant_specs.inference_type == tensorflow::DT_QINT8) {
207 quantized_type = ::tflite::optimize::BufferType::QUANTIZED_INT8;
208 } else if (quant_specs.inference_type == tensorflow::DT_HALF) {
209 quantized_type = ::tflite::optimize::BufferType::QUANTIZED_FLOAT16;
210 } else {
211 return errors::InvalidArgument("Quantized type not supported");
212 }
213 if (::tflite::optimize::QuantizeWeights(&q_builder, input_model,
214 quantized_type) != kTfLiteOk) {
215 return errors::InvalidArgument("Quantize weights transformation failed.");
216 }
217 const uint8_t* q_buffer = q_builder.GetBufferPointer();
218 *result =
219 string(reinterpret_cast<const char*>(q_buffer), q_builder.GetSize());
220 }
221
222 return Status::OK();
223 }
224
ImportSavedModel(const std::string & input_filename,const int saved_model_version,const std::unordered_set<std::string> & tags,absl::Span<const std::string> extra_tf_opdefs,absl::Span<std::string> exported_names,const GraphImportConfig & specs,mlir::MLIRContext * context)225 StatusOr<mlir::OwningModuleRef> ImportSavedModel(
226 const std::string& input_filename, const int saved_model_version,
227 const std::unordered_set<std::string>& tags,
228 absl::Span<const std::string> extra_tf_opdefs,
229 absl::Span<std::string> exported_names, const GraphImportConfig& specs,
230 mlir::MLIRContext* context) {
231 // Register extra TF ops passed as OpDef.
232 auto extra_opdefs_status = RegisterExtraTfOpDefs(extra_tf_opdefs);
233 if (!extra_opdefs_status.ok()) return extra_opdefs_status;
234
235 if (saved_model_version == 2) {
236 auto module_or = tensorflow::SavedModelObjectGraphToMlirImport(
237 input_filename, tags, exported_names, context);
238 if (!module_or.status().ok()) return module_or.status();
239 return module_or.ConsumeValueOrDie();
240 } else if (saved_model_version == 1) {
241 MLIRImportOptions options;
242 options.upgrade_legacy = specs.upgrade_legacy;
243 auto module_or = tensorflow::SavedModelSignatureDefsToMlirImport(
244 input_filename, tags, exported_names, context, options);
245
246 if (!module_or.status().ok()) return module_or.status();
247 return module_or.ConsumeValueOrDie();
248 } else {
249 return tensorflow::errors::InvalidArgument(
250 "Should be either saved model v1 or v2");
251 }
252 }
253
254 } // namespace tensorflow
255