• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
17 
18 #include <string>
19 #include <unordered_set>
20 
21 #include "absl/types/span.h"
22 #include "llvm/Support/raw_ostream.h"
23 #include "mlir/IR/Attributes.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
25 #include "mlir/IR/Visitors.h"  // from @llvm-project
26 #include "mlir/Parser.h"  // from @llvm-project
27 #include "mlir/Pass/Pass.h"  // from @llvm-project
28 #include "mlir/Support/FileUtilities.h"  // from @llvm-project
29 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
30 #include "mlir/Transforms/Passes.h"  // from @llvm-project
31 #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
32 #include "tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h"
33 #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
34 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
36 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
37 #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
38 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
39 #include "tensorflow/core/framework/op.h"
40 #include "tensorflow/core/framework/op_def.pb.h"
41 #include "tensorflow/core/framework/types.pb.h"
42 #include "tensorflow/lite/tools/optimize/quantize_weights.h"
43 #include "tensorflow/lite/tools/optimize/reduced_precision_support.h"
44 #include "tensorflow/stream_executor/lib/statusor.h"
45 
46 namespace tensorflow {
47 namespace {
48 using mlir::MLIRContext;
49 using mlir::ModuleOp;
50 using mlir::Operation;
51 using mlir::OwningModuleRef;
52 using stream_executor::port::StatusOr;
53 
IsControlFlowV1Op(Operation * op)54 bool IsControlFlowV1Op(Operation* op) {
55   return mlir::isa<mlir::tf_executor::SwitchOp, mlir::tf_executor::MergeOp,
56                    mlir::tf_executor::EnterOp, mlir::tf_executor::ExitOp,
57                    mlir::tf_executor::NextIterationSinkOp,
58                    mlir::tf_executor::NextIterationSourceOp>(op);
59 }
60 
IsValidGraph(mlir::ModuleOp module)61 mlir::LogicalResult IsValidGraph(mlir::ModuleOp module) {
62   auto result = module.walk([&](Operation* op) {
63     return IsControlFlowV1Op(op) ? mlir::WalkResult::interrupt()
64                                  : mlir::WalkResult::advance();
65   });
66   if (result.wasInterrupted()) {
67     mlir::TFL::AttachErrorCode(
68         module.emitError(
69             "The graph has Control Flow V1 ops. TFLite converter doesn't "
70             "support Control Flow V1 ops. Consider using Control Flow V2 ops "
71             "instead. See https://www.tensorflow.org/api_docs/python/tf/compat/"
72             "v1/enable_control_flow_v2."),
73         tflite::metrics::ConverterErrorData::ERROR_UNSUPPORTED_CONTROL_FLOW_V1);
74     return mlir::failure();
75   }
76   return mlir::success();
77 }
78 
79 // Util that registers 'extra_tf_opdefs' to the TF global registry.
80 // Return OK on success, failure if registering failed.
RegisterExtraTfOpDefs(absl::Span<const std::string> extra_tf_opdefs)81 Status RegisterExtraTfOpDefs(absl::Span<const std::string> extra_tf_opdefs) {
82   for (const auto& tf_opdefs_string : extra_tf_opdefs) {
83     tensorflow::OpDef opdef;
84     if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string,
85                                                            &opdef)) {
86       LOG(ERROR) << "OpDef parsing failed for: " << tf_opdefs_string;
87       return errors::InvalidArgument("fail to parse extra OpDef");
88     }
89     // Register extra opdefs.
90     // TODO(b/133770952): Support shape functions.
91     tensorflow::OpRegistry::Global()->Register(
92         [opdef](tensorflow::OpRegistrationData* op_reg_data) -> Status {
93           *op_reg_data = tensorflow::OpRegistrationData(opdef);
94           return Status::OK();
95         });
96   }
97   return Status::OK();
98 }
99 }  // namespace
100 
LoadFromGraphdefOrMlirSource(const std::string & input_filename,bool input_mlir,bool use_splatted_constant,const std::vector<std::string> & extra_tf_opdefs,const GraphImportConfig & specs,absl::string_view debug_info_file,absl::string_view input_arrays,absl::string_view input_dtypes,absl::string_view input_shapes,absl::string_view output_arrays,absl::string_view control_output_arrays,llvm::SourceMgr * source_mgr,MLIRContext * context)101 StatusOr<OwningModuleRef> LoadFromGraphdefOrMlirSource(
102     const std::string& input_filename, bool input_mlir,
103     bool use_splatted_constant, const std::vector<std::string>& extra_tf_opdefs,
104     const GraphImportConfig& specs, absl::string_view debug_info_file,
105     absl::string_view input_arrays, absl::string_view input_dtypes,
106     absl::string_view input_shapes, absl::string_view output_arrays,
107     absl::string_view control_output_arrays, llvm::SourceMgr* source_mgr,
108     MLIRContext* context) {
109   // Set up the input file.
110   std::string error_message;
111   auto file = mlir::openInputFile(input_filename, &error_message);
112   if (!file) {
113     llvm::errs() << error_message << "\n";
114     return errors::InvalidArgument("fail to open input file");
115   }
116 
117   if (input_mlir) {
118     source_mgr->AddNewSourceBuffer(std::move(file), llvm::SMLoc());
119     return OwningModuleRef(mlir::parseSourceFile(*source_mgr, context));
120   }
121 
122   // Register extra TF ops passed as OpDef.
123   auto extra_opdefs_status = RegisterExtraTfOpDefs(extra_tf_opdefs);
124   if (!extra_opdefs_status.ok()) return extra_opdefs_status;
125 
126   if (use_splatted_constant) {
127     return tensorflow::GraphdefToSplattedMlirTranslateFunction(
128         file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
129         input_shapes, output_arrays, control_output_arrays,
130         specs.prune_unused_nodes, /*convert_legacy_fed_inputs=*/true,
131         /*graph_as_function=*/false, specs.upgrade_legacy,
132         /*enable_shape_inference=*/false, context);
133   }
134   return tensorflow::GraphdefToMlirTranslateFunction(
135       file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
136       input_shapes, output_arrays, control_output_arrays,
137       specs.prune_unused_nodes, /*convert_legacy_fed_inputs=*/true,
138       /*graph_as_function=*/false, specs.upgrade_legacy,
139       /*enable_shape_inference=*/false, context);
140 }
141 
ConvertTFExecutorToTFLOrFlatbuffer(mlir::ModuleOp module,bool export_to_mlir,bool emit_builtin_tflite_ops,bool emit_select_tf_ops,bool emit_custom_ops,bool allow_all_select_tf_ops,const std::unordered_set<std::string> & select_user_tf_ops,const mlir::TFL::QuantizationSpecs & quant_specs,const std::unordered_set<std::string> & saved_model_tags,std::string * result,mlir::PassManager * pass_manager)142 Status ConvertTFExecutorToTFLOrFlatbuffer(
143     mlir::ModuleOp module, bool export_to_mlir, bool emit_builtin_tflite_ops,
144     bool emit_select_tf_ops, bool emit_custom_ops, bool allow_all_select_tf_ops,
145     const std::unordered_set<std::string>& select_user_tf_ops,
146     const mlir::TFL::QuantizationSpecs& quant_specs,
147     const std::unordered_set<std::string>& saved_model_tags,
148     std::string* result, mlir::PassManager* pass_manager) {
149   // Explicitly disable dumping Op details on failures.
150   module.getContext()->printOpOnDiagnostic(false);
151 
152   // Register a warning handler only log to std out.
153   mlir::ScopedDiagnosticHandler s(
154       module.getContext(), [](mlir::Diagnostic& diag) {
155         if (diag.getSeverity() == mlir::DiagnosticSeverity::Warning) {
156           for (auto& note : diag.getNotes()) {
157             std::cout << note.str() << "\n";
158             LOG(WARNING) << note.str() << "\n";
159           }
160         }
161         return mlir::failure();
162       });
163 
164   mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext(),
165                                                     /*propagate=*/true);
166 
167   if (failed(IsValidGraph(module))) {
168     return statusHandler.ConsumeStatus();
169   }
170 
171   if (failed(pass_manager->run(module))) {
172     auto status = statusHandler.ConsumeStatus();
173     mlir::TFL::ErrorCollector* collector =
174         mlir::TFL::ErrorCollector::GetErrorCollector();
175     for (const auto& error_data : collector->CollectedErrors()) {
176       if (error_data.subcomponent() == "FreezeGlobalTensorsPass") {
177         // LINT.IfChange
178         return errors::InvalidArgument(
179             "Variable constant folding is failed. Please consider using "
180             "enabling `experimental_enable_resource_variables` flag in the "
181             "TFLite converter object. For example, "
182             "converter.experimental_enable_resource_variables = True");
183         // LINT.ThenChange(//tensorflow/lite/python/lite_v2_test.py)
184       }
185     }
186     return status;
187   }
188 
189   if (export_to_mlir) {
190     llvm::raw_string_ostream os(*result);
191     module.print(os);
192     return statusHandler.ConsumeStatus();
193   }
194 
195   // Write MLIR TFLite dialect into FlatBuffer
196   OpOrArgLocNameMapper op_or_arg_name_mapper;
197   if (!quant_specs.RunWeightQuantization()) {
198     tflite::FlatbufferExportOptions options;
199     options.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
200     options.emit_select_tf_ops = emit_select_tf_ops;
201     options.select_user_tf_ops = select_user_tf_ops;
202     options.allow_all_select_tf_ops = allow_all_select_tf_ops;
203     options.emit_custom_ops = emit_custom_ops;
204     options.saved_model_tags = saved_model_tags;
205     options.op_or_arg_name_mapper = &op_or_arg_name_mapper;
206     if (quant_specs.support_mask !=
207         tflite::optimize::ReducedPrecisionSupport::None) {
208       options.metadata.insert(
209           MetadataForReducedPrecisionSupport(quant_specs.support_mask));
210     }
211     if (!tflite::MlirToFlatBufferTranslateFunction(module, options, result)) {
212       return statusHandler.ConsumeStatus();
213     }
214   } else {
215     // Post-training weight quantization path. Once MLIR has support for this,
216     // we can remove this else statement.
217     std::string pre_quantized_result;
218     tflite::FlatbufferExportOptions options;
219     options.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
220     options.emit_select_tf_ops = emit_select_tf_ops;
221     options.select_user_tf_ops = select_user_tf_ops;
222     options.allow_all_select_tf_ops = allow_all_select_tf_ops;
223     options.emit_custom_ops = emit_custom_ops;
224     options.saved_model_tags = saved_model_tags;
225     options.op_or_arg_name_mapper = &op_or_arg_name_mapper;
226     if (quant_specs.support_mask !=
227         tflite::optimize::ReducedPrecisionSupport::None) {
228       options.metadata.insert(
229           MetadataForReducedPrecisionSupport(quant_specs.support_mask));
230     }
231     if (!tflite::MlirToFlatBufferTranslateFunction(module, options,
232                                                    &pre_quantized_result)) {
233       return statusHandler.ConsumeStatus();
234     }
235     flatbuffers::FlatBufferBuilder q_builder(/*initial_size=*/10240);
236     const uint8_t* buffer =
237         reinterpret_cast<const uint8_t*>(pre_quantized_result.c_str());
238     const ::tflite::Model* input_model = ::tflite::GetModel(buffer);
239 
240     ::tflite::optimize::BufferType quantized_type;
241     if (quant_specs.inference_type == tensorflow::DT_QINT8) {
242       quantized_type = ::tflite::optimize::BufferType::QUANTIZED_INT8;
243     } else if (quant_specs.inference_type == tensorflow::DT_HALF) {
244       quantized_type = ::tflite::optimize::BufferType::QUANTIZED_FLOAT16;
245     } else {
246       return errors::InvalidArgument("Quantized type not supported");
247     }
248     if (::tflite::optimize::QuantizeWeights(&q_builder, input_model,
249                                             quantized_type) != kTfLiteOk) {
250       return errors::InvalidArgument("Quantize weights transformation failed.");
251     }
252     const uint8_t* q_buffer = q_builder.GetBufferPointer();
253     *result =
254         string(reinterpret_cast<const char*>(q_buffer), q_builder.GetSize());
255   }
256 
257   if (mlir::failed(module.verify())) {
258     return tensorflow::errors::Unknown("Final module is invalid");
259   }
260   return Status::OK();
261 }
262 
ImportSavedModel(const std::string & input_filename,const int saved_model_version,const std::unordered_set<std::string> & tags,absl::Span<const std::string> extra_tf_opdefs,absl::Span<std::string> exported_names,const GraphImportConfig & specs,bool enable_variable_lifting,mlir::MLIRContext * context,std::unique_ptr<tensorflow::SavedModelBundle> * saved_model_bundle)263 StatusOr<mlir::OwningModuleRef> ImportSavedModel(
264     const std::string& input_filename, const int saved_model_version,
265     const std::unordered_set<std::string>& tags,
266     absl::Span<const std::string> extra_tf_opdefs,
267     absl::Span<std::string> exported_names, const GraphImportConfig& specs,
268     bool enable_variable_lifting, mlir::MLIRContext* context,
269     std::unique_ptr<tensorflow::SavedModelBundle>* saved_model_bundle) {
270   // Register extra TF ops passed as OpDef.
271   auto extra_opdefs_status = RegisterExtraTfOpDefs(extra_tf_opdefs);
272   if (!extra_opdefs_status.ok()) return extra_opdefs_status;
273 
274   if (saved_model_version == 2) {
275     auto module_or = tensorflow::SavedModelObjectGraphToMlirImport(
276         input_filename, tags, exported_names, context);
277     if (!module_or.status().ok()) return module_or.status();
278     return module_or.ConsumeValueOrDie();
279   } else if (saved_model_version == 1) {
280     MLIRImportOptions options;
281     options.upgrade_legacy = specs.upgrade_legacy;
282     auto module_or = tensorflow::SavedModelSignatureDefsToMlirImport(
283         input_filename, tags, exported_names, context, options,
284         enable_variable_lifting, saved_model_bundle);
285 
286     if (!module_or.status().ok()) return module_or.status();
287     return module_or.ConsumeValueOrDie();
288   } else {
289     return tensorflow::errors::InvalidArgument(
290         "Should be either saved model v1 or v2");
291   }
292 }
293 
294 }  // namespace tensorflow
295