• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <iostream>
17 
18 #include "absl/strings/str_split.h"
19 #include "llvm/ADT/None.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "llvm/Support/InitLLVM.h"
27 #include "llvm/Support/SourceMgr.h"
28 #include "llvm/Support/ToolOutputFile.h"
29 #include "llvm/Support/raw_ostream.h"
30 #include "mlir/IR/AsmState.h"  // from @llvm-project
31 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
32 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
33 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
34 #include "mlir/Pass/Pass.h"  // from @llvm-project
35 #include "mlir/Pass/PassManager.h"  // from @llvm-project
36 #include "mlir/Support/FileUtilities.h"  // from @llvm-project
37 #include "tensorflow/cc/saved_model/loader.h"
38 #include "tensorflow/compiler/mlir/init_mlir.h"
39 #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
40 #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
41 #include "tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h"
42 #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
43 #include "tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h"
44 #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
45 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
46 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
47 #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h"
48 #include "tensorflow/core/framework/types.pb.h"
49 #include "tensorflow/core/platform/errors.h"
50 #include "tensorflow/lite/model.h"
51 #include "tensorflow/lite/schema/schema_generated.h"
52 #include "tensorflow/stream_executor/lib/statusor.h"
53 
54 using mlir::FuncOp;
55 using mlir::MLIRContext;
56 using mlir::ModuleOp;
57 using stream_executor::port::StatusOr;
58 
59 // Debugging flag to print function mapping in the flatbuffer.
60 // NOLINTNEXTLINE
61 static llvm::cl::opt<bool> print_function_result_mapping(
62     "print-function-result-mapping",
63     llvm::cl::desc(
64         "Print the mapping of function result to flatbuffer output buffer"),
65     llvm::cl::init(false));
66 
67 // NOLINTNEXTLINE
68 static llvm::cl::opt<std::string> weight_quantization(
69     "weight_quantization",
70     llvm::cl::desc("The type of the quantized weight buffer. Must be NONE, "
71                    "INT8, FLOAT16."),
72     llvm::cl::init("NONE"));
73 
74 enum TranslationStatus { kTrSuccess, kTrFailure };
75 
PrintFunctionResultMapping(const std::string & result,ModuleOp module)76 static int PrintFunctionResultMapping(const std::string &result,
77                                       ModuleOp module) {
78   // Build model from the resultant string to extract the return values from
79   // their source of truth.
80   auto model =
81       tflite::FlatBufferModel::BuildFromBuffer(result.data(), result.size());
82   if (!model) return kTrFailure;
83 
84   // Get an unknown location for where we don't have a terminator to get the
85   // location of the return value from.
86   auto unknown_loc = mlir::UnknownLoc::get(module.getContext());
87 
88   auto print_buffer = [&](const tflite::SubGraph &subgraph, int id, int buffer,
89                           std::function<mlir::Location(int)> loc) {
90     const auto &output_tensor = (*subgraph.tensors())[buffer];
91     std::cout << "\tname: '"
92               << (output_tensor->name() ? output_tensor->name()->str()
93                                         : "<<unnamed>>")
94               << "' buffer: " << buffer;
95     if (loc) std::cout << llvm::formatv(" {0}", loc(id)).str();
96     std::cout << '\n';
97   };
98 
99   // For every subgraph print out the name (if available), each result's output
100   // buffer number and location of the return value (if available).
101   for (auto *subgraph : *(*model)->subgraphs()) {
102     std::string subgraph_name =
103         subgraph->name() ? subgraph->name()->str() : "<<unnamed subgraph>>";
104 
105     std::cout << '\'' << subgraph_name << "' inputs:\n";
106     int i = 0;
107     for (auto input : *subgraph->inputs())
108       print_buffer(*subgraph, i++, input, nullptr);
109 
110     std::cout << '\'' << subgraph_name << "' outputs:\n";
111     mlir::Operation *terminator = nullptr;
112     if (subgraph->name()) {
113       if (auto fn = module.lookupSymbol<FuncOp>(subgraph->name()->str()))
114         terminator = fn.back().getTerminator();
115     }
116     i = 0;
117     for (auto output : *subgraph->outputs()) {
118       print_buffer(*subgraph, i, output, [&](int i) {
119         return terminator ? terminator->getOperand(i).getLoc() : unknown_loc;
120       });
121     }
122   }
123   return kTrSuccess;
124 }
125 
main(int argc,char ** argv)126 int main(int argc, char **argv) {
127   // TODO(jpienaar): Revise the command line option parsing here.
128   tensorflow::InitMlir y(&argc, &argv);
129 
130   // TODO(antiagainst): We are pulling in multiple transformations as follows.
131   // Each transformation has its own set of command-line options; options of one
132   // transformation can essentially be aliases to another. For example, the
133   // -tfl-annotate-inputs has -tfl-input-arrays, -tfl-input-data-types, and
134   // -tfl-input-shapes, which are the same as -graphdef-to-mlir transformation's
135   // -tf_input_arrays, -tf_input_data_types, and -tf_input_shapes, respectively.
136   // We need to disable duplicated ones to provide a cleaner command-line option
137   // interface. That also means we need to relay the value set in one option to
138   // all its aliases.
139   mlir::registerAsmPrinterCLOptions();
140   mlir::registerMLIRContextCLOptions();
141   mlir::registerPassManagerCLOptions();
142   llvm::cl::ParseCommandLineOptions(
143       argc, argv, "TF GraphDef to TFLite FlatBuffer converter\n");
144 
145   MLIRContext context;
146   llvm::SourceMgr source_mgr;
147   mlir::SourceMgrDiagnosticHandler sourceMgrHandler(source_mgr, &context);
148 
149   StatusOr<mlir::OwningModuleRef> module;
150   std::unordered_set<std::string> tags;
151 
152   tensorflow::GraphImportConfig specs;
153   specs.upgrade_legacy = upgrade_legacy;
154   specs.prune_unused_nodes = true;
155 
156   if (!select_user_tf_ops.empty() && !emit_select_tf_ops) {
157     llvm::errs() << "You must specify `emit-select-tf-ops=true` when passing "
158                     "`select-user-tf-ops` flag.";
159     return kTrFailure;
160   }
161 
162   std::unique_ptr<tensorflow::SavedModelBundle> bundle;
163 
164   // TODO(b/147435528): We need to test the e2e behavior once the graph freezing
165   // inside mlir is done.
166   if (import_saved_model_object_graph || import_saved_model_signature_defs) {
167     int saved_model_version;
168     if (import_saved_model_object_graph) {
169       saved_model_version = 2;
170     } else {
171       saved_model_version = 1;
172     }
173     if (input_mlir)
174       module = tensorflow::errors::InvalidArgument(
175           "Importing saved model should not have input_mlir set");
176 
177     tags = absl::StrSplit(saved_model_tags, ',');
178     std::vector<std::string> exported_names_vector =
179         absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
180     absl::Span<std::string> exported_names(exported_names_vector);
181 
182     if (exported_names.size() != 1) {
183       llvm::errs() << "There should be only one exported name";
184       return kTrFailure;
185     }
186     std::vector<std::string> extra_opdefs(custom_opdefs.begin(),
187                                           custom_opdefs.end());
188     module = tensorflow::ImportSavedModel(
189         input_file_name, saved_model_version, tags, extra_opdefs,
190         exported_names, specs, /*enable_variable_lifting=*/true, &context,
191         &bundle);
192   } else {
193     module = tensorflow::LoadFromGraphdefOrMlirSource(
194         input_file_name, input_mlir, use_splatted_constant, custom_opdefs,
195         specs, debug_info_file, input_arrays, input_dtypes, input_shapes,
196         output_arrays, control_output_arrays, &source_mgr, &context);
197   }
198 
199   // If errors occur, the library call in the above already logged the error
200   // message. So we can just return here.
201   if (!module.ok()) return kTrFailure;
202 
203   mlir::PassManager pm(&context, mlir::OpPassManager::Nesting::Implicit);
204   mlir::applyPassManagerCLOptions(pm);
205 
206   // Set the quantization specifications from the command line flags.
207   mlir::TFL::QuantizationSpecs quant_specs;
208   if (mlir::TFL::ParseInputNodeQuantSpecs(input_arrays, min_values, max_values,
209                                           inference_type, &quant_specs)) {
210     llvm::errs() << "Failed to get input quant spec.";
211     return kTrFailure;
212   }
213   if (weight_quantization != "NONE") {
214     quant_specs.weight_quantization = true;
215     if (weight_quantization == "INT8") {
216       quant_specs.inference_type = tensorflow::DT_QINT8;
217     } else if (weight_quantization == "FLOAT16") {
218       quant_specs.inference_type = tensorflow::DT_HALF;
219     } else {
220       llvm::errs() << "Unknown weight quantization " << weight_quantization;
221       return kTrFailure;
222     }
223   }
224   if (!emit_quant_adaptor_ops) {
225     quant_specs.inference_input_type = quant_specs.inference_type;
226   }
227 
228   if (!quant_stats_file_name.empty()) {
229     std::string error_message;
230     auto file = mlir::openInputFile(quant_stats_file_name, &error_message);
231     if (!file) {
232       llvm::errs() << "fail to open quant stats file: "
233                    << quant_stats_file_name;
234       return kTrFailure;
235     }
236     quant_specs.serialized_quant_stats = file->getBuffer().str();
237   }
238 
239   mlir::TFL::PassConfig pass_config(quant_specs);
240   pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
241   pass_config.lower_tensor_list_ops = lower_tensor_list_ops;
242   pass_config.legalize_tf_while = convert_tf_while_to_tfl_while;
243   pass_config.unfold_batch_matmul = unfold_batchmatmul;
244   pass_config.unfold_large_splat_constant = unfold_large_splat_constant;
245   pass_config.guarantee_all_funcs_one_use = guarantee_all_funcs_one_use;
246 
247   // TODO(b/153507667): Pass the session object when importing logic is removed.
248   tensorflow::AddTFToTFLConversionPasses(pass_config, &pm,
249                                          /*session=*/llvm::None);
250   // TODO(b/150901738): Move those into tf_tfl_translate.cc.
251   // Convert back to outlined while format for export back to flatbuffer.
252   if (pass_config.legalize_tf_while) {
253     pm.addPass(mlir::TFL::CreateWhileOutlinePass());
254   }
255   pm.addPass(mlir::TFL::CreateRuntimeVerifyPass());
256 
257   // Read list of user select ops.
258   std::unordered_set<std::string> select_user_ops_set;
259   llvm::SmallVector<llvm::StringRef, 2> user_ops;
260   (llvm::StringRef(select_user_tf_ops))
261       .split(user_ops, ',', /*MaxSplit=*/-1,
262              /*KeepEmpty=*/false);
263   llvm::for_each(user_ops, [&select_user_ops_set](llvm::StringRef op_name) {
264     select_user_ops_set.insert(op_name.str());
265   });
266 
267   std::string result;
268   auto status = tensorflow::ConvertTFExecutorToTFLOrFlatbuffer(
269       module.ValueOrDie().get(), output_mlir, emit_builtin_tflite_ops,
270       emit_select_tf_ops, emit_custom_ops, allow_all_select_tf_ops,
271       select_user_ops_set, quant_specs, tags, &result, &pm);
272   if (!status.ok()) return kTrFailure;
273 
274   std::string error_msg;
275   auto output = mlir::openOutputFile(output_file_name, &error_msg);
276   if (output == nullptr) {
277     llvm::errs() << error_msg << '\n';
278     return kTrFailure;
279   }
280   output->os() << result;
281   output->keep();
282 
283   // Print out debugging info related to function mapping.
284   if (print_function_result_mapping)
285     return PrintFunctionResultMapping(result, module.ValueOrDie().get());
286   return kTrSuccess;
287 }
288