• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h"
16 
17 #include <memory>
18 #include <utility>
19 
20 #include "absl/types/span.h"
21 #include "llvm/ADT/None.h"
22 #include "llvm/ADT/StringSet.h"
23 #include "llvm/Support/ToolOutputFile.h"
24 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
26 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
27 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
28 #include "mlir/Pass/Pass.h"  // from @llvm-project
29 #include "mlir/Support/FileUtilities.h"  // from @llvm-project
30 #include "mlir/Transforms/ViewOpGraph.h"  // from @llvm-project
31 #include "tensorflow/cc/saved_model/loader.h"
32 #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
33 #include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h"
34 #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
35 #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
36 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
37 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
38 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
39 #include "tensorflow/core/framework/graph.pb.h"
40 #include "tensorflow/core/framework/types.pb.h"
41 #include "tensorflow/core/lib/core/errors.h"
42 #include "tensorflow/core/platform/status.h"
43 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
44 #include "tensorflow/lite/toco/model_flags.pb.h"
45 #include "tensorflow/lite/toco/toco_flags.pb.h"
46 #include "tensorflow/lite/toco/types.pb.h"
47 #include "tensorflow/stream_executor/lib/statusor.h"
48 
49 namespace tensorflow {
50 
HandleInputOutputArraysWithModule(const toco::ModelFlags & model_flags,mlir::OwningModuleRef * module)51 Status HandleInputOutputArraysWithModule(const toco::ModelFlags& model_flags,
52                                          mlir::OwningModuleRef* module) {
53   mlir::FuncOp entry_function = nullptr;
54   for (auto func : module->get().getOps<mlir::FuncOp>()) {
55     if (auto tf_attrs =
56             func->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function")) {
57       // TODO(b/184697652): There could be multiple entry functions. Let's
58       // handle such cases if there are any needs for that.
59       if (entry_function != nullptr) {
60         return errors::InvalidArgument(
61             "There should be only one tf.entry_function");
62       }
63       entry_function = func;
64     }
65   }
66   if (entry_function == nullptr) {
67     return errors::InvalidArgument("no tf.entry_function found");
68   }
69 
70   // Get the list of input Op names from the function attribute.
71   mlir::DictionaryAttr tf_attrs =
72       entry_function->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
73   llvm::SmallVector<llvm::StringRef, 4> function_input_names;
74   function_input_names.reserve(model_flags.input_arrays().size());
75   auto input_attr = tf_attrs.get("inputs");
76   if (!input_attr) {
77     return errors::InvalidArgument("no inputs attribute found");
78   }
79   auto input_names = input_attr.cast<mlir::StringAttr>().getValue();
80   input_names.split(function_input_names, ",");
81   const int function_input_names_size = function_input_names.size();
82   if (function_input_names_size != model_flags.input_arrays().size()) {
83     return errors::InvalidArgument(
84         "input array size mismatch: got ", function_input_names.size(),
85         ", expected: ", model_flags.input_arrays().size());
86   }
87   llvm::StringSet<> function_input_names_set;
88   function_input_names_set.insert(function_input_names.begin(),
89                                   function_input_names.end());
90   for (const auto& input_array : model_flags.input_arrays()) {
91     if (function_input_names_set.count(input_array.name()) == 0) {
92       return errors::InvalidArgument("input array name (", input_array.name(),
93                                      ") does not exist in the given graph");
94     }
95   }
96 
97   // Get the list of output Op names from the function attribute.
98   llvm::SmallVector<llvm::StringRef, 4> function_output_names;
99   function_output_names.reserve(model_flags.output_arrays().size());
100   auto output_attr = tf_attrs.get("outputs");
101   if (!output_attr) {
102     return errors::InvalidArgument("no outputs attribute found");
103   }
104   auto output_names = output_attr.cast<mlir::StringAttr>().getValue();
105   output_names.split(function_output_names, ",");
106   const int function_output_names_size = function_output_names.size();
107   if (function_output_names_size != model_flags.output_arrays().size()) {
108     return errors::InvalidArgument(
109         "output array size mismatch: got ", function_output_names.size(),
110         ", expected: ", model_flags.output_arrays().size());
111   }
112   llvm::StringSet<> function_output_names_set;
113   function_output_names_set.insert(function_output_names.begin(),
114                                    function_output_names.end());
115   for (const auto& output_array : model_flags.output_arrays()) {
116     if (function_output_names_set.count(output_array) == 0) {
117       return errors::InvalidArgument("output array name (", output_array,
118                                      ") does not exist in the given graph");
119     }
120   }
121   return Status::OK();
122 }
123 
ConvertSavedModelToTFLiteFlatBuffer(const toco::ModelFlags & model_flags,const toco::TocoFlags & toco_flags,string * result)124 Status ConvertSavedModelToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
125                                            const toco::TocoFlags& toco_flags,
126                                            string* result) {
127   mlir::MLIRContext context;
128   mlir::TFL::QuantizationSpecs quant_specs;
129 
130   // Parse input arrays.
131   std::vector<string> node_names;
132   std::vector<string> node_dtypes;
133   std::vector<llvm::Optional<std::vector<int>>> node_shapes;
134   std::vector<llvm::Optional<double>> node_mins;
135   std::vector<llvm::Optional<double>> node_maxs;
136 
137   // Populate quantization specs.
138   TF_RETURN_IF_ERROR(internal::PopulateQuantizationSpecs(
139       model_flags, toco_flags, &quant_specs, &node_names, &node_dtypes,
140       &node_shapes, &node_mins, &node_maxs));
141 
142   internal::WarningUnusedFlags(model_flags, toco_flags);
143 
144   // Register all custom ops, including user-specified custom ops.
145   TF_RETURN_IF_ERROR(internal::RegisterAllCustomOps(toco_flags));
146 
147   auto& saved_model_tags = model_flags.saved_model_tags();
148   auto& saved_model_exported_names = model_flags.saved_model_exported_names();
149   std::unordered_set<std::string> tags(saved_model_tags.begin(),
150                                        saved_model_tags.end());
151   auto exported_names_in_vector = std::vector<std::string>(
152       saved_model_exported_names.begin(), saved_model_exported_names.end());
153   absl::Span<std::string> exported_names(exported_names_in_vector);
154 
155   if (exported_names.empty()) {
156     return errors::Unimplemented("Need at least one exported name.");
157   }
158 
159   tensorflow::GraphImportConfig specs;
160   specs.upgrade_legacy = true;
161 
162   std::vector<std::string> custom_opdefs(toco_flags.custom_opdefs().begin(),
163                                          toco_flags.custom_opdefs().end());
164   auto bundle = std::make_unique<tensorflow::SavedModelBundle>();
165   TF_ASSIGN_OR_RETURN(
166       auto module,
167       ImportSavedModel(
168           model_flags.saved_model_dir(), model_flags.saved_model_version(),
169           tags, absl::MakeSpan(custom_opdefs), exported_names, specs,
170           !toco_flags.enable_tflite_resource_variables(), &context, &bundle));
171 
172   if (!model_flags.input_arrays().empty() ||
173       !model_flags.output_arrays().empty()) {
174     TF_RETURN_IF_ERROR(HandleInputOutputArraysWithModule(model_flags, &module));
175   }
176 
177   mlir::TFL::PassConfig pass_config(quant_specs);
178   bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
179   pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
180   pass_config.enable_tflite_variables =
181       toco_flags.enable_tflite_resource_variables();
182   pass_config.unfold_batch_matmul = toco_flags.unfold_batchmatmul();
183   pass_config.lower_tensor_list_ops = toco_flags.lower_tensor_list_ops();
184   // Disable the unfolding of the 16x16 TF::BatchMatMulOp to avoid the
185   // conversion to an unsupported 16x16 TFL::FullyConnectedOp.
186   if (toco_flags.inference_type() == toco::IODataType::QUANTIZED_INT16) {
187     pass_config.unfold_batch_matmul = false;
188   }
189   pass_config.unfold_large_splat_constant =
190       toco_flags.unfold_large_splat_constant();
191 
192   // TODO(b/153507667): Pass the session object when importing logic is removed.
193   auto status = internal::ConvertMLIRToTFLiteFlatBuffer(
194       model_flags, toco_flags, std::move(module), pass_config, tags, result,
195       bundle ? bundle->GetSession() : nullptr);
196   return status;
197 }
198 
199 }  // namespace tensorflow
200