1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h"
16
17 #include <memory>
18 #include <utility>
19
20 #include "absl/types/span.h"
21 #include "llvm/ADT/None.h"
22 #include "llvm/ADT/StringSet.h"
23 #include "llvm/Support/ToolOutputFile.h"
24 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
25 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
26 #include "mlir/IR/MLIRContext.h" // from @llvm-project
27 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
28 #include "mlir/Pass/Pass.h" // from @llvm-project
29 #include "mlir/Support/FileUtilities.h" // from @llvm-project
30 #include "mlir/Transforms/ViewOpGraph.h" // from @llvm-project
31 #include "tensorflow/cc/saved_model/loader.h"
32 #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
33 #include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h"
34 #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
35 #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
36 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
37 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
38 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
39 #include "tensorflow/core/framework/graph.pb.h"
40 #include "tensorflow/core/framework/types.pb.h"
41 #include "tensorflow/core/lib/core/errors.h"
42 #include "tensorflow/core/platform/status.h"
43 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
44 #include "tensorflow/lite/toco/model_flags.pb.h"
45 #include "tensorflow/lite/toco/toco_flags.pb.h"
46 #include "tensorflow/lite/toco/types.pb.h"
47 #include "tensorflow/stream_executor/lib/statusor.h"
48
49 namespace tensorflow {
50
HandleInputOutputArraysWithModule(const toco::ModelFlags & model_flags,mlir::OwningModuleRef * module)51 Status HandleInputOutputArraysWithModule(const toco::ModelFlags& model_flags,
52 mlir::OwningModuleRef* module) {
53 mlir::FuncOp entry_function = nullptr;
54 for (auto func : module->get().getOps<mlir::FuncOp>()) {
55 if (auto tf_attrs =
56 func->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function")) {
57 // TODO(b/184697652): There could be multiple entry functions. Let's
58 // handle such cases if there are any needs for that.
59 if (entry_function != nullptr) {
60 return errors::InvalidArgument(
61 "There should be only one tf.entry_function");
62 }
63 entry_function = func;
64 }
65 }
66 if (entry_function == nullptr) {
67 return errors::InvalidArgument("no tf.entry_function found");
68 }
69
70 // Get the list of input Op names from the function attribute.
71 mlir::DictionaryAttr tf_attrs =
72 entry_function->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
73 llvm::SmallVector<llvm::StringRef, 4> function_input_names;
74 function_input_names.reserve(model_flags.input_arrays().size());
75 auto input_attr = tf_attrs.get("inputs");
76 if (!input_attr) {
77 return errors::InvalidArgument("no inputs attribute found");
78 }
79 auto input_names = input_attr.cast<mlir::StringAttr>().getValue();
80 input_names.split(function_input_names, ",");
81 const int function_input_names_size = function_input_names.size();
82 if (function_input_names_size != model_flags.input_arrays().size()) {
83 return errors::InvalidArgument(
84 "input array size mismatch: got ", function_input_names.size(),
85 ", expected: ", model_flags.input_arrays().size());
86 }
87 llvm::StringSet<> function_input_names_set;
88 function_input_names_set.insert(function_input_names.begin(),
89 function_input_names.end());
90 for (const auto& input_array : model_flags.input_arrays()) {
91 if (function_input_names_set.count(input_array.name()) == 0) {
92 return errors::InvalidArgument("input array name (", input_array.name(),
93 ") does not exist in the given graph");
94 }
95 }
96
97 // Get the list of output Op names from the function attribute.
98 llvm::SmallVector<llvm::StringRef, 4> function_output_names;
99 function_output_names.reserve(model_flags.output_arrays().size());
100 auto output_attr = tf_attrs.get("outputs");
101 if (!output_attr) {
102 return errors::InvalidArgument("no outputs attribute found");
103 }
104 auto output_names = output_attr.cast<mlir::StringAttr>().getValue();
105 output_names.split(function_output_names, ",");
106 const int function_output_names_size = function_output_names.size();
107 if (function_output_names_size != model_flags.output_arrays().size()) {
108 return errors::InvalidArgument(
109 "output array size mismatch: got ", function_output_names.size(),
110 ", expected: ", model_flags.output_arrays().size());
111 }
112 llvm::StringSet<> function_output_names_set;
113 function_output_names_set.insert(function_output_names.begin(),
114 function_output_names.end());
115 for (const auto& output_array : model_flags.output_arrays()) {
116 if (function_output_names_set.count(output_array) == 0) {
117 return errors::InvalidArgument("output array name (", output_array,
118 ") does not exist in the given graph");
119 }
120 }
121 return Status::OK();
122 }
123
ConvertSavedModelToTFLiteFlatBuffer(const toco::ModelFlags & model_flags,const toco::TocoFlags & toco_flags,string * result)124 Status ConvertSavedModelToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
125 const toco::TocoFlags& toco_flags,
126 string* result) {
127 mlir::MLIRContext context;
128 mlir::TFL::QuantizationSpecs quant_specs;
129
130 // Parse input arrays.
131 std::vector<string> node_names;
132 std::vector<string> node_dtypes;
133 std::vector<llvm::Optional<std::vector<int>>> node_shapes;
134 std::vector<llvm::Optional<double>> node_mins;
135 std::vector<llvm::Optional<double>> node_maxs;
136
137 // Populate quantization specs.
138 TF_RETURN_IF_ERROR(internal::PopulateQuantizationSpecs(
139 model_flags, toco_flags, &quant_specs, &node_names, &node_dtypes,
140 &node_shapes, &node_mins, &node_maxs));
141
142 internal::WarningUnusedFlags(model_flags, toco_flags);
143
144 // Register all custom ops, including user-specified custom ops.
145 TF_RETURN_IF_ERROR(internal::RegisterAllCustomOps(toco_flags));
146
147 auto& saved_model_tags = model_flags.saved_model_tags();
148 auto& saved_model_exported_names = model_flags.saved_model_exported_names();
149 std::unordered_set<std::string> tags(saved_model_tags.begin(),
150 saved_model_tags.end());
151 auto exported_names_in_vector = std::vector<std::string>(
152 saved_model_exported_names.begin(), saved_model_exported_names.end());
153 absl::Span<std::string> exported_names(exported_names_in_vector);
154
155 if (exported_names.empty()) {
156 return errors::Unimplemented("Need at least one exported name.");
157 }
158
159 tensorflow::GraphImportConfig specs;
160 specs.upgrade_legacy = true;
161
162 std::vector<std::string> custom_opdefs(toco_flags.custom_opdefs().begin(),
163 toco_flags.custom_opdefs().end());
164 auto bundle = std::make_unique<tensorflow::SavedModelBundle>();
165 TF_ASSIGN_OR_RETURN(
166 auto module,
167 ImportSavedModel(
168 model_flags.saved_model_dir(), model_flags.saved_model_version(),
169 tags, absl::MakeSpan(custom_opdefs), exported_names, specs,
170 !toco_flags.enable_tflite_resource_variables(), &context, &bundle));
171
172 if (!model_flags.input_arrays().empty() ||
173 !model_flags.output_arrays().empty()) {
174 TF_RETURN_IF_ERROR(HandleInputOutputArraysWithModule(model_flags, &module));
175 }
176
177 mlir::TFL::PassConfig pass_config(quant_specs);
178 bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
179 pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
180 pass_config.enable_tflite_variables =
181 toco_flags.enable_tflite_resource_variables();
182 pass_config.unfold_batch_matmul = toco_flags.unfold_batchmatmul();
183 pass_config.lower_tensor_list_ops = toco_flags.lower_tensor_list_ops();
184 // Disable the unfolding of the 16x16 TF::BatchMatMulOp to avoid the
185 // conversion to an unsupported 16x16 TFL::FullyConnectedOp.
186 if (toco_flags.inference_type() == toco::IODataType::QUANTIZED_INT16) {
187 pass_config.unfold_batch_matmul = false;
188 }
189 pass_config.unfold_large_splat_constant =
190 toco_flags.unfold_large_splat_constant();
191
192 // TODO(b/153507667): Pass the session object when importing logic is removed.
193 auto status = internal::ConvertMLIRToTFLiteFlatBuffer(
194 model_flags, toco_flags, std::move(module), pass_config, tags, result,
195 bundle ? bundle->GetSession() : nullptr);
196 return status;
197 }
198
199 } // namespace tensorflow
200