1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h"
16
17 #include <utility>
18
19 #include "absl/types/span.h"
20 #include "llvm/ADT/None.h"
21 #include "llvm/ADT/StringSet.h"
22 #include "llvm/Support/ToolOutputFile.h"
23 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
24 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
25 #include "mlir/IR/MLIRContext.h" // from @llvm-project
26 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
27 #include "mlir/Pass/Pass.h" // from @llvm-project
28 #include "mlir/Support/FileUtilities.h" // from @llvm-project
29 #include "mlir/Transforms/ViewOpGraph.h" // from @llvm-project
30 #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
31 #include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h"
32 #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
33 #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
34 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
35 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
36 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
37 #include "tensorflow/core/framework/graph.pb.h"
38 #include "tensorflow/core/framework/types.pb.h"
39 #include "tensorflow/core/lib/core/errors.h"
40 #include "tensorflow/core/platform/status.h"
41 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
42 #include "tensorflow/lite/toco/model_flags.pb.h"
43 #include "tensorflow/lite/toco/toco_flags.pb.h"
44 #include "tensorflow/lite/toco/types.pb.h"
45 #include "tensorflow/stream_executor/lib/statusor.h"
46
47 namespace tensorflow {
48
HandleInputOutputArraysWithModule(const toco::ModelFlags & model_flags,mlir::OwningModuleRef * module)49 Status HandleInputOutputArraysWithModule(const toco::ModelFlags& model_flags,
50 mlir::OwningModuleRef* module) {
51 mlir::FuncOp entry_function = nullptr;
52 for (auto func : module->get().getOps<mlir::FuncOp>()) {
53 if (auto tf_attrs =
54 func->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function")) {
55 // TODO(jaesung): There could be multiple entry functions. Let's handle
56 // such cases if there are any needs for that.
57 if (entry_function != nullptr) {
58 return errors::InvalidArgument(
59 "There should be only one tf.entry_function");
60 }
61 entry_function = func;
62 }
63 }
64 if (entry_function == nullptr) {
65 return errors::InvalidArgument("no tf.entry_function found");
66 }
67
68 // Get the list of input Op names from the function attribute.
69 mlir::DictionaryAttr tf_attrs =
70 entry_function->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
71 llvm::SmallVector<llvm::StringRef, 4> function_input_names;
72 function_input_names.reserve(model_flags.input_arrays().size());
73 auto input_attr = tf_attrs.get("inputs");
74 if (!input_attr) {
75 return errors::InvalidArgument("no inputs attribute found");
76 }
77 auto input_names = input_attr.cast<mlir::StringAttr>().getValue();
78 input_names.split(function_input_names, ",");
79 const int function_input_names_size = function_input_names.size();
80 if (function_input_names_size != model_flags.input_arrays().size()) {
81 return errors::InvalidArgument(
82 "input array size mismatch: got ", function_input_names.size(),
83 ", expected: ", model_flags.input_arrays().size());
84 }
85 llvm::StringSet<> function_input_names_set;
86 function_input_names_set.insert(function_input_names.begin(),
87 function_input_names.end());
88 for (const auto& input_array : model_flags.input_arrays()) {
89 if (function_input_names_set.count(input_array.name()) == 0) {
90 return errors::InvalidArgument("input array name (", input_array.name(),
91 ") does not exist in the given graph");
92 }
93 }
94
95 // Get the list of output Op names from the function attribute.
96 llvm::SmallVector<llvm::StringRef, 4> function_output_names;
97 function_output_names.reserve(model_flags.output_arrays().size());
98 auto output_attr = tf_attrs.get("outputs");
99 if (!output_attr) {
100 return errors::InvalidArgument("no outputs attribute found");
101 }
102 auto output_names = output_attr.cast<mlir::StringAttr>().getValue();
103 output_names.split(function_output_names, ",");
104 const int function_output_names_size = function_output_names.size();
105 if (function_output_names_size != model_flags.output_arrays().size()) {
106 return errors::InvalidArgument(
107 "output array size mismatch: got ", function_output_names.size(),
108 ", expected: ", model_flags.output_arrays().size());
109 }
110 llvm::StringSet<> function_output_names_set;
111 function_output_names_set.insert(function_output_names.begin(),
112 function_output_names.end());
113 for (const auto& output_array : model_flags.output_arrays()) {
114 if (function_output_names_set.count(output_array) == 0) {
115 return errors::InvalidArgument("output array name (", output_array,
116 ") does not exist in the given graph");
117 }
118 }
119 return Status::OK();
120 }
121
ConvertSavedModelToTFLiteFlatBuffer(const toco::ModelFlags & model_flags,const toco::TocoFlags & toco_flags,string * result)122 Status ConvertSavedModelToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
123 const toco::TocoFlags& toco_flags,
124 string* result) {
125 mlir::MLIRContext context;
126 mlir::TFL::QuantizationSpecs quant_specs;
127
128 // Parse input arrays.
129 std::vector<string> node_names;
130 std::vector<string> node_dtypes;
131 std::vector<llvm::Optional<std::vector<int>>> node_shapes;
132 std::vector<llvm::Optional<double>> node_mins;
133 std::vector<llvm::Optional<double>> node_maxs;
134
135 // Populate quantization specs.
136 TF_RETURN_IF_ERROR(internal::PopulateQuantizationSpecs(
137 model_flags, toco_flags, &quant_specs, &node_names, &node_dtypes,
138 &node_shapes, &node_mins, &node_maxs));
139
140 internal::WarningUnusedFlags(model_flags, toco_flags);
141
142 // Register all custom ops, including user-specified custom ops.
143 TF_RETURN_IF_ERROR(internal::RegisterAllCustomOps(toco_flags));
144
145 auto& saved_model_tags = model_flags.saved_model_tags();
146 auto& saved_model_exported_names = model_flags.saved_model_exported_names();
147 std::unordered_set<std::string> tags(saved_model_tags.begin(),
148 saved_model_tags.end());
149 auto exported_names_in_vector = std::vector<std::string>(
150 saved_model_exported_names.begin(), saved_model_exported_names.end());
151 absl::Span<std::string> exported_names(exported_names_in_vector);
152
153 if (exported_names.size() != 1) {
154 return errors::Unimplemented("Only support a single exported name.");
155 }
156
157 tensorflow::GraphImportConfig specs;
158 specs.upgrade_legacy = true;
159
160 std::vector<std::string> custom_opdefs(toco_flags.custom_opdefs().begin(),
161 toco_flags.custom_opdefs().end());
162 TF_ASSIGN_OR_RETURN(auto module,
163 ImportSavedModel(model_flags.saved_model_dir(),
164 model_flags.saved_model_version(), tags,
165 absl::MakeSpan(custom_opdefs),
166 exported_names, specs, &context));
167
168 if (!model_flags.input_arrays().empty() ||
169 !model_flags.output_arrays().empty()) {
170 TF_RETURN_IF_ERROR(HandleInputOutputArraysWithModule(model_flags, &module));
171 }
172
173 mlir::TFL::PassConfig pass_config(quant_specs);
174 bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
175 pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
176 pass_config.lower_tensor_list_ops = true;
177 pass_config.enable_tflite_variables =
178 toco_flags.enable_tflite_resource_variables();
179 // Disable the unfolding of the 16x16 TF::BatchMatMulOp to avoid the
180 // conversion to an unsupported 16x16 TFL::FullyConnectedOp.
181 if (toco_flags.inference_type() == toco::IODataType::QUANTIZED_INT16) {
182 pass_config.unfold_batch_matmul = false;
183 }
184
185 // TODO(b/153507667): Pass the session object when importing logic is removed.
186 auto status = internal::ConvertMLIRToTFLiteFlatBuffer(
187 toco_flags, std::move(module), pass_config, tags, result,
188 /*session=*/llvm::None);
189 return status;
190 }
191
192 } // namespace tensorflow
193