1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
17 
18 #include "absl/memory/memory.h"
19 #include "llvm/Support/raw_ostream.h"
20 #include "mlir/IR/Attributes.h"  // from @llvm-project
21 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
22 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
23 #include "mlir/IR/Identifier.h"  // from @llvm-project
24 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
25 #include "mlir/IR/Operation.h"  // from @llvm-project
26 #include "mlir/Parser.h"  // from @llvm-project
27 #include "tensorflow/cc/saved_model/bundle_v2.h"
28 #include "tensorflow/cc/saved_model/reader.h"
29 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
30 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
31 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
32 #include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
33 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
34 #include "tensorflow/core/framework/graph.pb.h"
35 #include "tensorflow/core/framework/versions.pb.h"
36 #include "tensorflow/core/graph/tensor_id.h"
37 #include "tensorflow/core/grappler/utils/transitive_fanin.h"
38 #include "tensorflow/core/platform/errors.h"
39 #include "tensorflow/core/platform/protobuf.h"
40 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
41 
42 namespace tensorflow {
43 
GraphdefToMlirImport(llvm::StringRef input,absl::string_view debug_info_file,const std::vector<std::string> & input_arrays,const std::vector<std::string> & input_dtypes,const std::vector<llvm::Optional<std::vector<int>>> & input_shapes,const std::vector<std::string> & output_arrays,const std::vector<std::string> & control_output_arrays,bool prune_unused_nodes,bool convert_legacy_fed_inputs,bool graph_as_function,bool upgrade_legacy,bool enable_shape_inference,mlir::MLIRContext * context)44 static StatusOr<mlir::OwningModuleRef> GraphdefToMlirImport(
45     llvm::StringRef input, absl::string_view debug_info_file,
46     const std::vector<std::string>& input_arrays,
47     const std::vector<std::string>& input_dtypes,
48     const std::vector<llvm::Optional<std::vector<int>>>& input_shapes,
49     const std::vector<std::string>& output_arrays,
50     const std::vector<std::string>& control_output_arrays,
51     bool prune_unused_nodes, bool convert_legacy_fed_inputs,
52     bool graph_as_function, bool upgrade_legacy, bool enable_shape_inference,
53     mlir::MLIRContext* context) {
54   GraphDef graphdef;
55   TF_RETURN_IF_ERROR(
56       tensorflow::LoadProtoFromBuffer({input.data(), input.size()}, &graphdef));
57 
58   GraphDebugInfo debug_info;
59   if (!debug_info_file.empty()) {
60     TF_RETURN_IF_ERROR(LoadProtoFromFile(debug_info_file, &debug_info));
61   }
62 
63   GraphImportConfig specs;
64   specs.prune_unused_nodes = prune_unused_nodes;
65   specs.convert_legacy_fed_inputs = convert_legacy_fed_inputs;
66   specs.graph_as_function = graph_as_function;
67   specs.upgrade_legacy = upgrade_legacy;
68   specs.enable_shape_inference = enable_shape_inference;
69   TF_RETURN_IF_ERROR(ParseInputArrayInfo(input_arrays, input_dtypes,
70                                          input_shapes, &specs.inputs));
71   TF_RETURN_IF_ERROR(ParseOutputArrayInfo(output_arrays, &specs.outputs));
72   TF_RETURN_IF_ERROR(
73       ParseOutputArrayInfo(control_output_arrays, &specs.control_outputs));
74   // TODO(b/142828368): Pruning should not be needed when TF import
75   // supports importing graphs w/ unregistered ops natively.
76   GraphDef pruned_graph_def;
77   if (specs.prune_unused_nodes) {
78     std::vector<std::string> terminal_nodes;
79     terminal_nodes.reserve(specs.outputs.size() + specs.inputs.size());
80     for (const auto& output : specs.outputs) {
81       terminal_nodes.push_back(std::string(ParseTensorName(output).node()));
82     }
83     for (const auto& control_output : specs.control_outputs) {
84       terminal_nodes.push_back(std::string(control_output));
85     }
86     for (const auto& input : specs.inputs) {
87       terminal_nodes.push_back(input.first);
88     }
89     TF_RETURN_IF_ERROR(tensorflow::grappler::SetTransitiveFaninGraph(
90         graphdef, &pruned_graph_def, terminal_nodes));
91     // TODO(ashwinm): Add a separate utility in grappler utils that abstracts
92     // both SetTransitiveFaninGraph and restoring the missing contents from the
93     // original graph like function def library and version.
94     pruned_graph_def.mutable_library()->Swap(graphdef.mutable_library());
95     pruned_graph_def.mutable_versions()->Swap(graphdef.mutable_versions());
96   }
97   return ConvertGraphdefToMlir(
98       specs.prune_unused_nodes ? pruned_graph_def : graphdef, debug_info, specs,
99       context);
100 }
101 
GraphdefToMlirTranslateFunction(llvm::StringRef input,absl::string_view debug_info_file,const std::vector<std::string> & input_arrays,const std::vector<std::string> & input_dtypes,const std::vector<llvm::Optional<std::vector<int>>> & input_shapes,const std::vector<std::string> & output_arrays,const std::vector<std::string> & control_output_arrays,bool prune_unused_nodes,bool convert_legacy_fed_inputs,bool graph_as_function,bool upgrade_legacy,bool enable_shape_inference,mlir::MLIRContext * context)102 StatusOr<mlir::OwningModuleRef> GraphdefToMlirTranslateFunction(
103     llvm::StringRef input, absl::string_view debug_info_file,
104     const std::vector<std::string>& input_arrays,
105     const std::vector<std::string>& input_dtypes,
106     const std::vector<llvm::Optional<std::vector<int>>>& input_shapes,
107     const std::vector<std::string>& output_arrays,
108     const std::vector<std::string>& control_output_arrays,
109     bool prune_unused_nodes, bool convert_legacy_fed_inputs,
110     bool graph_as_function, bool upgrade_legacy, bool enable_shape_inference,
111     mlir::MLIRContext* context) {
112   auto module_or = GraphdefToMlirImport(
113       input, debug_info_file, input_arrays, input_dtypes, input_shapes,
114       output_arrays, control_output_arrays, prune_unused_nodes,
115       convert_legacy_fed_inputs, graph_as_function, upgrade_legacy,
116       enable_shape_inference, context);
117   if (!module_or.status().ok()) {
118     LOG(ERROR) << "Graph import failed: " << module_or.status();
119   }
120   return module_or;
121 }
122 
GraphdefToMlirTranslateFunction(llvm::StringRef input,absl::string_view debug_info_file,absl::string_view input_arrays,absl::string_view input_dtypes,absl::string_view input_shapes,absl::string_view output_arrays,absl::string_view control_output_arrays,bool prune_unused_nodes,bool convert_legacy_fed_inputs,bool graph_as_function,bool upgrade_legacy,bool enable_shape_inference,mlir::MLIRContext * context)123 StatusOr<mlir::OwningModuleRef> GraphdefToMlirTranslateFunction(
124     llvm::StringRef input, absl::string_view debug_info_file,
125     absl::string_view input_arrays, absl::string_view input_dtypes,
126     absl::string_view input_shapes, absl::string_view output_arrays,
127     absl::string_view control_output_arrays, bool prune_unused_nodes,
128     bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy,
129     bool enable_shape_inference, mlir::MLIRContext* context) {
130   std::vector<std::string> input_array_vector;
131   std::vector<std::string> input_dtype_vector;
132   std::vector<llvm::Optional<std::vector<int>>> input_shapes_vector;
133   std::vector<std::string> output_array_vector;
134   std::vector<std::string> control_output_array_vector;
135   TF_RETURN_IF_ERROR(ParseNodeNames(input_arrays, input_array_vector));
136   TF_RETURN_IF_ERROR(ParseNodeDataTypes(input_dtypes, input_dtype_vector));
137   TF_RETURN_IF_ERROR(ParseNodeNames(output_arrays, output_array_vector));
138   TF_RETURN_IF_ERROR(ParseNodeShapes(input_shapes, input_shapes_vector));
139   TF_RETURN_IF_ERROR(
140       ParseNodeNames(control_output_arrays, control_output_array_vector));
141   return GraphdefToMlirTranslateFunction(
142       input, debug_info_file, input_array_vector, input_dtype_vector,
143       input_shapes_vector, output_array_vector, control_output_array_vector,
144       prune_unused_nodes, convert_legacy_fed_inputs, graph_as_function,
145       upgrade_legacy, enable_shape_inference, context);
146 }
147 
SavedModelObjectGraphToMlirImport(absl::string_view saved_model_dir,const std::unordered_set<std::string> & tags,absl::Span<std::string> exported_names,mlir::MLIRContext * context)148 StatusOr<mlir::OwningModuleRef> SavedModelObjectGraphToMlirImport(
149     absl::string_view saved_model_dir,
150     const std::unordered_set<std::string>& tags,
151     absl::Span<std::string> exported_names, mlir::MLIRContext* context) {
152   tensorflow::SavedModelV2Bundle bundle;
153   auto load_status = tensorflow::SavedModelV2Bundle::Load(
154       std::string(saved_model_dir.data(), saved_model_dir.length()), &bundle);
155   if (!load_status.ok()) {
156     LOG(ERROR) << "Failed to load saved model '" << saved_model_dir
157                << "': " << load_status;
158     return load_status;
159   }
160 
161   auto module_or = ConvertSavedModelToMlir(&bundle, context, exported_names);
162   if (!module_or.status().ok()) {
163     LOG(ERROR) << "SavedModel import failed: " << module_or.status();
164   }
165   return module_or;
166 }
167 
SavedModelSignatureDefsToMlirImport(absl::string_view saved_model_dir,const std::unordered_set<std::string> & tags,absl::Span<std::string> exported_names,mlir::MLIRContext * context,MLIRImportOptions options,bool lift_variables,std::unique_ptr<tensorflow::SavedModelBundle> * saved_model_bundle)168 StatusOr<mlir::OwningModuleRef> SavedModelSignatureDefsToMlirImport(
169     absl::string_view saved_model_dir,
170     const std::unordered_set<std::string>& tags,
171     absl::Span<std::string> exported_names, mlir::MLIRContext* context,
172     MLIRImportOptions options, bool lift_variables,
173     std::unique_ptr<tensorflow::SavedModelBundle>* saved_model_bundle) {
174   // Create local bundle if no one is provided to use.
175   std::unique_ptr<tensorflow::SavedModelBundle> bundle;
176   if (saved_model_bundle == nullptr) {
177     bundle = std::make_unique<tensorflow::SavedModelBundle>();
178   } else if (*saved_model_bundle == nullptr) {
179     *saved_model_bundle = std::make_unique<tensorflow::SavedModelBundle>();
180   }
181   SavedModelBundle* bundle_ptr =
182       saved_model_bundle ? saved_model_bundle->get() : bundle.get();
183   tensorflow::SessionOptions session_options;
184 
185   // Force saved model states to be restored to CPU.
186   (*session_options.config.mutable_device_count())["GPU"] = 0;
187   auto load_status = tensorflow::LoadSavedModel(
188       session_options, /* run_options = */ {}, std::string(saved_model_dir),
189       tags, bundle_ptr);
190   if (!load_status.ok()) {
191     LOG(ERROR) << "Failed to load saved model v1 '" << saved_model_dir
192                << "': " << load_status;
193     return load_status;
194   }
195 
196   auto module_or = ConvertSavedModelV1ToMlir(*bundle_ptr, exported_names,
197                                              context, options, lift_variables);
198   if (!module_or.status().ok()) {
199     LOG(ERROR) << "SavedModel V1 import failed: " << module_or.status();
200   }
201   return module_or;
202 }
203 
SavedModelSignatureDefsToMlirImportLite(absl::string_view saved_model_dir,const std::unordered_set<std::string> & tags,absl::Span<std::string> exported_names,mlir::MLIRContext * context,MLIRImportOptions options)204 StatusOr<mlir::OwningModuleRef> SavedModelSignatureDefsToMlirImportLite(
205     absl::string_view saved_model_dir,
206     const std::unordered_set<std::string>& tags,
207     absl::Span<std::string> exported_names, mlir::MLIRContext* context,
208     MLIRImportOptions options) {
209   MetaGraphDef meta_graph_def;
210   auto status = ReadMetaGraphDefFromSavedModel(std::string(saved_model_dir),
211                                                tags, &meta_graph_def);
212   if (!status.ok()) {
213     LOG(ERROR) << "Failed to load saved model v1 '" << saved_model_dir
214                << "': " << status;
215     return status;
216   }
217 
218   absl::optional<absl::Span<const std::string>> optional_exported_names;
219   if (!exported_names.empty()) optional_exported_names = exported_names;
220 
221   // TODO(b/186898924): debug info in the savedmodel should not be ignored and
222   // should be passed here.
223   auto module_or =
224       ConvertSavedModelV1ToMlirLite(meta_graph_def, /*debug_info=*/{},
225                                     optional_exported_names, context, options);
226   if (!module_or.status().ok()) {
227     LOG(ERROR) << "SavedModel V1 import failed: " << module_or.status();
228   }
229   return module_or;
230 }
231 
GraphdefToSplattedMlirTranslateFunction(llvm::StringRef input,absl::string_view debug_info_file,const std::vector<std::string> & input_arrays,const std::vector<std::string> & input_dtypes,const std::vector<llvm::Optional<std::vector<int>>> & input_shapes,const std::vector<std::string> & output_arrays,const std::vector<std::string> & control_output_arrays,bool prune_unused_nodes,bool convert_legacy_fed_inputs,bool graph_as_function,bool upgrade_legacy,bool enable_shape_inference,mlir::MLIRContext * context)232 StatusOr<mlir::OwningModuleRef> GraphdefToSplattedMlirTranslateFunction(
233     llvm::StringRef input, absl::string_view debug_info_file,
234     const std::vector<std::string>& input_arrays,
235     const std::vector<std::string>& input_dtypes,
236     const std::vector<llvm::Optional<std::vector<int>>>& input_shapes,
237     const std::vector<std::string>& output_arrays,
238     const std::vector<std::string>& control_output_arrays,
239     bool prune_unused_nodes, bool convert_legacy_fed_inputs,
240     bool graph_as_function, bool upgrade_legacy, bool enable_shape_inference,
241     mlir::MLIRContext* context) {
242   auto module_or = GraphdefToMlirImport(
243       input, debug_info_file, input_arrays, input_dtypes, input_shapes,
244       output_arrays, control_output_arrays, prune_unused_nodes,
245       convert_legacy_fed_inputs, graph_as_function, upgrade_legacy,
246       enable_shape_inference, context);
247   if (!module_or.status().ok()) {
248     LOG(ERROR) << "Graph import failed: " << module_or.status();
249     return module_or.status();
250   }
251   auto& module = module_or.ValueOrDie();
252   std::srand(0);
253   for (auto fn : module->getOps<mlir::FuncOp>()) {
254     for (auto& bb : fn) {
255       for (auto& inst : bb) {
256         auto attr_id = mlir::Identifier::get("value", context);
257         if (auto attr = inst.getAttrOfType<mlir::ElementsAttr>(attr_id)) {
258           mlir::Attribute rand_val;
259           mlir::Type element_type = attr.getType().getElementType();
260           if (element_type.isa<mlir::IntegerType>()) {
261             rand_val = mlir::IntegerAttr::get(element_type, std::rand());
262           } else if (element_type.isF16() || element_type.isF32() ||
263                      element_type.isF64()) {
264             rand_val = mlir::FloatAttr::get(element_type,
265                                             std::rand() * 1.0 / RAND_MAX);
266 
267           } else {
268             inst.emitWarning()
269                 << "Skipping splat conversion for "
270                 << "an unsupported attribute type " << element_type;
271             continue;
272           }
273           auto new_attr =
274               mlir::DenseElementsAttr::get(attr.getType(), rand_val);
275           inst.setAttr(attr_id, new_attr);
276         }
277       }
278     }
279   }
280   return module_or;
281 }
282 
GraphdefToSplattedMlirTranslateFunction(llvm::StringRef input,absl::string_view debug_info_file,absl::string_view input_arrays,absl::string_view input_dtypes,absl::string_view input_shapes,absl::string_view output_arrays,absl::string_view control_output_arrays,bool prune_unused_nodes,bool convert_legacy_fed_inputs,bool graph_as_function,bool upgrade_legacy,bool enable_shape_inference,mlir::MLIRContext * context)283 StatusOr<mlir::OwningModuleRef> GraphdefToSplattedMlirTranslateFunction(
284     llvm::StringRef input, absl::string_view debug_info_file,
285     absl::string_view input_arrays, absl::string_view input_dtypes,
286     absl::string_view input_shapes, absl::string_view output_arrays,
287     absl::string_view control_output_arrays, bool prune_unused_nodes,
288     bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy,
289     bool enable_shape_inference, mlir::MLIRContext* context) {
290   std::vector<std::string> input_array_vector;
291   std::vector<std::string> input_dtype_vector;
292   std::vector<llvm::Optional<std::vector<int>>> input_shapes_vector;
293   std::vector<std::string> output_array_vector;
294   std::vector<std::string> control_output_array_vector;
295   TF_RETURN_IF_ERROR(ParseNodeNames(input_arrays, input_array_vector));
296   TF_RETURN_IF_ERROR(ParseNodeDataTypes(input_dtypes, input_dtype_vector));
297   TF_RETURN_IF_ERROR(ParseNodeNames(output_arrays, output_array_vector));
298   TF_RETURN_IF_ERROR(ParseNodeShapes(input_shapes, input_shapes_vector));
299   TF_RETURN_IF_ERROR(
300       ParseNodeNames(control_output_arrays, control_output_array_vector));
301   return GraphdefToSplattedMlirTranslateFunction(
302       input, debug_info_file, input_array_vector, input_dtype_vector,
303       input_shapes_vector, output_array_vector, control_output_array_vector,
304       prune_unused_nodes, convert_legacy_fed_inputs, graph_as_function,
305       upgrade_legacy, enable_shape_inference, context);
306 }
307 
308 }  // namespace tensorflow
309