• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
17 
18 #include <string>
19 #include <unordered_set>
20 
21 #include "absl/types/span.h"
22 #include "llvm/Support/raw_ostream.h"
23 #include "mlir/IR/Attributes.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
25 #include "mlir/IR/Visitors.h"  // from @llvm-project
26 #include "mlir/Parser.h"  // from @llvm-project
27 #include "mlir/Pass/Pass.h"  // from @llvm-project
28 #include "mlir/Support/FileUtilities.h"  // from @llvm-project
29 #include "mlir/Transforms/Passes.h"  // from @llvm-project
30 #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
31 #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
32 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
34 #include "tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h"
35 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
36 #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
37 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
38 #include "tensorflow/core/framework/op.h"
39 #include "tensorflow/core/framework/op_def.pb.h"
40 #include "tensorflow/core/framework/types.pb.h"
41 #include "tensorflow/lite/tools/optimize/quantize_weights.h"
42 #include "tensorflow/stream_executor/lib/statusor.h"
43 
44 namespace tensorflow {
45 namespace {
46 using mlir::MLIRContext;
47 using mlir::ModuleOp;
48 using mlir::Operation;
49 using mlir::OwningModuleRef;
50 using stream_executor::port::StatusOr;
51 
IsControlFlowV1Op(Operation * op)52 bool IsControlFlowV1Op(Operation* op) {
53   return mlir::isa<mlir::tf_executor::SwitchOp, mlir::tf_executor::MergeOp,
54                    mlir::tf_executor::EnterOp, mlir::tf_executor::ExitOp,
55                    mlir::tf_executor::NextIterationSinkOp,
56                    mlir::tf_executor::NextIterationSourceOp>(op);
57 }
58 
IsValidGraph(mlir::ModuleOp module)59 mlir::LogicalResult IsValidGraph(mlir::ModuleOp module) {
60   auto result = module.walk([&](Operation* op) {
61     return IsControlFlowV1Op(op) ? mlir::WalkResult::interrupt()
62                                  : mlir::WalkResult::advance();
63   });
64   if (result.wasInterrupted()) {
65     module.emitError(
66         "The graph has Control Flow V1 ops. TFLite converter doesn't support "
67         "Control Flow V1 ops. Consider using Control Flow V2 ops instead. See "
68         "https://www.tensorflow.org/api_docs/python/tf/compat/v1/"
69         "enable_control_flow_v2.");
70     return mlir::failure();
71   }
72   return mlir::success();
73 }
74 
75 // Util that registers 'extra_tf_opdefs' to the TF global registry.
76 // Return OK on success, failure if registering failed.
RegisterExtraTfOpDefs(absl::Span<const std::string> extra_tf_opdefs)77 Status RegisterExtraTfOpDefs(absl::Span<const std::string> extra_tf_opdefs) {
78   for (const auto& tf_opdefs_string : extra_tf_opdefs) {
79     tensorflow::OpDef opdef;
80     if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string,
81                                                            &opdef)) {
82       LOG(ERROR) << "OpDef parsing failed for: " << tf_opdefs_string;
83       return errors::InvalidArgument("fail to parse extra OpDef");
84     }
85     // Register extra opdefs.
86     // TODO(b/133770952): Support shape functions.
87     tensorflow::OpRegistry::Global()->Register(
88         [opdef](tensorflow::OpRegistrationData* op_reg_data) -> Status {
89           *op_reg_data = tensorflow::OpRegistrationData(opdef);
90           return Status::OK();
91         });
92   }
93   return Status::OK();
94 }
95 }  // namespace
96 
LoadFromGraphdefOrMlirSource(const std::string & input_filename,bool input_mlir,bool use_splatted_constant,const std::vector<std::string> & extra_tf_opdefs,const GraphImportConfig & specs,absl::string_view debug_info_file,absl::string_view input_arrays,absl::string_view input_dtypes,absl::string_view input_shapes,absl::string_view output_arrays,llvm::SourceMgr * source_mgr,MLIRContext * context)97 StatusOr<OwningModuleRef> LoadFromGraphdefOrMlirSource(
98     const std::string& input_filename, bool input_mlir,
99     bool use_splatted_constant, const std::vector<std::string>& extra_tf_opdefs,
100     const GraphImportConfig& specs, absl::string_view debug_info_file,
101     absl::string_view input_arrays, absl::string_view input_dtypes,
102     absl::string_view input_shapes, absl::string_view output_arrays,
103     llvm::SourceMgr* source_mgr, MLIRContext* context) {
104   // Set up the input file.
105   std::string error_message;
106   auto file = mlir::openInputFile(input_filename, &error_message);
107   if (!file) {
108     llvm::errs() << error_message << "\n";
109     return errors::InvalidArgument("fail to open input file");
110   }
111 
112   if (input_mlir) {
113     source_mgr->AddNewSourceBuffer(std::move(file), llvm::SMLoc());
114     return OwningModuleRef(mlir::parseSourceFile(*source_mgr, context));
115   }
116 
117   // Register extra TF ops passed as OpDef.
118   auto extra_opdefs_status = RegisterExtraTfOpDefs(extra_tf_opdefs);
119   if (!extra_opdefs_status.ok()) return extra_opdefs_status;
120 
121   if (use_splatted_constant) {
122     return tensorflow::GraphdefToSplattedMlirTranslateFunction(
123         file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
124         input_shapes, output_arrays, /*control_output_arrays=*/"",
125         specs.prune_unused_nodes, /*convert_legacy_fed_inputs=*/true,
126         /*graph_as_function=*/false, specs.upgrade_legacy,
127         /*enable_shape_inference=*/false, context);
128   }
129   return tensorflow::GraphdefToMlirTranslateFunction(
130       file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
131       input_shapes, output_arrays, /*control_output_arrays=*/"",
132       specs.prune_unused_nodes, /*convert_legacy_fed_inputs=*/true,
133       /*graph_as_function=*/false, specs.upgrade_legacy,
134       /*enable_shape_inference=*/false, context);
135 }
136 
ConvertTFExecutorToTFLOrFlatbuffer(mlir::ModuleOp module,bool export_to_mlir,bool emit_builtin_tflite_ops,bool emit_select_tf_ops,bool emit_custom_ops,const std::unordered_set<std::string> & select_user_tf_ops,const mlir::TFL::QuantizationSpecs & quant_specs,const std::unordered_set<std::string> & saved_model_tags,std::string * result,mlir::PassManager * pass_manager)137 Status ConvertTFExecutorToTFLOrFlatbuffer(
138     mlir::ModuleOp module, bool export_to_mlir, bool emit_builtin_tflite_ops,
139     bool emit_select_tf_ops, bool emit_custom_ops,
140     const std::unordered_set<std::string>& select_user_tf_ops,
141     const mlir::TFL::QuantizationSpecs& quant_specs,
142     const std::unordered_set<std::string>& saved_model_tags,
143     std::string* result, mlir::PassManager* pass_manager) {
144   // Explicitly disable dumping Op details on failures.
145   module.getContext()->printOpOnDiagnostic(false);
146 
147   // Register a warning handler only log to std out.
148   mlir::ScopedDiagnosticHandler s(
149       module.getContext(), [](mlir::Diagnostic& diag) {
150         if (diag.getSeverity() == mlir::DiagnosticSeverity::Warning) {
151           for (auto& note : diag.getNotes()) {
152             std::cout << note.str() << "\n";
153             LOG(WARNING) << note.str() << "\n";
154           }
155         }
156         return mlir::failure();
157       });
158 
159   mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext(),
160                                                     /*propagate=*/true);
161 
162   if (failed(IsValidGraph(module)) || failed(pass_manager->run(module))) {
163     return statusHandler.ConsumeStatus();
164   }
165 
166   if (export_to_mlir) {
167     llvm::raw_string_ostream os(*result);
168     module.print(os);
169     return Status::OK();
170   }
171 
172   // Write MLIR TFLite dialect into FlatBuffer
173   OpOrArgLocNameMapper op_or_arg_name_mapper;
174   if (!quant_specs.RunWeightQuantization()) {
175     tflite::FlatbufferExportOptions options;
176     options.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
177     options.emit_select_tf_ops = emit_select_tf_ops;
178     options.select_user_tf_ops = select_user_tf_ops;
179     options.emit_custom_ops = emit_custom_ops;
180     options.saved_model_tags = saved_model_tags;
181     options.op_or_arg_name_mapper = &op_or_arg_name_mapper;
182     if (!tflite::MlirToFlatBufferTranslateFunction(module, options, result)) {
183       return statusHandler.ConsumeStatus();
184     }
185   } else {
186     // Post-training weight quantization path. Once MLIR has support for this,
187     // we can remove this else statement.
188     std::string pre_quantized_result;
189     tflite::FlatbufferExportOptions options;
190     options.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
191     options.emit_select_tf_ops = emit_select_tf_ops;
192     options.select_user_tf_ops = select_user_tf_ops;
193     options.emit_custom_ops = emit_custom_ops;
194     options.saved_model_tags = saved_model_tags;
195     options.op_or_arg_name_mapper = &op_or_arg_name_mapper;
196     if (!tflite::MlirToFlatBufferTranslateFunction(module, options,
197                                                    &pre_quantized_result)) {
198       return statusHandler.ConsumeStatus();
199     }
200     flatbuffers::FlatBufferBuilder q_builder(/*initial_size=*/10240);
201     const uint8_t* buffer =
202         reinterpret_cast<const uint8_t*>(pre_quantized_result.c_str());
203     const ::tflite::Model* input_model = ::tflite::GetModel(buffer);
204 
205     ::tflite::optimize::BufferType quantized_type;
206     if (quant_specs.inference_type == tensorflow::DT_QINT8) {
207       quantized_type = ::tflite::optimize::BufferType::QUANTIZED_INT8;
208     } else if (quant_specs.inference_type == tensorflow::DT_HALF) {
209       quantized_type = ::tflite::optimize::BufferType::QUANTIZED_FLOAT16;
210     } else {
211       return errors::InvalidArgument("Quantized type not supported");
212     }
213     if (::tflite::optimize::QuantizeWeights(&q_builder, input_model,
214                                             quantized_type) != kTfLiteOk) {
215       return errors::InvalidArgument("Quantize weights transformation failed.");
216     }
217     const uint8_t* q_buffer = q_builder.GetBufferPointer();
218     *result =
219         string(reinterpret_cast<const char*>(q_buffer), q_builder.GetSize());
220   }
221 
222   return Status::OK();
223 }
224 
ImportSavedModel(const std::string & input_filename,const int saved_model_version,const std::unordered_set<std::string> & tags,absl::Span<const std::string> extra_tf_opdefs,absl::Span<std::string> exported_names,const GraphImportConfig & specs,mlir::MLIRContext * context)225 StatusOr<mlir::OwningModuleRef> ImportSavedModel(
226     const std::string& input_filename, const int saved_model_version,
227     const std::unordered_set<std::string>& tags,
228     absl::Span<const std::string> extra_tf_opdefs,
229     absl::Span<std::string> exported_names, const GraphImportConfig& specs,
230     mlir::MLIRContext* context) {
231   // Register extra TF ops passed as OpDef.
232   auto extra_opdefs_status = RegisterExtraTfOpDefs(extra_tf_opdefs);
233   if (!extra_opdefs_status.ok()) return extra_opdefs_status;
234 
235   if (saved_model_version == 2) {
236     auto module_or = tensorflow::SavedModelObjectGraphToMlirImport(
237         input_filename, tags, exported_names, context);
238     if (!module_or.status().ok()) return module_or.status();
239     return module_or.ConsumeValueOrDie();
240   } else if (saved_model_version == 1) {
241     MLIRImportOptions options;
242     options.upgrade_legacy = specs.upgrade_legacy;
243     auto module_or = tensorflow::SavedModelSignatureDefsToMlirImport(
244         input_filename, tags, exported_names, context, options);
245 
246     if (!module_or.status().ok()) return module_or.status();
247     return module_or.ConsumeValueOrDie();
248   } else {
249     return tensorflow::errors::InvalidArgument(
250         "Should be either saved model v1 or v2");
251   }
252 }
253 
254 }  // namespace tensorflow
255