1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <iostream>
17
18 #include "absl/strings/str_split.h"
19 #include "llvm/ADT/None.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "llvm/Support/InitLLVM.h"
27 #include "llvm/Support/SourceMgr.h"
28 #include "llvm/Support/ToolOutputFile.h"
29 #include "llvm/Support/raw_ostream.h"
30 #include "mlir/IR/AsmState.h" // from @llvm-project
31 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
32 #include "mlir/IR/Diagnostics.h" // from @llvm-project
33 #include "mlir/IR/MLIRContext.h" // from @llvm-project
34 #include "mlir/Pass/Pass.h" // from @llvm-project
35 #include "mlir/Pass/PassManager.h" // from @llvm-project
36 #include "mlir/Support/FileUtilities.h" // from @llvm-project
37 #include "tensorflow/cc/saved_model/loader.h"
38 #include "tensorflow/compiler/mlir/init_mlir.h"
39 #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
40 #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
41 #include "tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h"
42 #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
43 #include "tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h"
44 #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
45 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
46 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
47 #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h"
48 #include "tensorflow/core/framework/types.pb.h"
49 #include "tensorflow/core/platform/errors.h"
50 #include "tensorflow/lite/model.h"
51 #include "tensorflow/lite/schema/schema_generated.h"
52 #include "tensorflow/stream_executor/lib/statusor.h"
53
54 using mlir::FuncOp;
55 using mlir::MLIRContext;
56 using mlir::ModuleOp;
57 using stream_executor::port::StatusOr;
58
59 // Debugging flag to print function mapping in the flatbuffer.
60 // NOLINTNEXTLINE
61 static llvm::cl::opt<bool> print_function_result_mapping(
62 "print-function-result-mapping",
63 llvm::cl::desc(
64 "Print the mapping of function result to flatbuffer output buffer"),
65 llvm::cl::init(false));
66
67 // NOLINTNEXTLINE
68 static llvm::cl::opt<std::string> weight_quantization(
69 "weight_quantization",
70 llvm::cl::desc("The type of the quantized weight buffer. Must be NONE, "
71 "INT8, FLOAT16."),
72 llvm::cl::init("NONE"));
73
74 enum TranslationStatus { kTrSuccess, kTrFailure };
75
PrintFunctionResultMapping(const std::string & result,ModuleOp module)76 static int PrintFunctionResultMapping(const std::string &result,
77 ModuleOp module) {
78 // Build model from the resultant string to extract the return values from
79 // their source of truth.
80 auto model =
81 tflite::FlatBufferModel::BuildFromBuffer(result.data(), result.size());
82 if (!model) return kTrFailure;
83
84 // Get an unknown location for where we don't have a terminator to get the
85 // location of the return value from.
86 auto unknown_loc = mlir::UnknownLoc::get(module.getContext());
87
88 auto print_buffer = [&](const tflite::SubGraph &subgraph, int id, int buffer,
89 std::function<mlir::Location(int)> loc) {
90 const auto &output_tensor = (*subgraph.tensors())[buffer];
91 std::cout << "\tname: '"
92 << (output_tensor->name() ? output_tensor->name()->str()
93 : "<<unnamed>>")
94 << "' buffer: " << buffer;
95 if (loc) std::cout << llvm::formatv(" {0}", loc(id)).str();
96 std::cout << '\n';
97 };
98
99 // For every subgraph print out the name (if available), each result's output
100 // buffer number and location of the return value (if available).
101 for (auto *subgraph : *(*model)->subgraphs()) {
102 std::string subgraph_name =
103 subgraph->name() ? subgraph->name()->str() : "<<unnamed subgraph>>";
104
105 std::cout << '\'' << subgraph_name << "' inputs:\n";
106 int i = 0;
107 for (auto input : *subgraph->inputs())
108 print_buffer(*subgraph, i++, input, nullptr);
109
110 std::cout << '\'' << subgraph_name << "' outputs:\n";
111 mlir::Operation *terminator = nullptr;
112 if (subgraph->name()) {
113 if (auto fn = module.lookupSymbol<FuncOp>(subgraph->name()->str()))
114 terminator = fn.back().getTerminator();
115 }
116 i = 0;
117 for (auto output : *subgraph->outputs()) {
118 print_buffer(*subgraph, i, output, [&](int i) {
119 return terminator ? terminator->getOperand(i).getLoc() : unknown_loc;
120 });
121 }
122 }
123 return kTrSuccess;
124 }
125
main(int argc,char ** argv)126 int main(int argc, char **argv) {
127 // TODO(jpienaar): Revise the command line option parsing here.
128 tensorflow::InitMlir y(&argc, &argv);
129
130 // TODO(antiagainst): We are pulling in multiple transformations as follows.
131 // Each transformation has its own set of command-line options; options of one
132 // transformation can essentially be aliases to another. For example, the
133 // -tfl-annotate-inputs has -tfl-input-arrays, -tfl-input-data-types, and
134 // -tfl-input-shapes, which are the same as -graphdef-to-mlir transformation's
135 // -tf_input_arrays, -tf_input_data_types, and -tf_input_shapes, respectively.
136 // We need to disable duplicated ones to provide a cleaner command-line option
137 // interface. That also means we need to relay the value set in one option to
138 // all its aliases.
139 mlir::registerAsmPrinterCLOptions();
140 mlir::registerMLIRContextCLOptions();
141 mlir::registerPassManagerCLOptions();
142 llvm::cl::ParseCommandLineOptions(
143 argc, argv, "TF GraphDef to TFLite FlatBuffer converter\n");
144
145 MLIRContext context;
146 llvm::SourceMgr source_mgr;
147 mlir::SourceMgrDiagnosticHandler sourceMgrHandler(source_mgr, &context);
148
149 StatusOr<mlir::OwningModuleRef> module;
150 std::unordered_set<std::string> tags;
151
152 tensorflow::GraphImportConfig specs;
153 specs.upgrade_legacy = upgrade_legacy;
154 specs.prune_unused_nodes = true;
155
156 if (!select_user_tf_ops.empty() && !emit_select_tf_ops) {
157 llvm::errs() << "You must specify `emit-select-tf-ops=true` when passing "
158 "`select-user-tf-ops` flag.";
159 return kTrFailure;
160 }
161
162 std::unique_ptr<tensorflow::SavedModelBundle> bundle;
163
164 // TODO(b/147435528): We need to test the e2e behavior once the graph freezing
165 // inside mlir is done.
166 if (import_saved_model_object_graph || import_saved_model_signature_defs) {
167 int saved_model_version;
168 if (import_saved_model_object_graph) {
169 saved_model_version = 2;
170 } else {
171 saved_model_version = 1;
172 }
173 if (input_mlir)
174 module = tensorflow::errors::InvalidArgument(
175 "Importing saved model should not have input_mlir set");
176
177 tags = absl::StrSplit(saved_model_tags, ',');
178 std::vector<std::string> exported_names_vector =
179 absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
180 absl::Span<std::string> exported_names(exported_names_vector);
181
182 if (exported_names.size() != 1) {
183 llvm::errs() << "There should be only one exported name";
184 return kTrFailure;
185 }
186 std::vector<std::string> extra_opdefs(custom_opdefs.begin(),
187 custom_opdefs.end());
188 module = tensorflow::ImportSavedModel(
189 input_file_name, saved_model_version, tags, extra_opdefs,
190 exported_names, specs, /*enable_variable_lifting=*/true, &context,
191 &bundle);
192 } else {
193 module = tensorflow::LoadFromGraphdefOrMlirSource(
194 input_file_name, input_mlir, use_splatted_constant, custom_opdefs,
195 specs, debug_info_file, input_arrays, input_dtypes, input_shapes,
196 output_arrays, control_output_arrays, &source_mgr, &context);
197 }
198
199 // If errors occur, the library call in the above already logged the error
200 // message. So we can just return here.
201 if (!module.ok()) return kTrFailure;
202
203 mlir::PassManager pm(&context, mlir::OpPassManager::Nesting::Implicit);
204 mlir::applyPassManagerCLOptions(pm);
205
206 // Set the quantization specifications from the command line flags.
207 mlir::TFL::QuantizationSpecs quant_specs;
208 if (mlir::TFL::ParseInputNodeQuantSpecs(input_arrays, min_values, max_values,
209 inference_type, &quant_specs)) {
210 llvm::errs() << "Failed to get input quant spec.";
211 return kTrFailure;
212 }
213 if (weight_quantization != "NONE") {
214 quant_specs.weight_quantization = true;
215 if (weight_quantization == "INT8") {
216 quant_specs.inference_type = tensorflow::DT_QINT8;
217 } else if (weight_quantization == "FLOAT16") {
218 quant_specs.inference_type = tensorflow::DT_HALF;
219 } else {
220 llvm::errs() << "Unknown weight quantization " << weight_quantization;
221 return kTrFailure;
222 }
223 }
224 if (!emit_quant_adaptor_ops) {
225 quant_specs.inference_input_type = quant_specs.inference_type;
226 }
227
228 if (!quant_stats_file_name.empty()) {
229 std::string error_message;
230 auto file = mlir::openInputFile(quant_stats_file_name, &error_message);
231 if (!file) {
232 llvm::errs() << "fail to open quant stats file: "
233 << quant_stats_file_name;
234 return kTrFailure;
235 }
236 quant_specs.serialized_quant_stats = file->getBuffer().str();
237 }
238
239 mlir::TFL::PassConfig pass_config(quant_specs);
240 pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
241 pass_config.lower_tensor_list_ops = lower_tensor_list_ops;
242 pass_config.legalize_tf_while = convert_tf_while_to_tfl_while;
243 pass_config.unfold_batch_matmul = unfold_batchmatmul;
244 pass_config.unfold_large_splat_constant = unfold_large_splat_constant;
245 pass_config.guarantee_all_funcs_one_use = guarantee_all_funcs_one_use;
246
247 // TODO(b/153507667): Pass the session object when importing logic is removed.
248 tensorflow::AddTFToTFLConversionPasses(pass_config, &pm,
249 /*session=*/llvm::None);
250 // TODO(b/150901738): Move those into tf_tfl_translate.cc.
251 // Convert back to outlined while format for export back to flatbuffer.
252 if (pass_config.legalize_tf_while) {
253 pm.addPass(mlir::TFL::CreateWhileOutlinePass());
254 }
255 pm.addPass(mlir::TFL::CreateRuntimeVerifyPass());
256
257 // Read list of user select ops.
258 std::unordered_set<std::string> select_user_ops_set;
259 llvm::SmallVector<llvm::StringRef, 2> user_ops;
260 (llvm::StringRef(select_user_tf_ops))
261 .split(user_ops, ',', /*MaxSplit=*/-1,
262 /*KeepEmpty=*/false);
263 llvm::for_each(user_ops, [&select_user_ops_set](llvm::StringRef op_name) {
264 select_user_ops_set.insert(op_name.str());
265 });
266
267 std::string result;
268 auto status = tensorflow::ConvertTFExecutorToTFLOrFlatbuffer(
269 module.ValueOrDie().get(), output_mlir, emit_builtin_tflite_ops,
270 emit_select_tf_ops, emit_custom_ops, allow_all_select_tf_ops,
271 select_user_ops_set, quant_specs, tags, &result, &pm);
272 if (!status.ok()) return kTrFailure;
273
274 std::string error_msg;
275 auto output = mlir::openOutputFile(output_file_name, &error_msg);
276 if (output == nullptr) {
277 llvm::errs() << error_msg << '\n';
278 return kTrFailure;
279 }
280 output->os() << result;
281 output->keep();
282
283 // Print out debugging info related to function mapping.
284 if (print_function_result_mapping)
285 return PrintFunctionResultMapping(result, module.ValueOrDie().get());
286 return kTrSuccess;
287 }
288