• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h"
16 
17 #include <utility>
18 
19 #include "absl/types/span.h"
20 #include "llvm/ADT/None.h"
21 #include "llvm/ADT/StringSet.h"
22 #include "llvm/Support/ToolOutputFile.h"
23 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
25 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
26 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
27 #include "mlir/Pass/Pass.h"  // from @llvm-project
28 #include "mlir/Support/FileUtilities.h"  // from @llvm-project
29 #include "mlir/Transforms/ViewOpGraph.h"  // from @llvm-project
30 #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
31 #include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h"
32 #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
33 #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
34 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
35 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
36 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
37 #include "tensorflow/core/framework/graph.pb.h"
38 #include "tensorflow/core/framework/types.pb.h"
39 #include "tensorflow/core/lib/core/errors.h"
40 #include "tensorflow/core/platform/status.h"
41 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
42 #include "tensorflow/lite/toco/model_flags.pb.h"
43 #include "tensorflow/lite/toco/toco_flags.pb.h"
44 #include "tensorflow/lite/toco/types.pb.h"
45 #include "tensorflow/stream_executor/lib/statusor.h"
46 
47 namespace tensorflow {
48 
HandleInputOutputArraysWithModule(const toco::ModelFlags & model_flags,mlir::OwningModuleRef * module)49 Status HandleInputOutputArraysWithModule(const toco::ModelFlags& model_flags,
50                                          mlir::OwningModuleRef* module) {
51   mlir::FuncOp entry_function = nullptr;
52   for (auto func : module->get().getOps<mlir::FuncOp>()) {
53     if (auto tf_attrs =
54             func->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function")) {
55       // TODO(jaesung): There could be multiple entry functions. Let's handle
56       // such cases if there are any needs for that.
57       if (entry_function != nullptr) {
58         return errors::InvalidArgument(
59             "There should be only one tf.entry_function");
60       }
61       entry_function = func;
62     }
63   }
64   if (entry_function == nullptr) {
65     return errors::InvalidArgument("no tf.entry_function found");
66   }
67 
68   // Get the list of input Op names from the function attribute.
69   mlir::DictionaryAttr tf_attrs =
70       entry_function->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
71   llvm::SmallVector<llvm::StringRef, 4> function_input_names;
72   function_input_names.reserve(model_flags.input_arrays().size());
73   auto input_attr = tf_attrs.get("inputs");
74   if (!input_attr) {
75     return errors::InvalidArgument("no inputs attribute found");
76   }
77   auto input_names = input_attr.cast<mlir::StringAttr>().getValue();
78   input_names.split(function_input_names, ",");
79   const int function_input_names_size = function_input_names.size();
80   if (function_input_names_size != model_flags.input_arrays().size()) {
81     return errors::InvalidArgument(
82         "input array size mismatch: got ", function_input_names.size(),
83         ", expected: ", model_flags.input_arrays().size());
84   }
85   llvm::StringSet<> function_input_names_set;
86   function_input_names_set.insert(function_input_names.begin(),
87                                   function_input_names.end());
88   for (const auto& input_array : model_flags.input_arrays()) {
89     if (function_input_names_set.count(input_array.name()) == 0) {
90       return errors::InvalidArgument("input array name (", input_array.name(),
91                                      ") does not exist in the given graph");
92     }
93   }
94 
95   // Get the list of output Op names from the function attribute.
96   llvm::SmallVector<llvm::StringRef, 4> function_output_names;
97   function_output_names.reserve(model_flags.output_arrays().size());
98   auto output_attr = tf_attrs.get("outputs");
99   if (!output_attr) {
100     return errors::InvalidArgument("no outputs attribute found");
101   }
102   auto output_names = output_attr.cast<mlir::StringAttr>().getValue();
103   output_names.split(function_output_names, ",");
104   const int function_output_names_size = function_output_names.size();
105   if (function_output_names_size != model_flags.output_arrays().size()) {
106     return errors::InvalidArgument(
107         "output array size mismatch: got ", function_output_names.size(),
108         ", expected: ", model_flags.output_arrays().size());
109   }
110   llvm::StringSet<> function_output_names_set;
111   function_output_names_set.insert(function_output_names.begin(),
112                                    function_output_names.end());
113   for (const auto& output_array : model_flags.output_arrays()) {
114     if (function_output_names_set.count(output_array) == 0) {
115       return errors::InvalidArgument("output array name (", output_array,
116                                      ") does not exist in the given graph");
117     }
118   }
119   return Status::OK();
120 }
121 
ConvertSavedModelToTFLiteFlatBuffer(const toco::ModelFlags & model_flags,const toco::TocoFlags & toco_flags,string * result)122 Status ConvertSavedModelToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
123                                            const toco::TocoFlags& toco_flags,
124                                            string* result) {
125   mlir::MLIRContext context;
126   mlir::TFL::QuantizationSpecs quant_specs;
127 
128   // Parse input arrays.
129   std::vector<string> node_names;
130   std::vector<string> node_dtypes;
131   std::vector<llvm::Optional<std::vector<int>>> node_shapes;
132   std::vector<llvm::Optional<double>> node_mins;
133   std::vector<llvm::Optional<double>> node_maxs;
134 
135   // Populate quantization specs.
136   TF_RETURN_IF_ERROR(internal::PopulateQuantizationSpecs(
137       model_flags, toco_flags, &quant_specs, &node_names, &node_dtypes,
138       &node_shapes, &node_mins, &node_maxs));
139 
140   internal::WarningUnusedFlags(model_flags, toco_flags);
141 
142   // Register all custom ops, including user-specified custom ops.
143   TF_RETURN_IF_ERROR(internal::RegisterAllCustomOps(toco_flags));
144 
145   auto& saved_model_tags = model_flags.saved_model_tags();
146   auto& saved_model_exported_names = model_flags.saved_model_exported_names();
147   std::unordered_set<std::string> tags(saved_model_tags.begin(),
148                                        saved_model_tags.end());
149   auto exported_names_in_vector = std::vector<std::string>(
150       saved_model_exported_names.begin(), saved_model_exported_names.end());
151   absl::Span<std::string> exported_names(exported_names_in_vector);
152 
153   if (exported_names.size() != 1) {
154     return errors::Unimplemented("Only support a single exported name.");
155   }
156 
157   tensorflow::GraphImportConfig specs;
158   specs.upgrade_legacy = true;
159 
160   std::vector<std::string> custom_opdefs(toco_flags.custom_opdefs().begin(),
161                                          toco_flags.custom_opdefs().end());
162   TF_ASSIGN_OR_RETURN(auto module,
163                       ImportSavedModel(model_flags.saved_model_dir(),
164                                        model_flags.saved_model_version(), tags,
165                                        absl::MakeSpan(custom_opdefs),
166                                        exported_names, specs, &context));
167 
168   if (!model_flags.input_arrays().empty() ||
169       !model_flags.output_arrays().empty()) {
170     TF_RETURN_IF_ERROR(HandleInputOutputArraysWithModule(model_flags, &module));
171   }
172 
173   mlir::TFL::PassConfig pass_config(quant_specs);
174   bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
175   pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
176   pass_config.lower_tensor_list_ops = true;
177   pass_config.enable_tflite_variables =
178       toco_flags.enable_tflite_resource_variables();
179   // Disable the unfolding of the 16x16 TF::BatchMatMulOp to avoid the
180   // conversion to an unsupported 16x16 TFL::FullyConnectedOp.
181   if (toco_flags.inference_type() == toco::IODataType::QUANTIZED_INT16) {
182     pass_config.unfold_batch_matmul = false;
183   }
184 
185   // TODO(b/153507667): Pass the session object when importing logic is removed.
186   auto status = internal::ConvertMLIRToTFLiteFlatBuffer(
187       toco_flags, std::move(module), pass_config, tags, result,
188       /*session=*/llvm::None);
189   return status;
190 }
191 
192 }  // namespace tensorflow
193