1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
17
18 #include "absl/memory/memory.h"
19 #include "llvm/Support/raw_ostream.h"
20 #include "mlir/IR/Attributes.h" // from @llvm-project
21 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
22 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
23 #include "mlir/IR/Identifier.h" // from @llvm-project
24 #include "mlir/IR/MLIRContext.h" // from @llvm-project
25 #include "mlir/IR/Operation.h" // from @llvm-project
26 #include "mlir/Parser.h" // from @llvm-project
27 #include "tensorflow/cc/saved_model/bundle_v2.h"
28 #include "tensorflow/cc/saved_model/reader.h"
29 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
30 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
31 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
32 #include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
33 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
34 #include "tensorflow/core/framework/graph.pb.h"
35 #include "tensorflow/core/framework/versions.pb.h"
36 #include "tensorflow/core/graph/tensor_id.h"
37 #include "tensorflow/core/grappler/utils/transitive_fanin.h"
38 #include "tensorflow/core/platform/errors.h"
39 #include "tensorflow/core/platform/protobuf.h"
40 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
41
42 namespace tensorflow {
43
GraphdefToMlirImport(llvm::StringRef input,absl::string_view debug_info_file,const std::vector<std::string> & input_arrays,const std::vector<std::string> & input_dtypes,const std::vector<llvm::Optional<std::vector<int>>> & input_shapes,const std::vector<std::string> & output_arrays,const std::vector<std::string> & control_output_arrays,bool prune_unused_nodes,bool convert_legacy_fed_inputs,bool graph_as_function,bool upgrade_legacy,bool enable_shape_inference,mlir::MLIRContext * context)44 static StatusOr<mlir::OwningModuleRef> GraphdefToMlirImport(
45 llvm::StringRef input, absl::string_view debug_info_file,
46 const std::vector<std::string>& input_arrays,
47 const std::vector<std::string>& input_dtypes,
48 const std::vector<llvm::Optional<std::vector<int>>>& input_shapes,
49 const std::vector<std::string>& output_arrays,
50 const std::vector<std::string>& control_output_arrays,
51 bool prune_unused_nodes, bool convert_legacy_fed_inputs,
52 bool graph_as_function, bool upgrade_legacy, bool enable_shape_inference,
53 mlir::MLIRContext* context) {
54 GraphDef graphdef;
55 TF_RETURN_IF_ERROR(
56 tensorflow::LoadProtoFromBuffer({input.data(), input.size()}, &graphdef));
57
58 GraphDebugInfo debug_info;
59 if (!debug_info_file.empty()) {
60 TF_RETURN_IF_ERROR(LoadProtoFromFile(debug_info_file, &debug_info));
61 }
62
63 GraphImportConfig specs;
64 specs.prune_unused_nodes = prune_unused_nodes;
65 specs.convert_legacy_fed_inputs = convert_legacy_fed_inputs;
66 specs.graph_as_function = graph_as_function;
67 specs.upgrade_legacy = upgrade_legacy;
68 specs.enable_shape_inference = enable_shape_inference;
69 TF_RETURN_IF_ERROR(ParseInputArrayInfo(input_arrays, input_dtypes,
70 input_shapes, &specs.inputs));
71 TF_RETURN_IF_ERROR(ParseOutputArrayInfo(output_arrays, &specs.outputs));
72 TF_RETURN_IF_ERROR(
73 ParseOutputArrayInfo(control_output_arrays, &specs.control_outputs));
74 // TODO(b/142828368): Pruning should not be needed when TF import
75 // supports importing graphs w/ unregistered ops natively.
76 GraphDef pruned_graph_def;
77 if (specs.prune_unused_nodes) {
78 std::vector<std::string> terminal_nodes;
79 terminal_nodes.reserve(specs.outputs.size() + specs.inputs.size());
80 for (const auto& output : specs.outputs) {
81 terminal_nodes.push_back(std::string(ParseTensorName(output).node()));
82 }
83 for (const auto& control_output : specs.control_outputs) {
84 terminal_nodes.push_back(std::string(control_output));
85 }
86 for (const auto& input : specs.inputs) {
87 terminal_nodes.push_back(input.first);
88 }
89 TF_RETURN_IF_ERROR(tensorflow::grappler::SetTransitiveFaninGraph(
90 graphdef, &pruned_graph_def, terminal_nodes));
91 // TODO(ashwinm): Add a separate utility in grappler utils that abstracts
92 // both SetTransitiveFaninGraph and restoring the missing contents from the
93 // original graph like function def library and version.
94 pruned_graph_def.mutable_library()->Swap(graphdef.mutable_library());
95 pruned_graph_def.mutable_versions()->Swap(graphdef.mutable_versions());
96 }
97 return ConvertGraphdefToMlir(
98 specs.prune_unused_nodes ? pruned_graph_def : graphdef, debug_info, specs,
99 context);
100 }
101
GraphdefToMlirTranslateFunction(llvm::StringRef input,absl::string_view debug_info_file,const std::vector<std::string> & input_arrays,const std::vector<std::string> & input_dtypes,const std::vector<llvm::Optional<std::vector<int>>> & input_shapes,const std::vector<std::string> & output_arrays,const std::vector<std::string> & control_output_arrays,bool prune_unused_nodes,bool convert_legacy_fed_inputs,bool graph_as_function,bool upgrade_legacy,bool enable_shape_inference,mlir::MLIRContext * context)102 StatusOr<mlir::OwningModuleRef> GraphdefToMlirTranslateFunction(
103 llvm::StringRef input, absl::string_view debug_info_file,
104 const std::vector<std::string>& input_arrays,
105 const std::vector<std::string>& input_dtypes,
106 const std::vector<llvm::Optional<std::vector<int>>>& input_shapes,
107 const std::vector<std::string>& output_arrays,
108 const std::vector<std::string>& control_output_arrays,
109 bool prune_unused_nodes, bool convert_legacy_fed_inputs,
110 bool graph_as_function, bool upgrade_legacy, bool enable_shape_inference,
111 mlir::MLIRContext* context) {
112 auto module_or = GraphdefToMlirImport(
113 input, debug_info_file, input_arrays, input_dtypes, input_shapes,
114 output_arrays, control_output_arrays, prune_unused_nodes,
115 convert_legacy_fed_inputs, graph_as_function, upgrade_legacy,
116 enable_shape_inference, context);
117 if (!module_or.status().ok()) {
118 LOG(ERROR) << "Graph import failed: " << module_or.status();
119 }
120 return module_or;
121 }
122
GraphdefToMlirTranslateFunction(llvm::StringRef input,absl::string_view debug_info_file,absl::string_view input_arrays,absl::string_view input_dtypes,absl::string_view input_shapes,absl::string_view output_arrays,absl::string_view control_output_arrays,bool prune_unused_nodes,bool convert_legacy_fed_inputs,bool graph_as_function,bool upgrade_legacy,bool enable_shape_inference,mlir::MLIRContext * context)123 StatusOr<mlir::OwningModuleRef> GraphdefToMlirTranslateFunction(
124 llvm::StringRef input, absl::string_view debug_info_file,
125 absl::string_view input_arrays, absl::string_view input_dtypes,
126 absl::string_view input_shapes, absl::string_view output_arrays,
127 absl::string_view control_output_arrays, bool prune_unused_nodes,
128 bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy,
129 bool enable_shape_inference, mlir::MLIRContext* context) {
130 std::vector<std::string> input_array_vector;
131 std::vector<std::string> input_dtype_vector;
132 std::vector<llvm::Optional<std::vector<int>>> input_shapes_vector;
133 std::vector<std::string> output_array_vector;
134 std::vector<std::string> control_output_array_vector;
135 TF_RETURN_IF_ERROR(ParseNodeNames(input_arrays, input_array_vector));
136 TF_RETURN_IF_ERROR(ParseNodeDataTypes(input_dtypes, input_dtype_vector));
137 TF_RETURN_IF_ERROR(ParseNodeNames(output_arrays, output_array_vector));
138 TF_RETURN_IF_ERROR(ParseNodeShapes(input_shapes, input_shapes_vector));
139 TF_RETURN_IF_ERROR(
140 ParseNodeNames(control_output_arrays, control_output_array_vector));
141 return GraphdefToMlirTranslateFunction(
142 input, debug_info_file, input_array_vector, input_dtype_vector,
143 input_shapes_vector, output_array_vector, control_output_array_vector,
144 prune_unused_nodes, convert_legacy_fed_inputs, graph_as_function,
145 upgrade_legacy, enable_shape_inference, context);
146 }
147
SavedModelObjectGraphToMlirImport(absl::string_view saved_model_dir,const std::unordered_set<std::string> & tags,absl::Span<std::string> exported_names,mlir::MLIRContext * context)148 StatusOr<mlir::OwningModuleRef> SavedModelObjectGraphToMlirImport(
149 absl::string_view saved_model_dir,
150 const std::unordered_set<std::string>& tags,
151 absl::Span<std::string> exported_names, mlir::MLIRContext* context) {
152 tensorflow::SavedModelV2Bundle bundle;
153 auto load_status = tensorflow::SavedModelV2Bundle::Load(
154 std::string(saved_model_dir.data(), saved_model_dir.length()), &bundle);
155 if (!load_status.ok()) {
156 LOG(ERROR) << "Failed to load saved model '" << saved_model_dir
157 << "': " << load_status;
158 return load_status;
159 }
160
161 auto module_or = ConvertSavedModelToMlir(&bundle, context, exported_names);
162 if (!module_or.status().ok()) {
163 LOG(ERROR) << "SavedModel import failed: " << module_or.status();
164 }
165 return module_or;
166 }
167
SavedModelSignatureDefsToMlirImport(absl::string_view saved_model_dir,const std::unordered_set<std::string> & tags,absl::Span<std::string> exported_names,mlir::MLIRContext * context,MLIRImportOptions options,bool lift_variables,std::unique_ptr<tensorflow::SavedModelBundle> * saved_model_bundle)168 StatusOr<mlir::OwningModuleRef> SavedModelSignatureDefsToMlirImport(
169 absl::string_view saved_model_dir,
170 const std::unordered_set<std::string>& tags,
171 absl::Span<std::string> exported_names, mlir::MLIRContext* context,
172 MLIRImportOptions options, bool lift_variables,
173 std::unique_ptr<tensorflow::SavedModelBundle>* saved_model_bundle) {
174 // Create local bundle if no one is provided to use.
175 std::unique_ptr<tensorflow::SavedModelBundle> bundle;
176 if (saved_model_bundle == nullptr) {
177 bundle = std::make_unique<tensorflow::SavedModelBundle>();
178 } else if (*saved_model_bundle == nullptr) {
179 *saved_model_bundle = std::make_unique<tensorflow::SavedModelBundle>();
180 }
181 SavedModelBundle* bundle_ptr =
182 saved_model_bundle ? saved_model_bundle->get() : bundle.get();
183 tensorflow::SessionOptions session_options;
184
185 // Force saved model states to be restored to CPU.
186 (*session_options.config.mutable_device_count())["GPU"] = 0;
187 auto load_status = tensorflow::LoadSavedModel(
188 session_options, /* run_options = */ {}, std::string(saved_model_dir),
189 tags, bundle_ptr);
190 if (!load_status.ok()) {
191 LOG(ERROR) << "Failed to load saved model v1 '" << saved_model_dir
192 << "': " << load_status;
193 return load_status;
194 }
195
196 auto module_or = ConvertSavedModelV1ToMlir(*bundle_ptr, exported_names,
197 context, options, lift_variables);
198 if (!module_or.status().ok()) {
199 LOG(ERROR) << "SavedModel V1 import failed: " << module_or.status();
200 }
201 return module_or;
202 }
203
SavedModelSignatureDefsToMlirImportLite(absl::string_view saved_model_dir,const std::unordered_set<std::string> & tags,absl::Span<std::string> exported_names,mlir::MLIRContext * context,MLIRImportOptions options)204 StatusOr<mlir::OwningModuleRef> SavedModelSignatureDefsToMlirImportLite(
205 absl::string_view saved_model_dir,
206 const std::unordered_set<std::string>& tags,
207 absl::Span<std::string> exported_names, mlir::MLIRContext* context,
208 MLIRImportOptions options) {
209 MetaGraphDef meta_graph_def;
210 auto status = ReadMetaGraphDefFromSavedModel(std::string(saved_model_dir),
211 tags, &meta_graph_def);
212 if (!status.ok()) {
213 LOG(ERROR) << "Failed to load saved model v1 '" << saved_model_dir
214 << "': " << status;
215 return status;
216 }
217
218 absl::optional<absl::Span<const std::string>> optional_exported_names;
219 if (!exported_names.empty()) optional_exported_names = exported_names;
220
221 // TODO(b/186898924): debug info in the savedmodel should not be ignored and
222 // should be passed here.
223 auto module_or =
224 ConvertSavedModelV1ToMlirLite(meta_graph_def, /*debug_info=*/{},
225 optional_exported_names, context, options);
226 if (!module_or.status().ok()) {
227 LOG(ERROR) << "SavedModel V1 import failed: " << module_or.status();
228 }
229 return module_or;
230 }
231
GraphdefToSplattedMlirTranslateFunction(llvm::StringRef input,absl::string_view debug_info_file,const std::vector<std::string> & input_arrays,const std::vector<std::string> & input_dtypes,const std::vector<llvm::Optional<std::vector<int>>> & input_shapes,const std::vector<std::string> & output_arrays,const std::vector<std::string> & control_output_arrays,bool prune_unused_nodes,bool convert_legacy_fed_inputs,bool graph_as_function,bool upgrade_legacy,bool enable_shape_inference,mlir::MLIRContext * context)232 StatusOr<mlir::OwningModuleRef> GraphdefToSplattedMlirTranslateFunction(
233 llvm::StringRef input, absl::string_view debug_info_file,
234 const std::vector<std::string>& input_arrays,
235 const std::vector<std::string>& input_dtypes,
236 const std::vector<llvm::Optional<std::vector<int>>>& input_shapes,
237 const std::vector<std::string>& output_arrays,
238 const std::vector<std::string>& control_output_arrays,
239 bool prune_unused_nodes, bool convert_legacy_fed_inputs,
240 bool graph_as_function, bool upgrade_legacy, bool enable_shape_inference,
241 mlir::MLIRContext* context) {
242 auto module_or = GraphdefToMlirImport(
243 input, debug_info_file, input_arrays, input_dtypes, input_shapes,
244 output_arrays, control_output_arrays, prune_unused_nodes,
245 convert_legacy_fed_inputs, graph_as_function, upgrade_legacy,
246 enable_shape_inference, context);
247 if (!module_or.status().ok()) {
248 LOG(ERROR) << "Graph import failed: " << module_or.status();
249 return module_or.status();
250 }
251 auto& module = module_or.ValueOrDie();
252 std::srand(0);
253 for (auto fn : module->getOps<mlir::FuncOp>()) {
254 for (auto& bb : fn) {
255 for (auto& inst : bb) {
256 auto attr_id = mlir::Identifier::get("value", context);
257 if (auto attr = inst.getAttrOfType<mlir::ElementsAttr>(attr_id)) {
258 mlir::Attribute rand_val;
259 mlir::Type element_type = attr.getType().getElementType();
260 if (element_type.isa<mlir::IntegerType>()) {
261 rand_val = mlir::IntegerAttr::get(element_type, std::rand());
262 } else if (element_type.isF16() || element_type.isF32() ||
263 element_type.isF64()) {
264 rand_val = mlir::FloatAttr::get(element_type,
265 std::rand() * 1.0 / RAND_MAX);
266
267 } else {
268 inst.emitWarning()
269 << "Skipping splat conversion for "
270 << "an unsupported attribute type " << element_type;
271 continue;
272 }
273 auto new_attr =
274 mlir::DenseElementsAttr::get(attr.getType(), rand_val);
275 inst.setAttr(attr_id, new_attr);
276 }
277 }
278 }
279 }
280 return module_or;
281 }
282
GraphdefToSplattedMlirTranslateFunction(llvm::StringRef input,absl::string_view debug_info_file,absl::string_view input_arrays,absl::string_view input_dtypes,absl::string_view input_shapes,absl::string_view output_arrays,absl::string_view control_output_arrays,bool prune_unused_nodes,bool convert_legacy_fed_inputs,bool graph_as_function,bool upgrade_legacy,bool enable_shape_inference,mlir::MLIRContext * context)283 StatusOr<mlir::OwningModuleRef> GraphdefToSplattedMlirTranslateFunction(
284 llvm::StringRef input, absl::string_view debug_info_file,
285 absl::string_view input_arrays, absl::string_view input_dtypes,
286 absl::string_view input_shapes, absl::string_view output_arrays,
287 absl::string_view control_output_arrays, bool prune_unused_nodes,
288 bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy,
289 bool enable_shape_inference, mlir::MLIRContext* context) {
290 std::vector<std::string> input_array_vector;
291 std::vector<std::string> input_dtype_vector;
292 std::vector<llvm::Optional<std::vector<int>>> input_shapes_vector;
293 std::vector<std::string> output_array_vector;
294 std::vector<std::string> control_output_array_vector;
295 TF_RETURN_IF_ERROR(ParseNodeNames(input_arrays, input_array_vector));
296 TF_RETURN_IF_ERROR(ParseNodeDataTypes(input_dtypes, input_dtype_vector));
297 TF_RETURN_IF_ERROR(ParseNodeNames(output_arrays, output_array_vector));
298 TF_RETURN_IF_ERROR(ParseNodeShapes(input_shapes, input_shapes_vector));
299 TF_RETURN_IF_ERROR(
300 ParseNodeNames(control_output_arrays, control_output_array_vector));
301 return GraphdefToSplattedMlirTranslateFunction(
302 input, debug_info_file, input_array_vector, input_dtype_vector,
303 input_shapes_vector, output_array_vector, control_output_array_vector,
304 prune_unused_nodes, convert_legacy_fed_inputs, graph_as_function,
305 upgrade_legacy, enable_shape_inference, context);
306 }
307
308 } // namespace tensorflow
309