1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
17
18 #include <atomic>
19 #include <functional>
20 #include <iterator>
21 #include <memory>
22 #include <queue>
23 #include <string>
24 #include <tuple>
25 #include <type_traits>
26 #include <unordered_set>
27 #include <utility>
28 #include <vector>
29
30 #include "absl/algorithm/container.h"
31 #include "absl/container/flat_hash_map.h"
32 #include "absl/container/flat_hash_set.h"
33 #include "absl/container/inlined_vector.h"
34 #include "absl/strings/escaping.h"
35 #include "absl/strings/numbers.h"
36 #include "absl/strings/str_cat.h"
37 #include "absl/strings/str_join.h"
38 #include "absl/strings/string_view.h"
39 #include "absl/strings/strip.h"
40 #include "llvm/ADT/ArrayRef.h"
41 #include "llvm/ADT/DenseMap.h"
42 #include "llvm/ADT/DenseSet.h"
43 #include "llvm/ADT/STLExtras.h"
44 #include "llvm/ADT/ScopeExit.h"
45 #include "llvm/ADT/SetVector.h"
46 #include "llvm/ADT/SmallVector.h"
47 #include "llvm/ADT/StringRef.h"
48 #include "llvm/ADT/StringSet.h"
49 #include "llvm/ADT/Twine.h"
50 #include "llvm/Support/FormatVariadic.h"
51 #include "llvm/Support/SourceMgr.h"
52 #include "llvm/Support/raw_ostream.h"
53 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
54 #include "mlir/IR/Attributes.h" // from @llvm-project
55 #include "mlir/IR/Builders.h" // from @llvm-project
56 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
57 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
58 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
59 #include "mlir/IR/Diagnostics.h" // from @llvm-project
60 #include "mlir/IR/Identifier.h" // from @llvm-project
61 #include "mlir/IR/Location.h" // from @llvm-project
62 #include "mlir/IR/MLIRContext.h" // from @llvm-project
63 #include "mlir/IR/OpDefinition.h" // from @llvm-project
64 #include "mlir/IR/Types.h" // from @llvm-project
65 #include "mlir/IR/Verifier.h" // from @llvm-project
66 #include "mlir/Pass/PassManager.h" // from @llvm-project
67 #include "tensorflow/cc/saved_model/constants.h"
68 #include "tensorflow/cc/saved_model/loader_util.h"
69 #include "tensorflow/compiler/jit/shape_inference_helpers.h"
70 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
71 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
72 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
73 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
74 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
75 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
76 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
77 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
78 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h"
79 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
80 #include "tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h"
81 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h"
82 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
83 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
84 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
85 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
86 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
87 #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
88 #include "tensorflow/compiler/xla/status_macros.h"
89 #include "tensorflow/core/common_runtime/function.h"
90 #include "tensorflow/core/common_runtime/graph_constructor.h"
91 #include "tensorflow/core/common_runtime/shape_refiner.h"
92 #include "tensorflow/core/framework/attr_value.pb.h"
93 #include "tensorflow/core/framework/function.pb.h"
94 #include "tensorflow/core/framework/graph.pb.h"
95 #include "tensorflow/core/framework/node_def.pb.h"
96 #include "tensorflow/core/framework/node_def_util.h"
97 #include "tensorflow/core/framework/op.h"
98 #include "tensorflow/core/framework/resource_var.h"
99 #include "tensorflow/core/framework/shape_inference.h"
100 #include "tensorflow/core/framework/tensor.pb.h"
101 #include "tensorflow/core/framework/types.h"
102 #include "tensorflow/core/framework/types.pb.h"
103 #include "tensorflow/core/framework/versions.pb.h"
104 #include "tensorflow/core/graph/algorithm.h"
105 #include "tensorflow/core/graph/graph.h"
106 #include "tensorflow/core/graph/node_builder.h"
107 #include "tensorflow/core/graph/tensor_id.h"
108 #include "tensorflow/core/grappler/utils/transitive_fanin.h"
109 #include "tensorflow/core/lib/core/errors.h"
110 #include "tensorflow/core/lib/strings/str_util.h"
111 #include "tensorflow/core/platform/crash_analysis.h"
112 #include "tensorflow/core/platform/errors.h"
113 #include "tensorflow/core/platform/fingerprint.h"
114 #include "tensorflow/core/platform/logging.h"
115 #include "tensorflow/core/platform/path.h"
116 #include "tensorflow/core/platform/protobuf.h"
117 #include "tensorflow/core/platform/types.h"
118 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
119 #include "tensorflow/core/protobuf/meta_graph.pb.h"
120 #include "tensorflow/core/protobuf/saved_object_graph.pb.h"
121 #include "tensorflow/core/protobuf/saver.pb.h"
122 #include "tensorflow/core/protobuf/struct.pb.h"
123 #include "tensorflow/core/protobuf/trackable_object_graph.pb.h"
124 #include "tensorflow/core/util/device_name_utils.h"
125 #include "tensorflow/core/util/dump_graph.h"
126 #include "tensorflow/stream_executor/lib/statusor.h"
127
StringRefToView(llvm::StringRef ref)128 static inline absl::string_view StringRefToView(llvm::StringRef ref) {
129 return {ref.data(), ref.size()};
130 }
131
132 namespace tensorflow {
133
134 const char kImportModelDefaultGraphFuncName[] = "main";
135
136 using mlir::NamedAttrList;
137 using mlir::TensorType;
138 using mlir::tf_saved_model::AssetOp;
139 using mlir::tf_saved_model::GlobalTensorOp;
140 using mlir::tf_saved_model::SessionInitializerOp;
141 using stream_executor::port::StatusOr;
142
143 namespace {
144
IsOutputShapesAttribute(const AttrValue & attr_value,llvm::StringRef attr_name)145 bool IsOutputShapesAttribute(const AttrValue& attr_value,
146 llvm::StringRef attr_name) {
147 return attr_name.compare("_output_shapes") == 0 &&
148 attr_value.value_case() == AttrValue::kList;
149 }
150
IsResourceOutputShapesAttribute(const AttrValue & attr_value,llvm::StringRef attr_name)151 bool IsResourceOutputShapesAttribute(const AttrValue& attr_value,
152 llvm::StringRef attr_name) {
153 if (attr_name == "_handle_dtypes" || attr_name == "_handle_shapes")
154 return attr_value.value_case() == AttrValue::kList;
155 return false;
156 }
157
LoadImporterDialects(mlir::MLIRContext & context)158 void LoadImporterDialects(mlir::MLIRContext& context) {
159 // Load dialects involved in the conversion
160 mlir::DialectRegistry registry;
161 mlir::RegisterAllTensorFlowDialects(registry);
162 context.appendDialectRegistry(registry);
163 for (llvm::StringRef name : registry.getDialectNames())
164 context.getOrLoadDialect(name);
165 }
166
167 // This class is used to generate new MLIR function name strings that are both
168 // unique in the TF function library `flib_` and unique among the name strings
169 // generated by the class object during its lifetime.
170 //
171 // In theory, this class is not necessary because we should simply take
172 // the TF function name and use it as MLIR function name. However, for some
173 // unknown reasons (callout for investigation in b/142268695), keeping the
174 // function names unchanged in an MLIR roundtrip causes test failures.
175 // TODO(b/142268695) Re-evaluate whether we need this class v.s. directly using
176 // and TF function name as MLIR function name after b/142268695 is root caused.
177 class NameUniquifier : public OpOrArgNameMapper {
178 public:
NameUniquifier(const FunctionLibraryDefinition & flib)179 explicit NameUniquifier(const FunctionLibraryDefinition& flib)
180 : flib_(flib) {}
181
182 private:
IsUnique(llvm::StringRef name)183 bool IsUnique(llvm::StringRef name) override {
184 return !flib_.Contains(std::string(name));
185 }
186
GetName(OpOrVal op_or_val)187 std::string GetName(OpOrVal op_or_val) override {
188 DCHECK(false) << "Unimplemented";
189 return "";
190 }
191
192 const FunctionLibraryDefinition& flib_;
193 };
194
195 // Stateful helper class to import a TensorFlow model into an MLIR Module.
196 //
197 // This is the base class that contains common utilities shared between the
198 // GraphDef importer and SavedModel importer.
199 //
200 // A subclass is expected to call `PrepareConvert` first to perform necessary
201 // preparation over the graph and also certain internal bookkeeping data.
202 // Afterwards the other protected methods can be called.
203 class ImporterBase {
204 protected:
ImporterBase(const FunctionLibraryDefinition & flib,const GraphDebugInfo & debug_info,const GraphImportConfig & specs,mlir::ModuleOp module,std::unordered_map<std::string,std::string> * tf_name_to_mlir_name,NameUniquifier * function_name_uniquifier,llvm::StringRef function_name_for_debug_info="")205 explicit ImporterBase(
206 const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info,
207 const GraphImportConfig& specs, mlir::ModuleOp module,
208 std::unordered_map<std::string, std::string>* tf_name_to_mlir_name,
209 NameUniquifier* function_name_uniquifier,
210 llvm::StringRef function_name_for_debug_info = "")
211 : builder_(module.getContext()),
212 module_(module),
213 context_(module.getContext()),
214 tf_name_to_mlir_name_(tf_name_to_mlir_name),
215 graph_flib_(flib),
216 specs_(specs),
217 debug_info_(debug_info),
218 function_name_for_debug_info_(function_name_for_debug_info),
219 function_name_uniquifier_(function_name_uniquifier),
220 error_handler_(module.getContext()) {
221 // Log import config.
222 if (VLOG_IS_ON(1)) {
223 LOG(INFO) << "Importing with: " << specs.str();
224 for (auto& it : *tf_name_to_mlir_name) {
225 LOG(INFO) << "\t" << it.first << " -> " << it.second;
226 }
227 }
228 }
229
230 // Returns the inferred function signature of the given function body. Input
231 // types are unranked tensor of the respective datatype in the function and
232 // result types are inferred by the shape_refiner_. Result types need not be
233 // unranked tensors and could be ranked tensors in cases where result type
234 // depends on an op with static output shape like tf.Const.
235 StatusOr<mlir::FunctionType> InferLibFunctionType(const FunctionBody& fbody);
236
237 // Extracts arg and ret nodes from FunctionBody.
238 void GetArgsAndRetsFromFunctionBody(
239 const FunctionBody& fbody,
240 absl::InlinedVector<OutputTensor, 4>* arg_nodes,
241 absl::InlinedVector<OutputTensor, 4>* ret_nodes,
242 absl::InlinedVector<Node*, 4>* control_ret_nodes);
243
244 // Prepares converting the graph to an MLIR module. This step removes the
245 // backedges of the graph, orders the nodes and infers the shapes.
246 // PrepareConvert needs to ensure that the original `graph` is cloned prior
247 // execution. The cloning procedure relies on the roundtrip through the
248 // GraphDef. Graph to GraphDef def conversion is heavy, in case, `graph_def`
249 // was obtained previously provide it to the PrepareConvert to reuse.
250 Status PrepareConvert(const Graph& graph,
251 std::unique_ptr<GraphDef> graph_def = nullptr);
252
253 // Converts the prepared graph to a Function and adds it to the module. A set
254 // of nodes from the graph are given to converted to the arguments and returns
255 // of the function.
256 Status Convert(llvm::StringRef func_name, mlir::FunctionType func_type,
257 const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
258 const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
259 const absl::InlinedVector<Node*, 4>& control_ret_nodes,
260 llvm::ArrayRef<mlir::NamedAttribute> attrs);
261
262 // Finds out the function definition for the given function name from the
263 // graph and converts it to a function of the module. This method is called
264 // on demand because the graph flib_def does not provide an iterator
265 // interface.
266 Status ConvertLibFunction(llvm::StringRef func_name);
267
268 // Returns the list of nodes in the graph. Nodes are presented in the reverse
269 // order of a post-order depth-first visit starting from the graph's source
270 // nodes.
GetOrderedNodes() const271 llvm::ArrayRef<Node*> GetOrderedNodes() const { return ordered_nodes_; }
272
273 // Returns the inferred input type at index `idx` of the `node` in the
274 // context.
275 StatusOr<mlir::Type> InferInputType(const Node& node, int idx,
276 mlir::Builder builder);
277
278 // Returns the inferred output type at index `idx` of the `node` in the
279 // context.
280 StatusOr<mlir::Type> InferOutputType(const Node& node, int idx,
281 mlir::Builder builder);
282
283 // Convert deferred TF functions to the MLIR representation.
284 // Conversion is deferred for efficiency reasons, e.g., to limit depth
285 // of recursion and reduce stack size pressure.
286 Status ConvertDeferredFunctions();
287
288 private:
289 // Most types with subtypes have only one subtype.
290 using ElementSubtypes = llvm::SmallVector<TensorType, 1>;
291
292 // Metadata used for deferred function conversion.
293 struct DeferredConversionMetaData {
DeferredConversionMetaDatatensorflow::__anon4b86074c0111::ImporterBase::DeferredConversionMetaData294 DeferredConversionMetaData(
295 const std::string& function_name,
296 const std::vector<mlir::NamedAttribute>& attributes)
297 : function_name(function_name), attributes(attributes) {}
298
299 std::string function_name;
300 std::vector<mlir::NamedAttribute> attributes;
301 };
302
303 // Adds all the ordered_nodes to the shape refiner shape_refiner_. Then all
304 // data type and shape information is maintained by the shape_refiner_.
305 // TODO(jpienaar): Remove once shape inference on import is removed.
306 Status AddNodesToShapeRefiner(
307 std::unordered_map<string, Node*>* node_name_map);
308
309 // Prune nodes that do not feed into fetch nodes.
310 Status PruneUnreachableNodes(
311 std::unordered_map<string, Node*>* node_name_map);
312
313 // Converts feeds to Placeholder nodes.
314 Status ConvertFeedsToPlaceholders(
315 std::unordered_map<string, Node*>* node_name_map);
316
317 // Converts the inferred shape referred to by 'handle' in 'context', with
318 // given element type, and returns an MLIR tensor type.
319 StatusOr<TensorType> ConvertDataTypeAndShape(
320 DataType dtype, const shape_inference::ShapeHandle& handle,
321 const std::vector<shape_inference::ShapeAndType>* handle_subtypes,
322 shape_inference::InferenceContext* context, mlir::Builder builder);
323
324 // Converts the inferred shape referred to by 'handle' in 'context', with
325 // given element type, and returns an MLIR tensor type.
326 StatusOr<TensorType> ConvertElementTypeAndShape(
327 mlir::Type element_type, const shape_inference::ShapeHandle& handle,
328 shape_inference::InferenceContext* context, mlir::Builder builder);
329
330 // Converts the inferred subtypes for an element type to corresponding MLIR
331 // types in 'context'.
332 StatusOr<ElementSubtypes> ConvertSubtypes(
333 const std::vector<shape_inference::ShapeAndType>* handle_subtypes,
334 shape_inference::InferenceContext* context, mlir::Builder builder);
335
336 // Converts the tensor proto into an MLIR elements attribute.
ConvertTensorProto(const TensorProto & value)337 StatusOr<mlir::ElementsAttr> ConvertTensorProto(const TensorProto& value) {
338 return ::tensorflow::ConvertTensorProto(value, &builder_);
339 }
340
341 // Converts func name in graphdef to mlir::SymbolRefAttribute.
342 StatusOr<mlir::FlatSymbolRefAttr> ConvertFunctionCallName(
343 const std::string& func_name);
344
345 // Converts the given non-function-call AttrValue to an MLIR Attribute.
346 StatusOr<mlir::Attribute> ConvertAttributeValue(const AttrValue& value);
347
348 // Converts the given function-call AttrValue to MLIR Attributes and pushes
349 // them to the given attributes list. For example, if there is a kFunc
350 // AttrValue {name : foo, attrs : {k1 : bar, k2 : rfc}}, it will convert it to
351 // a list of MLIR Attributes: [{base_name : foo}, {base_name.k1 : bar},
352 // {base_name.k2 : rfc}}.
353 Status ConvertFunctionCallAttribute(const std::string& base_name,
354 const AttrValue& value,
355 NamedAttrList* attributes);
356
357 // Helper to create either a tf_executor operation or a TF operation wrapped
358 // in an island.
359 mlir::Operation* CreateOperation(
360 const Node& node, llvm::StringRef node_type_name,
361 const mlir::OperationState& result,
362 const llvm::SmallVectorImpl<mlir::Value>& control_operands);
363
364 // Converts one NodeDef from the input GraphDef into an Operation and
365 // inserts it into the MLIR module using builder_.
366 Status ConvertNode(const Node& node);
367
368 // If the input graph represents a while-loop, the edges pointing from a
369 // "NextIteration" node to a "Merge" node add cyclic dependencies and make the
370 // topological sorting impossible. We need to remove these edges from the
371 // input graph to infer shapes and construct a Function. For each
372 // "NextIteration" node, there are two operations, "NextIteration.source"
373 // and "NextIteration.sink" are added to the MLIR module.
374 using BackEdge = BackEdgeHelper::BackEdge;
375
376 // Removes backedges from the input graph. The removed edges are added back to
377 // to OpBuilder after the remaining graph is converted to the Function.
378 Status RemoveBackedges();
379
380 // Restores backedges removed during shape inference to the final Function.
381 Status AddBackedges();
382
383 // Restores a single backedge in the Function by adding a replicated
384 // operation before the dst operation.
385 Status AddBackedge(mlir::Operation* sink, mlir::Operation* dst,
386 int dst_input);
387
388 // Adds the input arguments and return operation to the function. The
389 // arguments are added as basic block argument. Also the argument types and
390 // the id of the nodes from the input graph needs to be specified.
391 Status ConvertFunctionArgAndRets(
392 mlir::FuncOp func, mlir::tf_executor::GraphOp graph_op,
393 llvm::ArrayRef<mlir::Type> arg_types,
394 const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
395 const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
396 const absl::InlinedVector<Node*, 4>& control_ret_nodes);
397
398 // Gets the location information of the given node. It uses the
399 // "original_node_name" in the NodeDef to get the corresponding file location
400 // (FileLineColLoc) from the input DebugInfo and returns an CallSiteLoc. If
401 // there are multiple "original_node_names", a FusedLoc is returned. If the
402 // node name couldn't be found in the input DebugInfo, a NameLoc is used as
403 // the location.
404 mlir::Location GetLocation(const Node& node);
405
406 // Appends the location string for the node to the error message and returns
407 // the combined error status.
408 Status EmitErrorWithLocationStr(const Node& node, const Status& error_status);
409
410 // Inserts a placeholder node in the graph to replace a feed output tensor,
411 // and returns the new placeholder node and a boolean indicating if the
412 // original input node was removed from the graph. Uses of the feed output
413 // tensor are replaced with this placeholder node. If the feed output tensor
414 // is of a single output node, the control dependencies are forwarded to the
415 // the placeholder node, and the original node will be removed.
416 // Note: This modifies the graph, and so any list of ordered nodes needs to be
417 // reconstructed.
418 StatusOr<std::pair<Node*, bool>> CreatePlaceholderNodeForFeed(
419 const TensorShapeProto& shape, DataType dtype, Node* node, int index,
420 const std::unordered_map<string, Node*>& node_name_map);
421
422 // Gets the input and output nodes corresponding to the specified input and
423 // output nodes in specs_. If there are no input or output nodes specified,
424 // nodes will be empty.
425 Status GetInputOutputNodes(
426 const std::unordered_map<string, Node*>& node_name_map,
427 std::unordered_set<const Node*>* nodes);
428
GetUnmodelledOpTypes()429 llvm::StringSet<>& GetUnmodelledOpTypes() {
430 // All the TF ops encountered that aren't modelled in dialect.
431 static auto* unmodelled_op_types = new llvm::StringSet<>();
432 return *unmodelled_op_types;
433 }
434
435 // The input graph with backedges removed. The removed backedges are stored
436 // in the back_edge_helper.
437 BackEdgeHelper back_edge_helper_;
438 // A map between node and output index, for each backedge.
439 absl::flat_hash_map<const Node*, int> back_edge_node_output_;
440 absl::flat_hash_map<const Node*, BackEdge> back_edge_dst_inputs_;
441 // A map between sink and source operation of NextIteration
442 absl::flat_hash_map<mlir::Operation*, mlir::Operation*>
443 next_iteration_sink_source_;
444
445 // All nodes and version information about the (copied) imported graph.
446 std::unique_ptr<Graph> graph_;
447 std::vector<Node*> ordered_nodes_;
448
449 // Maps from a Node ID to a MLIR value.
450 using NodeValueMap = absl::flat_hash_map<int, mlir::Operation*>;
451
452 mlir::OpBuilder builder_;
453 mlir::ModuleOp module_;
454 mlir::MLIRContext* context_;
455 std::unordered_map<std::string, std::string>* tf_name_to_mlir_name_;
456 const FunctionLibraryDefinition& graph_flib_;
457 const GraphImportConfig& specs_;
458 const GraphDebugInfo& debug_info_;
459 llvm::StringRef function_name_for_debug_info_;
460 NodeValueMap node_values_;
461 // TODO(jpienaar): Remove once shape inference on import is removed.
462 // The shape_refinner_ will be nullptr if shape inference on import is
463 // not enabled.
464 std::unique_ptr<ShapeRefiner> shape_refiner_ = nullptr;
465 NameUniquifier* function_name_uniquifier_;
466 mlir::StatusScopedDiagnosticHandler error_handler_;
467
468 protected:
469 // Maps feed as TensorId to new Placeholder node name.
470 absl::flat_hash_map<TensorId, absl::string_view> remapped_feeds_;
471 // Keep track of functions required deferred conversion.
472 std::queue<DeferredConversionMetaData> deferred_functions_;
473 };
474
475 // Returns true if the node with given name has a non primary output that is
476 // used by some other node as an input. Returns false if no outputs are in use
477 // or only the first output is in use.
HasNonPrimaryOutputInUse(const GraphDef & graph_def,const std::string & node)478 bool HasNonPrimaryOutputInUse(const GraphDef& graph_def,
479 const std::string& node) {
480 for (const auto& node_def : graph_def.node()) {
481 for (const auto& input : node_def.input()) {
482 if (absl::StartsWith(input, node + ":") && input != node + ":0") {
483 return true;
484 }
485 }
486 }
487 return false;
488 }
489
490 // Updates the given LegacyFedInput node with Placeholder node if it is one of
491 // the inputs. Returns an error if non primary output of the LegacyFedInput node
492 // is in use and therefore can not be replaced by the Placeholder node that only
493 // has a single output.
UpdateLegacyFedInputNode(const GraphDef & graph_def,const GraphImportConfig::InputArrays & inputs,NodeDef * node)494 Status UpdateLegacyFedInputNode(const GraphDef& graph_def,
495 const GraphImportConfig::InputArrays& inputs,
496 NodeDef* node) {
497 const std::string& node_name = node->name();
498 auto it = inputs.find(node_name);
499
500 // Node is not an input.
501 if (it == inputs.end()) return Status::OK();
502
503 if (HasNonPrimaryOutputInUse(graph_def, node_name)) {
504 return errors::InvalidArgument(
505 "LegacyFedInput node ", node->name(),
506 " has non primary output in use and can not be replaced with "
507 "Placeholder node");
508 }
509
510 DataType dtype = it->second.imported_dtype;
511 // Uses the existing output type if it isn't specified by the user.
512 if (dtype == DT_INVALID) {
513 dtype = node->attr().at("output_types").list().type(0);
514 }
515 // Update op name, drop inputs and set attributes required by the Placeholder
516 // op.
517 *node->mutable_op() = "Placeholder";
518 node->clear_attr();
519 node->clear_input();
520 AddNodeAttr("dtype", dtype, node);
521 AddNodeAttr("shape", it->second.shape, node);
522 return Status::OK();
523 }
524
525 // Preprocesses GraphDef before it can be converted to Graph by,
526 // - Adding the default attributes to each node def if they are missing from
527 // the GraphDef.
528 // - Replacing LegacyFedInput nodes with Placeholder nodes if
529 // convert_legacy_fed_inputs option is enabled.
PreprocessGraphDef(const GraphImportConfig * specs,GraphDef * graph_def)530 Status PreprocessGraphDef(const GraphImportConfig* specs, GraphDef* graph_def) {
531 for (auto& node_def : *graph_def->mutable_node()) {
532 // TODO(hinsu): Completely deprecate support for LegacyFedInput ops. One
533 // solution could be have a tool to let users upgrade old serialized graphs.
534 if (specs && specs->convert_legacy_fed_inputs &&
535 node_def.op() == "LegacyFedInput") {
536 TF_RETURN_IF_ERROR(
537 UpdateLegacyFedInputNode(*graph_def, specs->inputs, &node_def));
538 }
539
540 const tensorflow::OpRegistrationData* op_reg_data =
541 tensorflow::OpRegistry::Global()->LookUp(node_def.op());
542 if (!op_reg_data) {
543 // This is likely a function call node, so we should continue.
544 continue;
545 }
546 ::tensorflow::AddDefaultsToNodeDef(op_reg_data->op_def, &node_def);
547 }
548 return Status::OK();
549 }
550
551 // Mapping from node name to feed (index and ArrayInfo). Node name must outlive
552 // this map.
553 using FeedsByNode = absl::flat_hash_map<
554 absl::string_view,
555 absl::flat_hash_map<int, const std::pair<std::string, ArrayInfo>*>>;
556
557 // Creates from a `GraphImportConfig::InputArrays` a mapping from a feeds output
558 // tensor name to index and ArrayInfo. Keys and values are backed by
559 // `GraphImportConfig::InputArrays`.
GetFeedsByNode(const GraphImportConfig::InputArrays & inputs)560 StatusOr<FeedsByNode> GetFeedsByNode(
561 const GraphImportConfig::InputArrays& inputs) {
562 FeedsByNode feeds_by_node;
563 feeds_by_node.reserve(inputs.size());
564
565 for (const auto& input : inputs) {
566 TensorId tensor = ParseTensorName(input.first);
567 if (tensor.index() < 0)
568 return errors::FailedPrecondition(
569 "Feed output tensor must be a data output '", tensor.ToString(), "'");
570
571 auto& node = feeds_by_node[tensor.node()];
572 if (!node.insert({tensor.index(), &input}).second)
573 return errors::FailedPrecondition(
574 "Multiple feeds for the same output tensor '", tensor.ToString(),
575 "'");
576 }
577
578 return feeds_by_node;
579 }
580
581 // Creates a unique name for a node that will be replacing a feed output tensor.
GetUniqueNodeName(absl::string_view node_name,int index,const std::unordered_map<string,Node * > & node_name_map)582 std::string GetUniqueNodeName(
583 absl::string_view node_name, int index,
584 const std::unordered_map<string, Node*>& node_name_map) {
585 std::string new_node_name_base = absl::StrCat(node_name, "_", index);
586 int count = 0;
587 std::string new_node_name = new_node_name_base;
588 while (node_name_map.find(new_node_name) != node_name_map.end()) {
589 new_node_name = absl::StrCat(new_node_name_base, "_", count++);
590 }
591 return new_node_name;
592 }
593
ConvertDeferredFunctions()594 Status ImporterBase::ConvertDeferredFunctions() {
595 while (!deferred_functions_.empty()) {
596 auto conversion_metadata = deferred_functions_.front();
597 deferred_functions_.pop();
598
599 const FunctionDef* func_def =
600 graph_flib_.Find(conversion_metadata.function_name);
601 // Converts the graph to an MLIR function and adds it to the module.
602 // We populate the NodeSpec so that all the _Arg ops get their shape
603 // added correctly.
604 GraphImportConfig specs;
605 specs.enable_shape_inference = specs_.enable_shape_inference;
606 for (const auto& name_and_value : func_def->attr()) {
607 if (name_and_value.first == "_input_shapes") {
608 auto& list = name_and_value.second.list();
609 auto& signature = func_def->signature();
610 // Some models have "_input_shapes" attribute, but with its value empty
611 if (list.shape_size() > 0 &&
612 list.shape_size() != signature.input_arg_size()) {
613 return errors::FailedPrecondition(
614 "Number of input arguments must be equal to the length of "
615 "_input_shapes attribute in function '",
616 StringRefToView(conversion_metadata.function_name), "'.");
617 }
618 for (int i = 0, e = signature.input_arg_size(); i < e; i++) {
619 auto& input_arg = signature.input_arg(i);
620 auto& array_info = specs.inputs[input_arg.name()];
621 array_info.imported_dtype = input_arg.type();
622 // set to unranked for empty "_input_shapes" attribute
623 if (list.shape_size() > 0)
624 array_info.shape = list.shape(i);
625 else
626 array_info.shape.set_unknown_rank(true);
627 }
628 }
629 }
630
631 ImporterBase importer(graph_flib_, debug_info_, specs, module_,
632 tf_name_to_mlir_name_, function_name_uniquifier_,
633 conversion_metadata.function_name);
634
635 std::unique_ptr<FunctionBody> fbody;
636 TF_RETURN_IF_ERROR(
637 FunctionDefToBodyHelper(*func_def, AttrSlice(), &graph_flib_, &fbody));
638 TF_RETURN_IF_ERROR(importer.PrepareConvert(*fbody->graph));
639
640 TF_ASSIGN_OR_RETURN(auto func_type, importer.InferLibFunctionType(*fbody));
641
642 absl::InlinedVector<OutputTensor, 4> arg_nodes;
643 absl::InlinedVector<OutputTensor, 4> ret_nodes;
644 absl::InlinedVector<Node*, 4> control_ret_nodes;
645 importer.GetArgsAndRetsFromFunctionBody(*fbody, &arg_nodes, &ret_nodes,
646 &control_ret_nodes);
647 const std::string& mlir_func_name =
648 (*tf_name_to_mlir_name_)[conversion_metadata.function_name];
649
650 TF_RETURN_IF_ERROR(importer.Convert(mlir_func_name, func_type, arg_nodes,
651 ret_nodes, control_ret_nodes,
652 conversion_metadata.attributes));
653
654 // Additional function bodies could be discovered during the deferred
655 // loading of the current function. Add them to the working queue.
656 while (!importer.deferred_functions_.empty()) {
657 deferred_functions_.push(importer.deferred_functions_.front());
658 importer.deferred_functions_.pop();
659 }
660 }
661
662 return Status::OK();
663 }
664
RemoveBackedges()665 Status ImporterBase::RemoveBackedges() {
666 // Remove all the backedges. So the nodes can be added to the shape refiner.
667 TF_RETURN_IF_ERROR(back_edge_helper_.Remove(graph_.get()));
668 VLOG(1) << "Found " << (back_edge_helper_.RemovedEdges().size())
669 << " backedges.";
670
671 // Creates a map for quickly identifying whether a node output is a backedge.
672 for (const auto& edge : back_edge_helper_.RemovedEdges()) {
673 if (back_edge_node_output_.find(edge.src) != back_edge_node_output_.end() &&
674 back_edge_node_output_[edge.src] != edge.src_output) {
675 return errors::FailedPrecondition(
676 "More than one of the src node outputs are backedges!");
677 }
678 back_edge_node_output_[edge.src] = edge.src_output;
679 // We expect a merge to receive a single backedge (multiple NextIteration
680 // nodes feeding into the same merge is unexpected here).
681 DCHECK(!back_edge_dst_inputs_.contains(edge.dst));
682 back_edge_dst_inputs_[edge.dst] = edge;
683 }
684
685 // Obtains a RPO ordering, using node names as a tiebreak for stable sorting.
686 GetReversePostOrder(
687 *graph_, &ordered_nodes_,
688 [](const Node* n1, const Node* n2) { return n1->name() < n2->name(); });
689 return Status::OK();
690 }
691
CopyStackTraces(const Graph & from,Graph * to)692 Status CopyStackTraces(const Graph& from, Graph* to) {
693 // Copy over the stack traces.
694 // TODO(jpienaar): This really shouldn't be needed, copying the Graph above
695 // and then needing these traversals is unfortunate.
696 std::unordered_map<string, Node*> node_map = from.BuildNodeNameIndex();
697 for (Node* node : to->nodes()) {
698 if (const Node* old_node = node_map[node->name()]) {
699 if (const std::shared_ptr<AbstractStackTrace>& stack =
700 old_node->GetStackTrace()) {
701 DVLOG(2) << "Stack for " << node->name() << " "
702 << old_node->GetStackTrace()->ToString(
703 AbstractStackTrace::TracePrintingOptions());
704 node->SetStackTrace(stack);
705 } else {
706 DVLOG(1) << "No stack for " << node->name() << " (" << node
707 << ") in Graph " << &from;
708 }
709 } else {
710 DVLOG(1) << "No stack for " << node->name() << " (" << node
711 << ") in Graph " << &from;
712 }
713 }
714
715 return Status::OK();
716 }
717
CreatePlaceholderNodeForFeed(const TensorShapeProto & shape,DataType dtype,Node * node,int index,const std::unordered_map<string,Node * > & node_name_map)718 StatusOr<std::pair<Node*, bool>> ImporterBase::CreatePlaceholderNodeForFeed(
719 const TensorShapeProto& shape, DataType dtype, Node* node, int index,
720 const std::unordered_map<string, Node*>& node_name_map) {
721 DCHECK_LT(index, node->num_outputs());
722 const bool update_inplace = node->num_outputs() == 1 && index == 0;
723 std::string new_node_name =
724 update_inplace ? node->name()
725 : GetUniqueNodeName(node->name(), index, node_name_map);
726
727 Node* placeholder_node;
728 NodeBuilder builder(new_node_name, "Placeholder");
729 builder.Attr("shape", shape);
730 builder.Attr("dtype", dtype);
731 TF_RETURN_IF_ERROR(builder.Finalize(graph_.get(), &placeholder_node));
732
733 // Update edges from original feed with Placeholder node.
734 std::vector<const Edge*> data_edges;
735 std::vector<const Edge*> control_edges;
736 for (const tensorflow::Edge* edge : node->out_edges()) {
737 if (edge->src_output() == index) {
738 data_edges.push_back(edge);
739 } else if (update_inplace && edge->IsControlEdge()) {
740 control_edges.push_back(edge);
741 }
742 }
743
744 for (const auto* edge : data_edges) {
745 TF_RETURN_IF_ERROR(graph_->UpdateEdge(placeholder_node, 0, edge->dst(),
746 edge->dst_input()));
747 }
748
749 // TODO(lyandy): Preserve control dependencies properly by not forwarding
750 // control dependencies to data outputs and not removing single output nodes.
751 // When a data output is replaced as a feed, unless there is another non feed
752 // data output or an explicit control output used by the same node, transitive
753 // control dependencies are not to be executed. For single output nodes,
754 // Placeholders can be converted to a NoOp if there are no uses, and
755 // PlaceholderWithDefault can be converted to an Identity.
756 for (const auto* edge : control_edges) {
757 graph_->AddControlEdge(placeholder_node, edge->dst());
758 graph_->RemoveControlEdge(edge);
759 }
760
761 if (update_inplace) {
762 graph_->RemoveNode(node);
763 }
764
765 return std::pair<Node*, bool>(placeholder_node, update_inplace);
766 }
767
GetInputOutputNodes(const std::unordered_map<string,Node * > & node_name_map,std::unordered_set<const Node * > * nodes)768 Status ImporterBase::GetInputOutputNodes(
769 const std::unordered_map<string, Node*>& node_name_map,
770 std::unordered_set<const Node*>* nodes) {
771 auto add_node = [&](absl::string_view name) {
772 auto it = node_name_map.find(std::string(name));
773 if (it == node_name_map.end()) {
774 return errors::FailedPrecondition(
775 absl::StrCat("Graph does not contain node: ", name));
776 }
777 nodes->insert(it->second);
778 return Status::OK();
779 };
780
781 // Remap feeds and fetches to newly created Placeholder nodes.
782 for (const auto& input : specs_.inputs) {
783 TensorId tensor = ParseTensorName(input.first);
784 auto remapped_it = remapped_feeds_.find(tensor);
785 if (remapped_it != remapped_feeds_.end()) {
786 TF_RETURN_IF_ERROR(add_node(remapped_it->second));
787 } else {
788 TF_RETURN_IF_ERROR(add_node(tensor.node()));
789 }
790 }
791
792 for (const auto& output : specs_.outputs) {
793 TensorId tensor = ParseTensorName(output);
794 auto remapped_it = remapped_feeds_.find(tensor);
795 if (remapped_it != remapped_feeds_.end()) {
796 TF_RETURN_IF_ERROR(add_node(remapped_it->second));
797 } else {
798 TF_RETURN_IF_ERROR(add_node(tensor.node()));
799 }
800 }
801
802 for (const auto& control_output : specs_.control_outputs)
803 TF_RETURN_IF_ERROR(add_node(control_output));
804
805 return Status::OK();
806 }
807
808 // TODO(jpienaar): Remove this post shape inference on import flag is removed.
AddNodesToShapeRefiner(std::unordered_map<string,Node * > * node_name_map)809 Status ImporterBase::AddNodesToShapeRefiner(
810 std::unordered_map<string, Node*>* node_name_map) {
811 shape_refiner_ = absl::make_unique<ShapeRefiner>(graph_->versions(),
812 graph_->op_registry());
813 // Some operations (for example "TPUExecute") don't have shape inference
814 // function defined, so we should set this to false for adding nodes with
815 // these types of operations.
816 shape_refiner_->set_require_shape_inference_fns(false);
817 shape_refiner_->set_function_library_for_shape_inference(&graph_flib_);
818
819 TF_ASSIGN_OR_RETURN(auto feeds_by_node, GetFeedsByNode(specs_.inputs));
820
821 // First add all nodes to the refiner.
822 for (Node* node : ordered_nodes_) {
823 // We need to use a TensorFlow node to teach the shape refiner that user
824 // specifies certain data type and shape for the inputs in the `specs_`.
825 // This node shouldn't have any inputs, only have one output and its
826 // output type/shape is only determined by its "named" attributes. (The
827 // attributes should have fixed names so we can use the info from `specs_`
828 // to set the value of them.) `Placeholder` satisfies these constraints.
829 //
830 // Therefore, if the input node isn't a `Placeholder`, we create one and use
831 // it to replace the original input node, so the shape refiner can
832 // successfully propagate the user's input type and shape to the rest of the
833 // graph.
834 bool node_added_to_shape_refiner = false;
835 auto it = feeds_by_node.find(node->name());
836 if (it != feeds_by_node.end()) {
837 auto op_name = node->op_def().name();
838 if (op_name != "Placeholder" && op_name != "LegacyFedInput" &&
839 op_name != FunctionLibraryDefinition::kArgOp) {
840 for (const auto& output_tensor : it->second) {
841 const int index = output_tensor.first;
842 const ArrayInfo& array_info = output_tensor.second->second;
843
844 DataType dtype = array_info.imported_dtype;
845 // Uses the existing output type if it isn't specified by the user.
846 if (dtype == DT_INVALID) {
847 dtype = node->output_type(index);
848 }
849
850 TF_ASSIGN_OR_RETURN(
851 auto placeholder_node_and_removed,
852 CreatePlaceholderNodeForFeed(array_info.shape, dtype, node, index,
853 *node_name_map));
854
855 Node* placeholder_node = placeholder_node_and_removed.first;
856 if (placeholder_node_and_removed.second) {
857 // Original node has been removed from the graph.
858 node = placeholder_node;
859 node_added_to_shape_refiner = true;
860 }
861 remapped_feeds_[{it->first, index}] = placeholder_node->name();
862 (*node_name_map)[placeholder_node->name()] = placeholder_node;
863 // Add the new placeholder node to the shape refiner.
864 Status status = shape_refiner_->AddNode(placeholder_node);
865 if (!status.ok()) {
866 return EmitErrorWithLocationStr(*placeholder_node, status);
867 }
868 }
869 } else {
870 auto index_it = it->second.find(0);
871 if (index_it == it->second.end()) {
872 return errors::FailedPrecondition(
873 "Missing feed output tensor at index 0 for node '", node->name(),
874 "'");
875 }
876 node->AddAttr("shape", index_it->second->second.shape);
877 DataType dtype = index_it->second->second.imported_dtype;
878 // Uses the existing output type if it isn't specified by the user.
879 if (dtype == DT_INVALID) {
880 dtype = node->output_type(0);
881 }
882 node->AddAttr("dtype", dtype);
883 }
884 }
885 if (!node_added_to_shape_refiner) {
886 // Add the node to the shape refiner if the node hasn't been removed.
887 Status status = shape_refiner_->AddNode(node);
888 if (!status.ok()) {
889 return EmitErrorWithLocationStr(*node, status);
890 }
891 }
892
893 auto set_shape_from_list_attr = [&](const AttrValue* attr) {
894 auto& list = attr->list();
895 for (auto shape : llvm::enumerate(list.shape())) {
896 auto* node_context = shape_refiner_->GetContext(node);
897 shape_inference::ShapeHandle handle;
898 Status status =
899 node_context->MakeShapeFromShapeProto(shape.value(), &handle);
900 if (!status.ok()) {
901 return EmitErrorWithLocationStr(*node, status);
902 }
903 node_context->set_output(shape.index(), handle);
904 }
905 return Status::OK();
906 };
907
908 // We currently have no other way to get shapes from ReadVariableOp's.
909 // Some graphs seem to have _output_shapes attributes on them, so use that
910 // if possible.
911 // Note: _output_shapes are optionally set when the user exports the graph
912 // and it is not guaranteed (nor an error if missing). There is not a
913 // promised contract, so effectively a heuristic.
914 if (node->op_def().name() == "ReadVariableOp") {
915 if (const AttrValue* attr = node->attrs().Find("_output_shapes"))
916 TF_RETURN_IF_ERROR(set_shape_from_list_attr(attr));
917 }
918
919 // If it is the argument node, the shape handle is set explicitly, so it
920 // can be propagated to the body nodes of the function.
921 if (StringPiece(node->type_string()) == FunctionLibraryDefinition::kArgOp) {
922 auto* node_context = shape_refiner_->GetContext(node);
923 DCHECK(node_context != nullptr);
924 if (const AttrValue* attr = node->attrs().Find("shape")) {
925 shape_inference::ShapeHandle handle;
926 Status status =
927 node_context->MakeShapeFromShapeProto(attr->shape(), &handle);
928 if (!status.ok()) {
929 return EmitErrorWithLocationStr(*node, status);
930 }
931 node_context->set_output(0, handle);
932 } else if (const AttrValue* attr = node->attrs().Find("_output_shapes")) {
933 TF_RETURN_IF_ERROR(set_shape_from_list_attr(attr));
934 } else {
935 node_context->set_output(0, node_context->UnknownShape());
936 }
937 }
938 }
939
940 // Since we might have inserted and removed nodes from the graph, fix
941 // source/sink edges and reconstruct the RPO ordering of nodes
942 FixupSourceAndSinkEdges(graph_.get());
943
944 // Prune nodes in the graph that are not reachable from the output.
945 if (specs_.prune_unused_nodes) {
946 std::unordered_set<const Node*> prune_start;
947 TF_RETURN_IF_ERROR(GetInputOutputNodes(*node_name_map, &prune_start));
948 if (!prune_start.empty()) {
949 if (PruneForReverseReachability(graph_.get(), prune_start)) {
950 VLOG(1) << "Pruned unused nodes in graphdef";
951 } else {
952 VLOG(1) << "No unused nodes in graphdef to prune";
953 }
954 } else {
955 VLOG(1) << "No output nodes specified, skipping pruning";
956 }
957 } else {
958 VLOG(1) << "Pruning unused nodes in graphdef is disabled";
959 }
960
961 // Re-initialize ordered_nodes_ since we might have modified the graph.
962 GetReversePostOrder(
963 *graph_, &ordered_nodes_,
964 [](const Node* n1, const Node* n2) { return n1->name() < n2->name(); });
965
966 VLOG(1) << "Inferring graph shapes to fixpoint";
967
968 // The "changed" information from UpdateNode can give false positives, so we
969 // create a dedicated method to verify the shapes are not changed before and
970 // after the shape refine.
971 auto same_inferred_shape = [](shape_inference::InferenceContext* c,
972 shape_inference::ShapeHandle s0,
973 shape_inference::ShapeHandle s1) -> bool {
974 if (s0.SameHandle(s1) || (!c->RankKnown(s0) && !c->RankKnown(s1))) {
975 return true;
976 }
977 if (c->Rank(s0) != c->Rank(s1)) {
978 return false;
979 }
980 for (int i = 0; i < c->Rank(s0); ++i) {
981 if (!c->Dim(s0, i).SameHandle(c->Dim(s1, i))) {
982 int64_t val0 = c->Value(c->Dim(s0, i));
983 int64_t val1 = c->Value(c->Dim(s1, i));
984 // Negative value is treated as unknown so all negative values indicate
985 // the same dimension.
986 if (val0 >= 0 && val1 >= 0 && val0 != val1) return false;
987 }
988 }
989 return true;
990 };
991
992 bool changed = true;
993 int i = 0;
994 const int kMaxIterationCount = 2;
995 while (changed && i != kMaxIterationCount) {
996 changed = false;
997 for (const Node* node : ordered_nodes_) {
998 auto* shape_context = shape_refiner_->GetContext(node);
999 DCHECK(shape_context != nullptr);
1000 absl::InlinedVector<shape_inference::ShapeHandle, 4> existing;
1001 existing.reserve(shape_context->num_outputs());
1002 for (int o = 0; o < shape_context->num_outputs(); ++o) {
1003 existing.push_back(shape_context->output(o));
1004 }
1005 bool inferred = false;
1006 shape_inference::ShapeHandle handle;
1007 Status status =
1008 shape_refiner_->UpdateNode(node, /*relax=*/false, &inferred);
1009 if (!status.ok()) {
1010 return EmitErrorWithLocationStr(*node, status);
1011 }
1012 for (int o = 0; o < shape_context->num_outputs(); ++o) {
1013 if (!same_inferred_shape(shape_context, shape_context->output(o),
1014 existing[o])) {
1015 changed = true;
1016 break;
1017 }
1018 }
1019 }
1020 ++i;
1021 }
1022 if (i >= kMaxIterationCount) {
1023 LOG(WARNING) << "Graph shapes did not converge to a fixpoint within "
1024 << kMaxIterationCount
1025 << " iterations. Graph shapes may be conservative.";
1026 }
1027 VLOG(1) << "Graph shapes were inferred with " << (i - 1)
1028 << " extra rounds of analysis to reach a fixpoint.";
1029 return Status::OK();
1030 }
1031
InferInputType(const Node & node,int idx,mlir::Builder builder)1032 StatusOr<mlir::Type> ImporterBase::InferInputType(const Node& node, int idx,
1033 mlir::Builder builder) {
1034 if (specs_.enable_shape_inference) {
1035 // TODO(jpienaar): Remove this if shape inference on import flag is removed.
1036 ExtendedInferenceContext* shape_context =
1037 shape_refiner_->GetExtendedContext(&node);
1038 DataType dtype = shape_context->input_type(idx);
1039 auto* context = shape_context->get_context();
1040 return ConvertDataTypeAndShape(dtype, context->input(idx),
1041 context->input_handle_shapes_and_types(idx),
1042 context, builder);
1043 }
1044 DataType dtype = node.properties()->input_types[idx];
1045 mlir::Type element_type;
1046 TF_RETURN_IF_ERROR(ConvertDataType(dtype, builder, &element_type));
1047 return mlir::UnrankedTensorType::get(element_type);
1048 }
1049
InferOutputType(const Node & node,int idx,mlir::Builder builder)1050 StatusOr<mlir::Type> ImporterBase::InferOutputType(const Node& node, int idx,
1051 mlir::Builder builder) {
1052 DataType dtype = node.properties()->output_types[idx];
1053
1054 // Returns output type given inference context.
1055 auto shape_ic = [&](shape_inference::InferenceContext* c) {
1056 return ConvertDataTypeAndShape(dtype, c->output(idx),
1057 c->output_handle_shapes_and_types(idx), c,
1058 builder);
1059 };
1060
1061 if (specs_.enable_shape_inference) {
1062 // TODO(jpienaar): Remove this if shape inference on import flag is removed.
1063 ExtendedInferenceContext* shape_context =
1064 shape_refiner_->GetExtendedContext(&node);
1065 return shape_ic(shape_context->get_context());
1066 }
1067
1068 // Treat TensorList init ops specially here as the op requires knowing its
1069 // element dtype.
1070 // TODO(jpienaar): Reconsider post refactoring shape functions.
1071 if (node.type_string() == "TensorListReserve" ||
1072 node.type_string() == "EmptyTensorList") {
1073 mlir::Type etype;
1074 if (auto element_dtype = node.attrs().Find("element_dtype")) {
1075 TF_RETURN_IF_ERROR(
1076 ConvertDataType(element_dtype->type(), builder, &etype));
1077 }
1078 return mlir::RankedTensorType::get(
1079 {}, mlir::TF::VariantType::get({mlir::UnrankedTensorType::get(etype)},
1080 etype.getContext()));
1081 }
1082
1083 if (node.IsWhileNode()) {
1084 auto* output_shapes = node.attrs().Find("output_shapes");
1085 auto* element_types = node.attrs().Find("T");
1086 if (output_shapes && !output_shapes->list().shape().empty()) {
1087 const auto& output_shape = output_shapes->list().shape(idx);
1088 const auto& element_type = element_types->list().type(idx);
1089 return ConvertToMlirTensorType(output_shape, element_type, &builder);
1090 }
1091 }
1092
1093 auto type_from_array_attr = [&node, &idx, &builder](
1094 absl::string_view output_shape_attr,
1095 absl::string_view element_type_attr) {
1096 auto* output_shapes = node.attrs().Find(output_shape_attr);
1097 auto* element_types = node.attrs().Find(element_type_attr);
1098 const auto& output_shape = output_shapes->list().shape(idx);
1099 const auto& element_type = element_types->list().type(idx);
1100 return ConvertToMlirTensorType(output_shape, element_type, &builder);
1101 };
1102
1103 if (node.type_string() == "IteratorGetNext" ||
1104 node.type_string() == "IteratorGetNextSync" ||
1105 node.type_string() == "MultiDeviceIteratorGetNextFromShard")
1106 return type_from_array_attr("output_shapes", "output_types");
1107
1108 if (node.type_string() == "InfeedDequeueTuple")
1109 return type_from_array_attr("shapes", "dtypes");
1110
1111 if (node.type_string() == "InfeedDequeue") {
1112 assert(idx == 0);
1113 const auto& output_shape = node.attrs().Find("shape")->shape();
1114 const auto& element_type = node.attrs().Find("dtype")->type();
1115 return ConvertToMlirTensorType(output_shape, element_type, &builder);
1116 }
1117
1118 // Returns a simple, more conservative unranked tensor type.
1119 auto default_type = [&]() -> StatusOr<mlir::Type> {
1120 mlir::Type element_type;
1121 TF_RETURN_IF_ERROR(ConvertDataType(dtype, builder, &element_type));
1122 return mlir::UnrankedTensorType::get(element_type);
1123 };
1124
1125 // Below we only try and do some shape inference for "source" ops which have
1126 // no inputs.
1127 if (node.num_inputs() > 0) return default_type();
1128
1129 // Do some simply inference here to get the function arguments correct for
1130 // this common case.
1131 // TODO(jpienaar): Reconsider post refactoring shape functions.
1132 if (node.IsArg()) {
1133 if (dtype == DT_RESOURCE) {
1134 const AttrValue* dtype_attr = node.attrs().Find("_handle_dtypes");
1135 const AttrValue* shape_attr = node.attrs().Find("_handle_shapes");
1136 if (dtype_attr && shape_attr) {
1137 if (dtype_attr->list().type().empty()) {
1138 return errors::InvalidArgument(
1139 "Invalid \"_handle_dtypes\" attribute value for _Arg node: ",
1140 shape_attr->DebugString());
1141 }
1142 if (shape_attr->list().shape().empty()) {
1143 return errors::InvalidArgument(
1144 "Invalid \"_handle_shapes\" attribute value for _Arg node: ",
1145 shape_attr->DebugString());
1146 }
1147 DataType dtype = dtype_attr->list().type(0);
1148 const TensorShapeProto& shape_proto = shape_attr->list().shape(0);
1149 TF_ASSIGN_OR_RETURN(
1150 auto etype, ConvertToMlirTensorType(shape_proto, dtype, &builder));
1151 return mlir::UnrankedTensorType::get(mlir::TF::ResourceType::get(
1152 {etype.cast<TensorType>()}, builder.getContext()));
1153 } else {
1154 return mlir::UnrankedTensorType::get(
1155 mlir::TF::ResourceType::get(builder.getContext()));
1156 }
1157 } else if (auto shape = node.attrs().Find("_output_shapes")) {
1158 if (shape->has_list() && shape->list().shape_size() == 1) {
1159 return ConvertToMlirTensorType(shape->list().shape().at(0), dtype,
1160 &builder);
1161 }
1162 }
1163 }
1164
1165 const tensorflow::OpRegistrationData* op_reg_data;
1166 TF_RETURN_IF_ERROR(
1167 graph_->op_registry()->LookUp(node.type_string(), &op_reg_data));
1168 if (!op_reg_data) {
1169 DVLOG(1) << "Skipping inference for unregistered op " << node.type_string();
1170 return default_type();
1171 }
1172 if (op_reg_data->shape_inference_fn == nullptr) {
1173 DVLOG(1) << "Skipping inference for op without shape function "
1174 << node.type_string();
1175 return default_type();
1176 }
1177 shape_inference::InferenceContext c(graph_->versions().producer(),
1178 node.attrs(), op_reg_data->op_def,
1179 std::vector<PartialTensorShape>{}, {},
1180 /*input_tensors_as_shapes=*/{}, {});
1181 TF_RETURN_IF_ERROR(c.Run(op_reg_data->shape_inference_fn));
1182 return shape_ic(&c);
1183 }
1184
ConvertDataTypeAndShape(DataType dtype,const shape_inference::ShapeHandle & handle,const std::vector<shape_inference::ShapeAndType> * handle_subtypes,shape_inference::InferenceContext * context,mlir::Builder builder)1185 StatusOr<TensorType> ImporterBase::ConvertDataTypeAndShape(
1186 DataType dtype, const shape_inference::ShapeHandle& handle,
1187 const std::vector<shape_inference::ShapeAndType>* handle_subtypes,
1188 shape_inference::InferenceContext* context, mlir::Builder builder) {
1189 TF_ASSIGN_OR_RETURN(auto subtypes,
1190 ConvertSubtypes(handle_subtypes, context, builder));
1191
1192 mlir::Type element_type;
1193 if (dtype == DT_VARIANT)
1194 element_type = mlir::TF::VariantType::get(subtypes, context_);
1195 else if (dtype == DT_RESOURCE)
1196 element_type = mlir::TF::ResourceType::get(subtypes, context_);
1197 else
1198 TF_RETURN_IF_ERROR(
1199 ::tensorflow::ConvertDataType(dtype, builder, &element_type));
1200
1201 return ConvertElementTypeAndShape(element_type, handle, context, builder);
1202 }
1203
ConvertElementTypeAndShape(mlir::Type element_type,const shape_inference::ShapeHandle & handle,shape_inference::InferenceContext * context,mlir::Builder builder)1204 StatusOr<TensorType> ImporterBase::ConvertElementTypeAndShape(
1205 mlir::Type element_type, const shape_inference::ShapeHandle& handle,
1206 shape_inference::InferenceContext* context, mlir::Builder builder) {
1207 if (!context->RankKnown(handle)) {
1208 return mlir::UnrankedTensorType::get(element_type);
1209 }
1210
1211 // Sentinel for an unknown dimension size. getTensorType interprets any
1212 // negative value as an unknown dimension.
1213 // TODO(jmolloy): Ideally this shouldn't be a local sentinel.
1214 const int64_t kUnknownDim = -1;
1215
1216 absl::InlinedVector<int64_t, 4> dimensions;
1217 int32_t rank = context->Rank(handle);
1218 dimensions.reserve(rank);
1219 for (int i = 0; i < rank; ++i) {
1220 auto dim_handle = context->Dim(handle, i);
1221 if (!context->ValueKnown(dim_handle))
1222 dimensions.push_back(kUnknownDim);
1223 else
1224 dimensions.push_back(context->Value(dim_handle));
1225 }
1226
1227 return mlir::RankedTensorType::get(
1228 llvm::makeArrayRef(dimensions.begin(), dimensions.end()), element_type);
1229 }
1230
ConvertSubtypes(const std::vector<shape_inference::ShapeAndType> * handle_subtypes,shape_inference::InferenceContext * context,mlir::Builder builder)1231 StatusOr<ImporterBase::ElementSubtypes> ImporterBase::ConvertSubtypes(
1232 const std::vector<shape_inference::ShapeAndType>* handle_subtypes,
1233 shape_inference::InferenceContext* context, mlir::Builder builder) {
1234 ElementSubtypes subtypes;
1235 if (!handle_subtypes) return subtypes;
1236
1237 subtypes.reserve(handle_subtypes->size());
1238 for (const auto& subtype : *handle_subtypes) {
1239 mlir::Type element_type;
1240 TF_RETURN_IF_ERROR(
1241 ::tensorflow::ConvertDataType(subtype.dtype, builder, &element_type));
1242 TF_ASSIGN_OR_RETURN(TensorType type,
1243 ConvertElementTypeAndShape(element_type, subtype.shape,
1244 context, builder));
1245 subtypes.push_back(type);
1246 }
1247 return subtypes;
1248 }
1249
ConvertFunctionCallAttribute(const std::string & base_name,const AttrValue & value,NamedAttrList * attributes)1250 Status ImporterBase::ConvertFunctionCallAttribute(const std::string& base_name,
1251 const AttrValue& value,
1252 NamedAttrList* attributes) {
1253 TF_ASSIGN_OR_RETURN(auto func_attr,
1254 ConvertFunctionCallName(value.func().name()));
1255 if (!func_attr) return Status::OK();
1256 attributes->push_back(builder_.getNamedAttr(base_name, func_attr));
1257
1258 for (const auto& it : value.func().attr()) {
1259 auto name = absl::StrCat(base_name, ".", it.first);
1260 TF_ASSIGN_OR_RETURN(auto value, ConvertAttributeValue(it.second));
1261 attributes->push_back(builder_.getNamedAttr(name, value));
1262 }
1263 return Status::OK();
1264 }
1265
ConvertFunctionCallName(const std::string & func_name)1266 StatusOr<mlir::FlatSymbolRefAttr> ImporterBase::ConvertFunctionCallName(
1267 const std::string& func_name) {
1268 // Some ops like XlaHostCompute op uses empty value to represent missing
1269 // functions. Such attribute values should be defined optional in MLIR
1270 // definition.
1271 if (func_name.empty()) return mlir::FlatSymbolRefAttr();
1272
1273 TF_RETURN_IF_ERROR(ConvertLibFunction(func_name));
1274 auto mlir_func_name = (*tf_name_to_mlir_name_)[func_name];
1275 return builder_.getSymbolRefAttr(mlir_func_name);
1276 }
1277
ConvertAttributeValue(const AttrValue & value)1278 StatusOr<mlir::Attribute> ImporterBase::ConvertAttributeValue(
1279 const AttrValue& value) {
1280 switch (value.value_case()) {
1281 case AttrValue::kFunc: {
1282 // TODO(b/156546237): Unify kFunc/NameAttrList attribute representation.
1283 // Currently kFunc/NameAttrList attributes in a kList/repeated AttrValue
1284 // will not use this representation. This also doesn't handle empty
1285 // function values like ConvertFunctionCallName method.
1286 NamedAttrList attrs;
1287 for (const auto& func_attr : value.func().attr()) {
1288 TF_ASSIGN_OR_RETURN(
1289 auto attr, ImporterBase::ConvertAttributeValue(func_attr.second));
1290 attrs.push_back(builder_.getNamedAttr(func_attr.first, attr));
1291 }
1292 auto func_attrs = builder_.getDictionaryAttr(attrs);
1293 return mlir::TF::FuncAttr::get(context_, value.func().name(), func_attrs);
1294 }
1295 case AttrValue::kList: {
1296 if (!value.list().func().empty()) {
1297 absl::InlinedVector<mlir::Attribute, 8> attrs;
1298 for (const auto& item : value.list().func()) {
1299 TF_ASSIGN_OR_RETURN(auto attr, ConvertFunctionCallName(item.name()));
1300 if (item.attr_size() != 0)
1301 return errors::Unimplemented(
1302 "func attributes with non-zero attr.size()");
1303 if (attr) attrs.push_back(attr);
1304 }
1305 return builder_.getArrayAttr(
1306 llvm::makeArrayRef(attrs.begin(), attrs.end()));
1307 }
1308 return ConvertNonFuncAttributeValue(value, &builder_);
1309 }
1310 default:
1311 return ConvertNonFuncAttributeValue(value, &builder_);
1312 }
1313 }
1314
GetArgsAndRetsFromFunctionBody(const FunctionBody & fbody,absl::InlinedVector<OutputTensor,4> * arg_nodes,absl::InlinedVector<OutputTensor,4> * ret_nodes,absl::InlinedVector<Node *,4> * control_ret_nodes)1315 void ImporterBase::GetArgsAndRetsFromFunctionBody(
1316 const FunctionBody& fbody, absl::InlinedVector<OutputTensor, 4>* arg_nodes,
1317 absl::InlinedVector<OutputTensor, 4>* ret_nodes,
1318 absl::InlinedVector<Node*, 4>* control_ret_nodes) {
1319 arg_nodes->reserve(fbody.arg_nodes.size());
1320 ret_nodes->reserve(fbody.ret_nodes.size());
1321 for (auto arg : fbody.arg_nodes) {
1322 arg_nodes->emplace_back(arg, 0);
1323 }
1324 for (auto ret : fbody.ret_nodes) {
1325 ret_nodes->emplace_back(ret, 0);
1326 }
1327 *control_ret_nodes = fbody.control_ret_nodes;
1328 }
1329
ConvertLibFunction(llvm::StringRef func_name)1330 Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) {
1331 // If the library function has been converted already, nothing needs to be
1332 // done.
1333 if (tf_name_to_mlir_name_->find(std::string(func_name)) !=
1334 tf_name_to_mlir_name_->end())
1335 return Status::OK();
1336
1337 std::string mlir_func_name(
1338 function_name_uniquifier_->GetUniqueName(func_name));
1339 (*tf_name_to_mlir_name_)[std::string(func_name)] = mlir_func_name;
1340
1341 const auto& func_lib = graph_flib_;
1342 const auto* func_def = func_lib.Find(std::string(func_name));
1343 if (func_def == nullptr) {
1344 return errors::FailedPrecondition(
1345 absl::StrCat("Failed to find function '", StringRefToView(func_name),
1346 "'. The imported TensorFlow GraphDef is ill-formed."));
1347 }
1348
1349 // Converts the argument and return types to MLIR types.
1350 std::vector<mlir::NamedAttribute> attributes;
1351 attributes.reserve(func_def->attr_size());
1352 for (const auto& name_and_value : func_def->attr()) {
1353 // This is a function definition attribute, so it shouldn't contain
1354 // kFunc attribute and it is treated as normal one.
1355 TF_ASSIGN_OR_RETURN(auto attr,
1356 ConvertAttributeValue(name_and_value.second));
1357 std::string attr_name =
1358 mangling_util::MangleAttributeName(name_and_value.first);
1359 attributes.push_back(builder_.getNamedAttr(attr_name, attr));
1360 }
1361
1362 // Checks opdef stateful attribute and import that as Function Attribute
1363 if (func_def->signature().is_stateful()) {
1364 auto stateful_str = mlir::TF::TensorFlowDialect::GetStatefulAttrName();
1365 attributes.push_back(
1366 builder_.getNamedAttr(stateful_str, builder_.getUnitAttr()));
1367 }
1368
1369 // Checks for an associated custom gradient function. Adds it to the attribute
1370 // list of this function.
1371 auto grad_func_name = func_lib.FindGradient(std::string(func_name));
1372 if (!grad_func_name.empty()) {
1373 TF_RETURN_IF_ERROR(ConvertLibFunction(grad_func_name));
1374 auto mlir_grad_func_name = (*tf_name_to_mlir_name_)[grad_func_name];
1375 auto gradient_attr = builder_.getSymbolRefAttr(mlir_grad_func_name);
1376 auto grad_string = mlir::TF::TensorFlowDialect::GetGradientAttrName();
1377 attributes.push_back(builder_.getNamedAttr(grad_string, gradient_attr));
1378 }
1379
1380 deferred_functions_.emplace(func_name.str(), attributes);
1381 return Status::OK();
1382 }
1383
PruneUnreachableNodes(std::unordered_map<string,Node * > * node_name_map)1384 Status ImporterBase::PruneUnreachableNodes(
1385 std::unordered_map<string, Node*>* node_name_map) {
1386 std::unordered_set<const Node*> prune_start;
1387 TF_RETURN_IF_ERROR(GetInputOutputNodes(*node_name_map, &prune_start));
1388
1389 if (!prune_start.empty()) {
1390 if (PruneForReverseReachability(graph_.get(), prune_start)) {
1391 VLOG(1) << "Pruned unused nodes in graphdef";
1392 } else {
1393 VLOG(1) << "No unused nodes in graphdef to prune";
1394 }
1395 } else {
1396 VLOG(1) << "No output nodes specified, skipping pruning";
1397 }
1398 return Status::OK();
1399 }
1400
ConvertFeedsToPlaceholders(std::unordered_map<string,Node * > * node_name_map)1401 Status ImporterBase::ConvertFeedsToPlaceholders(
1402 std::unordered_map<string, Node*>* node_name_map) {
1403 // Feeds (edges) are converted into single-output placeholder nodes to
1404 // simplify the conversion process.
1405 TF_ASSIGN_OR_RETURN(auto feeds_by_node, GetFeedsByNode(specs_.inputs));
1406 for (const auto& it : feeds_by_node) {
1407 TensorId tensor = ParseTensorName(it.first);
1408 auto jt = node_name_map->find(std::string(tensor.node()));
1409 if (jt == node_name_map->end()) {
1410 return errors::FailedPrecondition(
1411 absl::StrCat("Graph does not contain node: ", tensor.node()));
1412 }
1413
1414 Node* node = jt->second;
1415 auto op_name = node->op_def().name();
1416 if (op_name != "Placeholder" && op_name != "LegacyFedInput" &&
1417 op_name != FunctionLibraryDefinition::kArgOp) {
1418 for (const auto& output_tensor : it.second) {
1419 const int index = output_tensor.first;
1420 const ArrayInfo& array_info = output_tensor.second->second;
1421
1422 DataType dtype = array_info.imported_dtype;
1423 // Uses the existing output type if it isn't specified by the user.
1424 if (dtype == DT_INVALID) {
1425 dtype = node->output_type(index);
1426 }
1427
1428 TF_ASSIGN_OR_RETURN(
1429 auto placeholder_node_and_removed,
1430 CreatePlaceholderNodeForFeed(array_info.shape, dtype, node, index,
1431 *node_name_map));
1432
1433 Node* placeholder_node = placeholder_node_and_removed.first;
1434 if (placeholder_node->in_edges().empty()) {
1435 graph_->AddControlEdge(graph_->source_node(), placeholder_node,
1436 true /* skip test for duplicates */);
1437 }
1438 if (placeholder_node->out_edges().empty()) {
1439 graph_->AddControlEdge(placeholder_node, graph_->sink_node(),
1440 true /* skip test for duplicates */);
1441 }
1442 remapped_feeds_[{it.first, index}] = placeholder_node->name();
1443 (*node_name_map)[placeholder_node->name()] = placeholder_node;
1444 }
1445 }
1446 }
1447 return Status::OK();
1448 }
1449
PrepareConvert(const Graph & graph,std::unique_ptr<GraphDef> graph_def)1450 Status ImporterBase::PrepareConvert(const Graph& graph,
1451 std::unique_ptr<GraphDef> graph_def) {
1452 // TODO(fengliuai): Converting to GraphDef and back is the easiest way to
1453 // clone a graph.
1454 // TODO(fengliuai): clone the graph without going to graph_def first.
1455 if (graph_def == nullptr) {
1456 graph_def = std::make_unique<GraphDef>();
1457 graph.ToGraphDef(graph_def.get());
1458 }
1459 graph_ = absl::make_unique<Graph>(graph.flib_def());
1460 GraphConstructorOptions opts;
1461 opts.allow_internal_ops = true;
1462 opts.add_default_attributes = true;
1463 TF_RETURN_IF_ERROR(::tensorflow::ConvertGraphDefToGraph(
1464 opts, std::move(*graph_def), graph_.get()));
1465
1466 TF_RETURN_IF_ERROR(RemoveBackedges());
1467
1468 TF_RETURN_IF_ERROR(CopyStackTraces(graph, graph_.get()));
1469
1470 auto node_name_map = graph_->BuildNodeNameIndex();
1471
1472 if (specs_.enable_shape_inference) {
1473 // TODO(jpienaar): Remove once infer shapes on import flag is removed.
1474 TF_RETURN_IF_ERROR(AddNodesToShapeRefiner(&node_name_map));
1475 } else {
1476 TF_RETURN_IF_ERROR(ConvertFeedsToPlaceholders(&node_name_map));
1477 }
1478
1479 // Prune nodes in the graph that are not reachable from the output.
1480 if (specs_.prune_unused_nodes) {
1481 TF_RETURN_IF_ERROR(PruneUnreachableNodes(&node_name_map));
1482 }
1483
1484 if (!specs_.enable_shape_inference) {
1485 // Re-initialize ordered_nodes_ since we might have modified the graph.
1486 GetReversePostOrder(
1487 *graph_, &ordered_nodes_,
1488 [](const Node* n1, const Node* n2) { return n1->name() < n2->name(); });
1489 }
1490
1491 return Status::OK();
1492 }
1493
Convert(llvm::StringRef func_name,mlir::FunctionType func_type,const absl::InlinedVector<OutputTensor,4> & arg_nodes,const absl::InlinedVector<OutputTensor,4> & ret_nodes,const absl::InlinedVector<Node *,4> & control_ret_nodes,llvm::ArrayRef<mlir::NamedAttribute> attrs)1494 Status ImporterBase::Convert(
1495 llvm::StringRef func_name, mlir::FunctionType func_type,
1496 const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
1497 const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
1498 const absl::InlinedVector<Node*, 4>& control_ret_nodes,
1499 llvm::ArrayRef<mlir::NamedAttribute> attrs) {
1500 // TODO(b/122040776): Uses debug info for FunctionDef.
1501 auto function = mlir::FuncOp::create(mlir::UnknownLoc::get(context_),
1502 func_name, func_type, attrs);
1503
1504 module_.push_back(function);
1505 // Seeds the builder with an initial block.
1506 function.addEntryBlock();
1507 builder_ = mlir::OpBuilder(function.getBody());
1508
1509 // Create the graph operation in which we will convert the individual nodes.
1510 auto graph = builder_.create<mlir::tf_executor::GraphOp>(
1511 function.getLoc(), func_type.getResults());
1512 builder_.createBlock(&graph.body());
1513
1514 for (const Node* node : ordered_nodes_) {
1515 TF_RETURN_IF_ERROR(ConvertNode(*node));
1516 }
1517
1518 // Adds the backedges back to the function by creating the source and sink
1519 // pairs.
1520 TF_RETURN_IF_ERROR(AddBackedges());
1521
1522 TF_RETURN_IF_ERROR(ConvertFunctionArgAndRets(function, graph,
1523 func_type.getInputs(), arg_nodes,
1524 ret_nodes, control_ret_nodes));
1525
1526 // TODO(jpienaar): Update post removing shape_refinier_.
1527 if (!specs_.enable_shape_inference) {
1528 // Refine graph's type given more precise fetch.
1529 auto fetch = graph.GetFetch();
1530 bool all_equal = true;
1531 for (auto it :
1532 llvm::zip_first(graph.getResults(), fetch.getOperandTypes())) {
1533 auto rt = std::get<1>(it);
1534 if (rt == std::get<0>(it).getType()) continue;
1535 std::get<0>(it).setType(rt);
1536 all_equal = false;
1537 }
1538 if (!all_equal) {
1539 function.setType(mlir::FunctionType::get(function.getContext(),
1540 func_type.getInputs(),
1541 graph.getResultTypes()));
1542 }
1543 }
1544
1545 return Status::OK();
1546 }
1547
ConvertFunctionArgAndRets(mlir::FuncOp func,mlir::tf_executor::GraphOp graph_op,llvm::ArrayRef<mlir::Type> arg_types,const absl::InlinedVector<OutputTensor,4> & arg_nodes,const absl::InlinedVector<OutputTensor,4> & ret_nodes,const absl::InlinedVector<Node *,4> & control_ret_nodes)1548 Status ImporterBase::ConvertFunctionArgAndRets(
1549 mlir::FuncOp func, mlir::tf_executor::GraphOp graph_op,
1550 llvm::ArrayRef<mlir::Type> arg_types,
1551 const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
1552 const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
1553 const absl::InlinedVector<Node*, 4>& control_ret_nodes) {
1554 // Store the arg/return attributes as a list rather than uniqueuing during
1555 // construction.
1556 llvm::SmallVector<mlir::NamedAttrList, 4> arg_attrs;
1557 arg_attrs.resize(func.getNumArguments());
1558 llvm::SmallVector<mlir::NamedAttrList, 4> ret_attrs;
1559 ret_attrs.resize(func.getNumResults());
1560
1561 auto set_attributes_on_func = [&](Node* node, int64_t index, bool is_arg) {
1562 for (const auto& node_attr : node->attrs()) {
1563 const auto& key = node_attr.first;
1564 // Only import optional attributes (e.g., those starting with an
1565 // underscore).
1566 if (key.empty() || key[0] != '_') continue;
1567 // Ignore shape inference attributes as shape information is already
1568 // populated in the result type.
1569 if (IsOutputShapesAttribute(node_attr.second, key) ||
1570 IsResourceOutputShapesAttribute(node_attr.second, key))
1571 continue;
1572 TF_ASSIGN_OR_RETURN(auto converted_attr,
1573 ConvertAttributeValue(node_attr.second));
1574 std::string dialect_attribute = "tf." + key;
1575 if (is_arg) {
1576 arg_attrs[index].set(dialect_attribute, converted_attr);
1577 } else {
1578 func.setResultAttr(index, dialect_attribute, converted_attr);
1579 ret_attrs[index].set(dialect_attribute, converted_attr);
1580 }
1581 }
1582 return Status::OK();
1583 };
1584
1585 auto* bb = &func.front();
1586 llvm::SmallDenseMap<std::pair<Node*, int>, mlir::Value, 4>
1587 arg_nodes_to_values;
1588 for (int i = 0, e = arg_types.size(); i < e; ++i) {
1589 auto& arg_node = arg_nodes[i];
1590 // The lookup can't fail here: otherwise some nodes in the function haven't
1591 // be converted to mlir operations and don't have a mapping.
1592 mlir::Operation* island = node_values_.find(arg_node.node->id())->second;
1593
1594 auto bb_arg = bb->getArgument(i);
1595 mlir::Value arg_def = bb_arg;
1596
1597 if (island->getNumResults() != 2)
1598 return errors::InvalidArgument(
1599 "Only feed output tensors of single output nodes are supported");
1600
1601 // Collect mapping of OutputTensor to associated block arg.
1602 arg_nodes_to_values.try_emplace({arg_node.node, arg_node.index}, arg_def);
1603 island->getResult(0).replaceAllUsesWith(arg_def);
1604 // Erase control outputs from feed.
1605 auto control_uses = island->getResult(1).getUses();
1606 for (auto& control_use : llvm::make_early_inc_range(control_uses))
1607 control_use.getOwner()->eraseOperand(control_use.getOperandNumber());
1608
1609 if (!arg_node.node->requested_device().empty())
1610 arg_attrs[i].set("tf.device", builder_.getStringAttr(
1611 arg_node.node->requested_device()));
1612
1613 if (arg_node.node->IsArg()) {
1614 TF_RETURN_IF_ERROR(
1615 set_attributes_on_func(arg_node.node, i, /*is_arg=*/true));
1616 }
1617
1618 island->dropAllReferences();
1619 island->erase();
1620 }
1621
1622 llvm::SmallVector<mlir::Value, 8> inst_to_return;
1623 for (auto ret_and_idx : llvm::enumerate(ret_nodes)) {
1624 const auto& ret = ret_and_idx.value();
1625 auto* inst = node_values_[ret.node->id()];
1626 if (ret.node->IsRetval()) {
1627 if (!ret.node->requested_device().empty())
1628 ret_attrs[ret_and_idx.index()].set(
1629 "tf.device", builder_.getStringAttr(ret.node->requested_device()));
1630 TF_RETURN_IF_ERROR(set_attributes_on_func(ret.node, ret_and_idx.index(),
1631 /*is_arg=*/false));
1632 // Lookup the instruction inside the island
1633 auto island_op = llvm::cast<mlir::tf_executor::IslandOp>(inst);
1634 mlir::Operation* inner_op = &island_op.GetBody().front();
1635 // Remove kRetOp or kDeviceRetOp operation and return its operand.
1636 // kRetOp and kDeviceRetOp should have just one operand unless they have
1637 // control dependencies.
1638 if (inner_op->getNumOperands() != 1)
1639 return errors::Unimplemented("Return node with multiple inputs.");
1640 inst_to_return.push_back(inner_op->getOperand(0));
1641 inst->dropAllReferences();
1642 inst->erase();
1643 } else {
1644 // Lookup and use block arg if fetch is a feed.
1645 auto it = arg_nodes_to_values.find({ret.node, ret.index});
1646 if (it != arg_nodes_to_values.end())
1647 inst_to_return.push_back(it->second);
1648 else
1649 inst_to_return.push_back(inst->getResult(ret.index));
1650 }
1651 }
1652
1653 for (Node* control_ret : control_ret_nodes) {
1654 auto* inst = node_values_[control_ret->id()];
1655 inst_to_return.push_back(*std::prev(inst->result_end()));
1656 }
1657
1658 // Terminate the function by adding a Fetch operation to terminate the graph
1659 // and a return operation to return the Graph results.
1660 builder_.setInsertionPointToEnd(&graph_op.body().front());
1661 builder_.create<mlir::tf_executor::FetchOp>(graph_op.getLoc(),
1662 inst_to_return);
1663 builder_.setInsertionPointToEnd(bb);
1664 builder_.create<mlir::ReturnOp>(mlir::UnknownLoc::get(context_),
1665 graph_op.getResults());
1666
1667 func.setAllArgAttrs(
1668 llvm::to_vector<4>(llvm::map_range(arg_attrs, [&](NamedAttrList& list) {
1669 return list.getDictionary(context_);
1670 })));
1671 func.setAllResultAttrs(
1672 llvm::to_vector<4>(llvm::map_range(ret_attrs, [&](NamedAttrList& list) {
1673 return list.getDictionary(context_);
1674 })));
1675
1676 return Status::OK();
1677 }
1678
GetLocation(const Node & node)1679 mlir::Location ImporterBase::GetLocation(const Node& node) {
1680 DVLOG(1) << "Getting location for " << node.name() << " " << &node;
1681 // TODO(b/142400497): What is the semantic contract for locations?
1682 const auto& debug_info = debug_info_.traces();
1683
1684 // Create a location for node `name` in function `function_name`.
1685 auto create_location = [&](llvm::StringRef name,
1686 llvm::StringRef function_name) -> mlir::Location {
1687 // Use the catenation of function and node names as the lookup key into the
1688 // debug info. This matches the way that the key is formed on the python
1689 // side.
1690 //
1691 // We also use this as the name for the NameLoc for ops in function, since
1692 // otherwise our names could collide across functions.
1693 // For ops in the main graph, we omit the "@function_name" (which, would be
1694 // just "@" since function_name would be empty) because some code seems to
1695 // depend on the name being this way for correctness.
1696 std::string debug_info_key = (name + "@" + function_name).str();
1697 std::string name_for_name_loc =
1698 function_name.empty() ? name.str() : (name + "@" + function_name).str();
1699 auto name_loc_id = mlir::Identifier::get(name_for_name_loc, context_);
1700
1701 llvm::SmallVector<mlir::Location, 4> locations;
1702 // Prefer stack traces if available, fallback to debug info if not, and then
1703 // finally to just name.
1704 if (auto stack_trace = node.GetStackTrace()) {
1705 DVLOG(1) << "Stack available for " << node.name();
1706 absl::Span<const StackFrame> frames = stack_trace->ToFrames();
1707 locations.reserve(frames.size());
1708 for (const StackFrame& frame : llvm::reverse(frames)) {
1709 auto file_name = mlir::Identifier::get(frame.file_name, context_);
1710 // Use col 1 as there is no column info in StackTrace.
1711 auto file_line_loc =
1712 mlir::FileLineColLoc::get(file_name, frame.line_number, 1);
1713 locations.push_back(file_line_loc);
1714 }
1715 } else {
1716 DVLOG(1) << "No stack trace for " << node.name();
1717 const auto location_it = debug_info.find(debug_info_key);
1718 if (location_it != debug_info.end()) {
1719 DVLOG(1) << "Available serialized debug info for " << node.name();
1720 // Convert the stack trace to a chain of mlir::CallSiteLocs.
1721 const auto& trace = location_it->second;
1722 locations.reserve(trace.file_line_cols_size());
1723 for (const auto& location : trace.file_line_cols()) {
1724 const auto& file = debug_info_.files(location.file_index());
1725 auto file_name = mlir::Identifier::get(file, context_);
1726 auto file_line_loc = mlir::FileLineColLoc::get(
1727 file_name, location.line(), location.col());
1728 locations.push_back(file_line_loc);
1729 }
1730 }
1731 }
1732
1733 // If there are no locations in the stack trace, fall back to just a
1734 // NameLoc with no child.
1735 if (locations.empty()) return mlir::NameLoc::get(name_loc_id);
1736
1737 // Use the front FileLineColLoc to generate a NameLoc.
1738 mlir::Location node_name_loc =
1739 mlir::NameLoc::get(name_loc_id, locations.front());
1740
1741 // If there are more locations then generate a stack trace, otherwise just
1742 // return the name loc.
1743 auto callsite_locs = llvm::makeArrayRef(locations).drop_front();
1744 return callsite_locs.empty()
1745 ? node_name_loc
1746 : mlir::CallSiteLoc::get(node_name_loc, callsite_locs);
1747 };
1748
1749 // For NextIteration nodes, location is used to pair source and sink nodes.
1750 // Hence, we use node name as location to keep it unique.
1751 // TODO(prakalps): In future the plan is to use tokens to pair source/sink
1752 // nodes. Then NextIteration nodes would not need to be handled separately.
1753 if (node.type_string() == "NextIteration")
1754 return create_location(node.name(), function_name_for_debug_info_);
1755
1756 if (node.GetStackTrace())
1757 return create_location(node.name(), function_name_for_debug_info_);
1758
1759 const auto& node_def = node.def();
1760 auto original_nodes =
1761 node_def.experimental_debug_info().original_node_names();
1762 auto original_funcs =
1763 node_def.experimental_debug_info().original_func_names();
1764
1765 if (original_nodes.empty()) {
1766 return create_location(node.name(), function_name_for_debug_info_);
1767 } else {
1768 // If the original nodes are defined, then we use them to get a list of
1769 // call sites, and then fuse them to a single fused location, with the name
1770 // of the node_def.
1771 llvm::SmallVector<mlir::Location, 4> node_locations;
1772 node_locations.reserve(original_nodes.size() + 1);
1773
1774 // store the names in the experimental_debug_info
1775 for (int i = 0, e = original_nodes.size(); i != e; ++i) {
1776 auto node_name = original_nodes[i];
1777 auto func_name = (i < original_funcs.size()) ? original_funcs[i] : "";
1778 node_locations.push_back(create_location(node_name, func_name));
1779 }
1780 // store the name of the node_def
1781 node_locations.push_back(
1782 create_location(node.name(), function_name_for_debug_info_));
1783 return mlir::FusedLoc::get(context_, node_locations);
1784 }
1785 }
1786
EmitErrorWithLocationStr(const Node & node,const Status & error_status)1787 Status ImporterBase::EmitErrorWithLocationStr(const Node& node,
1788 const Status& error_status) {
1789 const mlir::Location location = GetLocation(node);
1790 mlir::emitError(location);
1791 return error_handler_.Combine(error_status);
1792 }
1793
CreateOperation(const Node & node,llvm::StringRef node_type_name,const mlir::OperationState & result,const llvm::SmallVectorImpl<mlir::Value> & control_operands)1794 mlir::Operation* ImporterBase::CreateOperation(
1795 const Node& node, llvm::StringRef node_type_name,
1796 const mlir::OperationState& result,
1797 const llvm::SmallVectorImpl<mlir::Value>& control_operands) {
1798 // For the tf.executor specific operations (not wrapped in an island), we
1799 // have an extra returned value for the control result, and we concatenate
1800 // control and non-control operands.
1801 mlir::SmallVector<mlir::Type, 4> types(result.types);
1802 types.push_back(mlir::tf_executor::ControlType::get(builder_.getContext()));
1803 mlir::SmallVector<mlir::Value, 4> operands(result.operands);
1804 operands.append(control_operands.begin(), control_operands.end());
1805
1806 auto loc = result.location;
1807 // Dispatch based on the name and create the appropriate operation.
1808 if (node.IsSwitch()) {
1809 // Switch and _SwitchN both are in switch class, differentiate based on
1810 // op name.
1811 if (node.op_def().name() == "_SwitchN") {
1812 return builder_.create<mlir::tf_executor::SwitchNOp>(loc, types, operands,
1813 result.attributes);
1814 }
1815 return builder_.create<mlir::tf_executor::SwitchOp>(loc, types, operands,
1816 result.attributes);
1817 }
1818 if (node.IsMerge()) {
1819 return builder_.create<mlir::tf_executor::MergeOp>(loc, types, operands,
1820 result.attributes);
1821 }
1822 if (node.IsNextIteration()) {
1823 // NextIteration is a bit special, we create a pair of operations that are
1824 // linked together through a token returned by the source.
1825 // We make use of a separate builder to insert the source at the top of
1826 // the block.
1827 mlir::OpBuilder builder_at_begin(builder_.getBlock(),
1828 builder_.getBlock()->begin());
1829 auto source_op =
1830 builder_at_begin.create<mlir::tf_executor::NextIterationSourceOp>(
1831 loc, operands[0].getType(), result.attributes);
1832 return builder_.create<mlir::tf_executor::NextIterationSinkOp>(
1833 loc, source_op.token(), operands, result.attributes);
1834 }
1835 if (node.IsLoopCond()) {
1836 return builder_.create<mlir::tf_executor::LoopCondOp>(loc, types, operands,
1837 result.attributes);
1838 }
1839 if (node.IsEnter()) {
1840 return builder_.create<mlir::tf_executor::EnterOp>(loc, types, operands,
1841 result.attributes);
1842 }
1843 if (node.IsExit()) {
1844 return builder_.create<mlir::tf_executor::ExitOp>(loc, types, operands,
1845 result.attributes);
1846 }
1847 if (node.IsControlTrigger()) {
1848 return builder_.create<mlir::tf_executor::ControlTriggerOp>(
1849 loc, operands, result.attributes);
1850 }
1851 // Regular TensorFlow operation are wrapped in a tf_executor.island.
1852 auto island = builder_.create<mlir::tf_executor::IslandOp>(
1853 result.location, types, control_operands,
1854 mlir::ArrayRef<mlir::NamedAttribute>{});
1855 island.body().push_back(new mlir::Block);
1856 mlir::OpBuilder island_builder =
1857 mlir::OpBuilder::atBlockEnd(&island.GetBody());
1858
1859 // Create the operation inside the island now.
1860 mlir::Operation* inner_op = island_builder.createOperation(result);
1861
1862 // Sets operand_segment_sizes or result_segment_sizes attribute to the op.
1863 const auto set_segment_sizes_attr =
1864 [&](const NameRangeMap& arg_ranges,
1865 const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
1866 llvm::StringRef attr_name) {
1867 std::vector<mlir::Attribute> values;
1868 values.reserve(args.size());
1869 for (const auto& arg : args) {
1870 auto range = arg_ranges.at(arg.name());
1871 values.push_back(
1872 island_builder.getI32IntegerAttr(range.second - range.first));
1873 }
1874 auto attr_type =
1875 mlir::VectorType::get(args.size(), builder_.getIntegerType(32));
1876 auto attr_value = mlir::DenseElementsAttr::get(attr_type, values);
1877 inner_op->setAttr(attr_name, attr_value);
1878 };
1879
1880 if (inner_op->hasTrait<mlir::OpTrait::AttrSizedOperandSegments>() ||
1881 inner_op->hasTrait<mlir::OpTrait::AttrSizedResultSegments>()) {
1882 // The op has multiple variadic operands or results.
1883 // Calculate operand and result segment sizes using the OpDef.
1884 NameRangeMap input_ranges, output_ranges;
1885 // This will fail only if the OpDef is syntactically invalid.
1886 // TODO(jpienaar): Convert this CHECK into a properly propagated error.
1887 TF_CHECK_OK(
1888 NameRangesForNode(node, node.op_def(), &input_ranges, &output_ranges));
1889 if (inner_op->hasTrait<mlir::OpTrait::AttrSizedOperandSegments>()) {
1890 // Add derived "operand_segment_sizes" attr to the created operation.
1891 // TODO(b/146937733): Don't use <void> here.
1892 set_segment_sizes_attr(input_ranges, node.op_def().input_arg(),
1893 mlir::OpTrait::AttrSizedOperandSegments<
1894 void>::getOperandSegmentSizeAttr());
1895 }
1896
1897 if (inner_op->hasTrait<mlir::OpTrait::AttrSizedResultSegments>()) {
1898 // Add derived "result_segment_sizes" attr to the created operation.
1899 // TODO(b/146937733): Don't use <void> here.
1900 set_segment_sizes_attr(output_ranges, node.op_def().output_arg(),
1901 mlir::OpTrait::AttrSizedResultSegments<
1902 void>::getResultSegmentSizeAttr());
1903 }
1904 }
1905
1906 mlir::OperationName name = inner_op->getName();
1907 if (!name.getAbstractOperation() &&
1908 // Skip unmodelled ops that are handled differently.
1909 (node_type_name != "_Arg" && node_type_name != "_Retval") &&
1910 // Skip if warning already reported.
1911 (GetUnmodelledOpTypes().insert(name.getStringRef()).second)) {
1912 if (node.op_def().is_stateful()) {
1913 LOG(INFO) << "[potentially conservative] Op type `" << node.type_string()
1914 << "` is stateful but effects not modelled";
1915 } else {
1916 // See if any resource type is used.
1917 bool resource = false;
1918 std::function<bool(mlir::Type)> record_resource;
1919 record_resource = [&](mlir::Type type) {
1920 if (resource) return true;
1921 if (type.isa<mlir::TF::ResourceType>()) {
1922 resource = true;
1923 return true;
1924 }
1925 if (auto with_subtype =
1926 type.dyn_cast<mlir::SubElementTypeInterface>()) {
1927 with_subtype.walkSubTypes([&](mlir::Type t) { record_resource(t); });
1928 }
1929 return resource;
1930 };
1931
1932 for (mlir::Type t : inner_op->getResultTypes())
1933 if (record_resource(t)) break;
1934 for (mlir::Type t : inner_op->getOperandTypes())
1935 if (record_resource(t)) break;
1936 if (resource)
1937 LOG(INFO) << "[potentially conservative] Op type `"
1938 << node.type_string()
1939 << "` has resource operands/results but effects not modelled";
1940 }
1941 }
1942
1943 // Add the terminator for the island
1944 island_builder.create<mlir::tf_executor::YieldOp>(result.location,
1945 inner_op->getResults());
1946 return island.getOperation();
1947 }
1948
ConvertNode(const Node & node)1949 Status ImporterBase::ConvertNode(const Node& node) {
1950 if (!node.IsOp()) {
1951 // Don't import the pseudo-nodes _SOURCE or _SINK. These are added by
1952 // Graph and don't exist in GraphDef.
1953 return Status::OK();
1954 }
1955
1956 // If it is a custom OP, its definition should be found in the library. We
1957 // create the MLIR function and insert it to the module if it doesn't exist.
1958 std::string node_type_name = node.type_string();
1959 const auto* func_def = graph_flib_.Find(node_type_name);
1960 bool convert_to_legacy_call = false;
1961 if (func_def) {
1962 TF_RETURN_IF_ERROR(ConvertLibFunction(node_type_name));
1963 node_type_name = (*tf_name_to_mlir_name_)[node_type_name];
1964 convert_to_legacy_call = true;
1965 }
1966
1967 auto get_full_op_name = [&](const std::string& op_name) {
1968 const char* kTfPrefix = "tf.";
1969 return kTfPrefix + op_name;
1970 };
1971
1972 std::string op_name = get_full_op_name(node_type_name);
1973 if (back_edge_node_output_.contains(&node)) {
1974 op_name = op_name + ".sink";
1975 }
1976
1977 mlir::OperationState result(GetLocation(node), op_name);
1978 for (int i = 0; i < node.num_outputs(); ++i) {
1979 // The backedge has been removed, so we shouldn't count the corresponding
1980 // output from the src node when converting to an operation.
1981 if (back_edge_node_output_.contains(&node) &&
1982 back_edge_node_output_[&node] == i) {
1983 continue;
1984 }
1985 TF_ASSIGN_OR_RETURN(auto type, InferOutputType(node, i, builder_));
1986 result.types.push_back(type);
1987 }
1988
1989 // Surprisingly input edges can be nondeterministically ordered. This
1990 // particularly seems to be the case for the control edges between _SOURCE
1991 // and _SINK that the Graph constructor inserts. Copy the input edges and
1992 // sort the edges, but only the control edges, not data edges!
1993 // TODO(jmolloy): We should probably just ignore _SOURCE and _SINK nodes.
1994 // They'll break roundtripping anyway unless we strip them when converting
1995 // back to graphdef.
1996 absl::InlinedVector<const Edge*, 8> in_edges(node.in_edges().size());
1997 absl::c_copy(node.in_edges(), in_edges.begin());
1998 absl::c_stable_sort(in_edges, [](const Edge* e1, const Edge* e2) {
1999 if (e1->IsControlEdge() && !e2->IsControlEdge()) return false;
2000 if (!e1->IsControlEdge() && e2->IsControlEdge()) return true;
2001 if (e1->IsControlEdge() && e2->IsControlEdge())
2002 return e1->src()->id() < e2->src()->id();
2003 return e1->dst_input() < e2->dst_input();
2004 });
2005
2006 result.operands.reserve(in_edges.size());
2007
2008 // Collect the control operands separately, they will be held by the island.
2009 mlir::SmallVector<mlir::Value, 8> control_operands;
2010
2011 for (const auto* input_edge : in_edges) {
2012 const Node& input_node = *input_edge->src();
2013 if (input_node.IsSource()) {
2014 if (in_edges.size() != 1) {
2015 return errors::FailedPrecondition(
2016 "The node has other inputs besides the _Source node");
2017 }
2018 // We don't import the _SOURCE node.
2019 continue;
2020 }
2021 if (input_node.IsArg() && input_edge->IsControlEdge()) {
2022 // Currently we have not reached consensus as to what TF function
2023 // semantics are (b/133509504). Here we assume that all arguments to a
2024 // function should be available before we start execution of any internal
2025 // node. This makes the control dependencies between function arguments
2026 // and internal nodes redundant, and so we do not import them. The TF
2027 // inliner however assumes no such dependency between function args and
2028 // internal nodes exists, unless explicitly stated. Since we drop control
2029 // dependencies here, it leads to loss of information. If the function is
2030 // inlined later, the inliner would not know of these explicit control
2031 // dependencies present in the original graph.
2032 continue;
2033 }
2034 if (node_values_.find(input_node.id()) == node_values_.end())
2035 return errors::FailedPrecondition(
2036 "Graph not traversed in reverse post order; use seen before def!");
2037 mlir::Operation* inst = node_values_[input_node.id()];
2038 if (input_edge->IsControlEdge())
2039 control_operands.push_back(inst->getResult(inst->getNumResults() - 1));
2040 else
2041 result.operands.push_back(inst->getResult(input_edge->src_output()));
2042 }
2043
2044 using FuncPairType = std::pair<const std::string*, const AttrValue*>;
2045 std::vector<FuncPairType> funcs;
2046 result.attributes.reserve(node.attrs().size() + 2);
2047 auto abstract_op = result.name.getAbstractOperation();
2048 auto derived_op =
2049 abstract_op
2050 ? abstract_op->getInterface<mlir::DerivedAttributeOpInterface>()
2051 : nullptr;
2052 for (const auto& name_and_value : node.attrs()) {
2053 const auto& attr_name = name_and_value.first;
2054 // Skip adding derived attributes to the generated op.
2055 if (derived_op && derived_op->isDerivedAttribute(attr_name)) continue;
2056 const AttrValue& attr_value = name_and_value.second;
2057
2058 // Remove _output_shapes attribute that will be added by the exporter.
2059 if (IsOutputShapesAttribute(attr_value, attr_name)) continue;
2060
2061 if (attr_value.value_case() == AttrValue::kFunc) {
2062 // Attribute iteration order is not defined for protocol buffer Map.
2063 // Process function attributes separately in the lexicographical order to
2064 // have deterministic order of functions in the constructed IR.
2065 funcs.emplace_back(&attr_name, &attr_value);
2066 } else {
2067 TF_ASSIGN_OR_RETURN(auto attr, ConvertAttributeValue(attr_value));
2068 result.attributes.push_back(builder_.getNamedAttr(attr_name, attr));
2069 }
2070 }
2071
2072 auto comparator = [](const FuncPairType& a, const FuncPairType& b) {
2073 return *a.first < *b.first;
2074 };
2075 std::sort(funcs.begin(), funcs.end(), comparator);
2076 for (const auto& func : funcs) {
2077 TF_RETURN_IF_ERROR(ConvertFunctionCallAttribute(*func.first, *func.second,
2078 &result.attributes));
2079 }
2080
2081 const auto& node_def = node.def();
2082 // NodeDef can contain partial TF device names. In such cases, canonicalize
2083 // it. Note that in current TF, placer will place full device name to each
2084 // node.
2085 DeviceNameUtils::ParsedName parsed_name;
2086 if (!DeviceNameUtils::ParseFullName(node_def.device(), &parsed_name)) {
2087 return errors::InvalidArgument(
2088 "Op ", op_name, " has invalid device name: ", node_def.device());
2089 }
2090 // Keep the parsed name untouched if the device name is empty.
2091 if (!node_def.device().empty()) {
2092 if (!parsed_name.has_type) {
2093 parsed_name.type = "CPU";
2094 parsed_name.has_type = true;
2095 }
2096 if (!parsed_name.has_id) {
2097 parsed_name.id = 0;
2098 parsed_name.has_id = true;
2099 }
2100 }
2101 result.attributes.push_back(builder_.getNamedAttr(
2102 "device", builder_.getStringAttr(
2103 DeviceNameUtils::ParsedNameToString(parsed_name))));
2104
2105 // Map user function calls to LegacyCall ops and add the user function name
2106 // as an attribute.
2107 if (convert_to_legacy_call) {
2108 result.name = mlir::OperationName(get_full_op_name("LegacyCall"), context_);
2109 mlir::SymbolRefAttr val = builder_.getSymbolRefAttr(node_type_name);
2110 result.addAttribute("f", val);
2111
2112 if (!result.attributes.get("_disable_call_shape_inference")) {
2113 result.addAttribute("_disable_call_shape_inference",
2114 builder_.getBoolAttr(false));
2115 }
2116 }
2117
2118 auto composite_control_flow_op = [&](const std::string& name) {
2119 result.name = mlir::OperationName(get_full_op_name(name), context_);
2120 bool stateless = absl::StartsWith(node_type_name, "Stateless");
2121 mlir::BoolAttr val = builder_.getBoolAttr(stateless);
2122 result.attributes.push_back(builder_.getNamedAttr("is_stateless", val));
2123 };
2124
2125 // Map Case/If/While and StatelessCase/If/While op in TensorFlow to the common
2126 // Case/If/While op in MLIR and add the differentiating attribute.
2127 if (node.IsCaseNode()) composite_control_flow_op("Case");
2128 if (node.IsIfNode()) composite_control_flow_op("If");
2129 if (node.IsWhileNode()) {
2130 composite_control_flow_op("While");
2131 auto* output_shapes = node.attrs().Find("output_shapes");
2132 if (output_shapes && !output_shapes->list().shape().empty())
2133 result.attributes.push_back(
2134 builder_.getNamedAttr("shape_invariant", builder_.getUnitAttr()));
2135 }
2136
2137 // Register the mapping between the TF node and the newly created operation.
2138 node_values_[node.id()] =
2139 CreateOperation(node, node_type_name, result, control_operands);
2140 return Status::OK();
2141 }
2142
2143 // Add the backedges to the CFG. Given a backedge, we replace the original
2144 // source and destination operations by two new operations. Most of the
2145 // fields of the replacements are copied from the original operations.
2146 // However,
2147 // - for the src operation, one output is inserted to the front of the output
2148 // list. The type of the output is set to the type of the non-control result
2149 // of the dst operation, and
2150 // - for the dst operation, one operand is inserted to the front of the
2151 // operand list. This operand is using the first result of the src
2152 // operation.
2153 // TODO(fengliuai): Preserve the order of the results and operands if
2154 // necessary.
AddBackedges()2155 Status ImporterBase::AddBackedges() {
2156 for (auto it : back_edge_dst_inputs_) {
2157 BackEdge& edge = it.second;
2158 if (!edge.src->IsNextIteration() || !edge.dst->IsMerge()) {
2159 return errors::FailedPrecondition(
2160 "Invalid backedge; should be from NextIteration to Merge!");
2161 }
2162 auto* sink = node_values_[edge.src->id()];
2163 auto* dst = node_values_[edge.dst->id()];
2164 TF_RETURN_IF_ERROR(AddBackedge(sink, dst, edge.dst_input));
2165 }
2166 return Status::OK();
2167 }
2168
AddBackedge(mlir::Operation * sink,mlir::Operation * dst,int dst_input)2169 Status ImporterBase::AddBackedge(mlir::Operation* sink, mlir::Operation* dst,
2170 int dst_input) {
2171 // Get the NextIteration.Source operation from the token operand of the sink.
2172 mlir::Operation* source = sink->getOperand(0).getDefiningOp();
2173
2174 // Adds the "source" to the operands of the dst by creating a new dst
2175 // operation.
2176 mlir::OperationState state(dst->getLoc(), dst->getName());
2177 auto num_operands = dst->getNumOperands();
2178 state.operands.reserve(num_operands + 1);
2179 for (int input = 0, e = num_operands + 1; input != e; ++input) {
2180 if (input < dst_input) {
2181 state.operands.push_back(dst->getOperand(input));
2182 } else if (input == dst_input) {
2183 state.operands.push_back(source->getResult(0));
2184 } else {
2185 state.operands.push_back(dst->getOperand(input - 1));
2186 }
2187 }
2188 state.attributes.assign(dst->getAttrs().begin(), dst->getAttrs().end());
2189 state.types.assign(dst->getResultTypes().begin(),
2190 dst->getResultTypes().end());
2191 builder_.setInsertionPoint(dst);
2192 auto* new_dst = builder_.createOperation(state);
2193
2194 // Replaces the output uses of the old operation by the corresponding
2195 // result of the new operation, and deletes the old operation.
2196 for (unsigned i = 0, e = dst->getNumResults(); i != e; ++i) {
2197 auto new_output = new_dst->getResult(i);
2198 dst->getResult(i).replaceAllUsesWith(new_output);
2199 }
2200 dst->dropAllReferences();
2201 dst->erase();
2202 return Status::OK();
2203 }
2204
InferLibFunctionType(const FunctionBody & fbody)2205 StatusOr<mlir::FunctionType> ImporterBase::InferLibFunctionType(
2206 const FunctionBody& fbody) {
2207 mlir::Builder builder(context_);
2208
2209 // The FunctionBody contains a graph with a single-output _Arg node for each
2210 // function argument and a single-input _Retval node for each function return
2211 // value.
2212 //
2213 // We already populated the ShapeRefiner with all the information about the
2214 // shapes of these graph edges, so we just query it to build the corresponding
2215 // MLIR function type signature.
2216
2217 llvm::SmallVector<mlir::Type, 4> arg_types;
2218 if (specs_.inputs.empty()) {
2219 arg_types.reserve(fbody.arg_types.size());
2220 for (auto arg : fbody.arg_nodes) {
2221 // Find node in the graph using the node id instead of using `arg`
2222 // directly because the graph has been cloned.
2223 auto* node = graph_->FindNodeId(arg->id());
2224 TF_ASSIGN_OR_RETURN(auto type,
2225 InferOutputType(*node, /*idx=*/0, builder));
2226 arg_types.push_back(type);
2227 }
2228 } else {
2229 arg_types.reserve(fbody.arg_types.size());
2230 for (const auto& it : llvm::enumerate(specs_.inputs)) {
2231 mlir::Type element_type;
2232 const auto& node_info = it.value().second;
2233 DataType dtype = node_info.imported_dtype;
2234 // Uses the existing output type of the arg node if the data type of the
2235 // the node isn't specified through the import configuration.
2236 if (dtype == DT_INVALID) {
2237 auto arg = fbody.arg_nodes[it.index()];
2238 auto* node = graph_->FindNodeId(arg->id());
2239 dtype = node->output_type(0);
2240 if (dtype == DT_INVALID) {
2241 return errors::InvalidArgument("Input ", it.index(),
2242 "has invalid data type");
2243 }
2244 }
2245 TF_RETURN_IF_ERROR(
2246 ::tensorflow::ConvertDataType(dtype, builder, &element_type));
2247 if (node_info.shape.unknown_rank()) {
2248 arg_types.push_back(mlir::UnrankedTensorType::get(element_type));
2249 } else {
2250 llvm::SmallVector<int64_t, 4> shape;
2251 TF_RETURN_IF_ERROR(ConvertToMlirShape(node_info.shape, &shape));
2252 arg_types.push_back(mlir::RankedTensorType::get(shape, element_type));
2253 }
2254 }
2255 }
2256
2257 llvm::SmallVector<mlir::Type, 4> ret_types;
2258 ret_types.reserve(fbody.ret_types.size());
2259 for (auto ret : fbody.ret_nodes) {
2260 // Find node in the graph using the node id instead of using `ret` directly
2261 // because the graph has been cloned.
2262 auto* node = graph_->FindNodeId(ret->id());
2263 TF_ASSIGN_OR_RETURN(auto type, InferInputType(*node, /*idx=*/0, builder));
2264 ret_types.push_back(type);
2265 }
2266
2267 return builder.getFunctionType(arg_types, ret_types);
2268 }
2269
2270 // Stateful helper class to import a TensorFlow model expressed in GraphDef into
2271 // an MLIR Module.
2272 //
2273 // The nodes defined in the graph are converted to a function called
2274 // 'func_name'. All library function definitions are converted to MLIR functions
2275 // in the module.
2276 class GraphDefImporter : public ImporterBase {
2277 public:
2278 // Main entry point: converts the given graph to an MLIR Module.
2279 static StatusOr<mlir::OwningModuleRef> Convert(
2280 mlir::MLIRContext* context, const Graph& graph,
2281 const GraphDebugInfo& debug_info,
2282 const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs,
2283 std::unordered_map<std::string, std::string>& tf_name_to_mlir_name);
2284
2285 private:
GraphDefImporter(const FunctionLibraryDefinition & flib,const GraphDebugInfo & debug_info,const GraphImportConfig & specs,mlir::ModuleOp module,std::unordered_map<std::string,std::string> * tf_name_to_mlir_name,NameUniquifier * function_name_uniquifier)2286 explicit GraphDefImporter(
2287 const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info,
2288 const GraphImportConfig& specs, mlir::ModuleOp module,
2289 std::unordered_map<std::string, std::string>* tf_name_to_mlir_name,
2290 NameUniquifier* function_name_uniquifier)
2291 : ImporterBase(flib, debug_info, specs, module, tf_name_to_mlir_name,
2292 function_name_uniquifier) {}
2293
2294 // Returns the function signature of the main function of converted MLIR
2295 // module, the input nodes and output nodes. The type and shape information
2296 // for the function arguments are read from `specs`, but the type and shape
2297 // information for the function returns are inferred by the shape refiner in
2298 // ImporterBase.
2299 StatusOr<mlir::FunctionType> InferMainFunctionType(
2300 const GraphImportConfig& specs, mlir::MLIRContext* context,
2301 absl::InlinedVector<OutputTensor, 4>* arg_nodes,
2302 absl::InlinedVector<OutputTensor, 4>* ret_nodes);
2303
2304 // Returns the function signature of the main function, alongside input and
2305 // output nodes, for function graphs. Arguments and return values are
2306 // determined by node op type. Type and shape information of the function are
2307 // inferred by the shape refiner in ImporterBase.
2308 StatusOr<mlir::FunctionType> GetArgsRetsAndTypesFromFunctionGraph(
2309 mlir::MLIRContext* context,
2310 absl::InlinedVector<OutputTensor, 4>* arg_nodes,
2311 absl::InlinedVector<OutputTensor, 4>* ret_nodes);
2312
2313 // Finds the graph's target nodes/function's control ret nodes based on
2314 // supplied node names in `control_outputs`. If `control_outputs` are not
2315 // unique or a control ret node is missing, an error will be returned.
2316 Status GetControlRetsFromGraph(
2317 llvm::ArrayRef<std::string> control_outputs,
2318 absl::InlinedVector<Node*, 4>* control_ret_nodes);
2319 };
2320
Convert(mlir::MLIRContext * context,const Graph & graph,const GraphDebugInfo & debug_info,const FunctionLibraryDefinition & flib_def,const GraphImportConfig & specs,std::unordered_map<std::string,std::string> & tf_name_to_mlir_name)2321 StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
2322 mlir::MLIRContext* context, const Graph& graph,
2323 const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def,
2324 const GraphImportConfig& specs,
2325 std::unordered_map<std::string, std::string>& tf_name_to_mlir_name) {
2326 LoadImporterDialects(*context);
2327 mlir::OwningModuleRef module =
2328 mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
2329 NameUniquifier function_name_uniquifier(flib_def);
2330
2331 // importer.PrepareConvert below will attemp to clone the original `graph`
2332 // via conversion to the graph def first. Convert graph to graph_def here
2333 // first and avoid extra copies later.
2334 auto graph_def = std::make_unique<GraphDef>();
2335 graph.ToGraphDef(graph_def.get());
2336
2337 static std::atomic<uint32> counter(0);
2338 uint32 current_file_prefix = counter++;
2339 const auto* graph_crash_handle = crash_analysis::ReportProtoDataOnCrash(
2340 absl::StrCat(current_file_prefix, "_mlir_import_graph.pbtxt"),
2341 *graph_def);
2342 auto reachable_flib = flib_def.ReachableDefinitions(*graph_def);
2343 const auto* flib_crash_handle = crash_analysis::ReportProtoDataOnCrash(
2344 absl::StrCat(current_file_prefix, "_mlir_import_flib.pbtxt"),
2345 reachable_flib.ToProto());
2346
2347 auto scope_exit = llvm::make_scope_exit([&]() {
2348 crash_analysis::RemoveReportData(graph_crash_handle);
2349 crash_analysis::RemoveReportData(flib_crash_handle);
2350 });
2351
2352 VLOG(1) << "Importing: "
2353 << ::tensorflow::DumpGraphToFile("tf_mlir_importer_base", graph,
2354 &flib_def);
2355
2356 GraphDefImporter importer(flib_def, debug_info, specs, module.get(),
2357 &tf_name_to_mlir_name, &function_name_uniquifier);
2358
2359 TF_RETURN_IF_ERROR(importer.PrepareConvert(graph, std::move(graph_def)));
2360
2361 mlir::FunctionType func_type;
2362 absl::InlinedVector<OutputTensor, 4> arg_nodes;
2363 absl::InlinedVector<OutputTensor, 4> ret_nodes;
2364 absl::InlinedVector<Node*, 4> control_ret_nodes;
2365 llvm::SmallVector<mlir::NamedAttribute, 1> attrs;
2366 if (specs.graph_as_function) {
2367 if (specs.prune_unused_nodes || !specs.inputs.empty() ||
2368 !specs.outputs.empty())
2369 return errors::InvalidArgument(
2370 "Pruning of graph is currently unsupported when the main graph is "
2371 "converted to a function.");
2372
2373 TF_ASSIGN_OR_RETURN(func_type,
2374 importer.GetArgsRetsAndTypesFromFunctionGraph(
2375 context, &arg_nodes, &ret_nodes));
2376
2377 TF_RETURN_IF_ERROR(importer.GetControlRetsFromGraph(specs.control_outputs,
2378 &control_ret_nodes));
2379
2380 mlir::Builder b(context);
2381 std::string s;
2382 llvm::raw_string_ostream ss(s);
2383 auto node_name = [&](const OutputTensor& tensor) {
2384 ss << tensor.node->name();
2385 };
2386 llvm::interleave(arg_nodes, ss, node_name, ",");
2387 auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str()));
2388 s.clear();
2389 llvm::interleave(ret_nodes, ss, node_name, ",");
2390 auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str()));
2391 s.clear();
2392 llvm::interleave(specs.control_outputs, ss, ",");
2393 auto control_outputs =
2394 b.getNamedAttr("control_outputs", b.getStringAttr(ss.str()));
2395
2396 // Under `graph_as_function` mode, `tf.entry_function` is always set as it
2397 // is assumed feed, fetch, and target nodes are set correctly.
2398 attrs.push_back(b.getNamedAttr(
2399 "tf.entry_function",
2400 b.getDictionaryAttr({inputs, outputs, control_outputs})));
2401 } else {
2402 // Collects the argument and return nodes by looking up the node names
2403 // specified by the user.
2404 TF_ASSIGN_OR_RETURN(func_type, importer.InferMainFunctionType(
2405 specs, context, &arg_nodes, &ret_nodes));
2406
2407 TF_RETURN_IF_ERROR(importer.GetControlRetsFromGraph(specs.control_outputs,
2408 &control_ret_nodes));
2409
2410 // TODO(prakalps): Refactor to keep tf.entry_function attribute encoding and
2411 // decoding in a centralized place.
2412 // Record the input and output mapping.
2413 if (!specs.inputs.empty() || !specs.outputs.empty() ||
2414 !specs.control_outputs.empty()) {
2415 mlir::Builder b(context);
2416 std::string s;
2417 llvm::raw_string_ostream ss(s);
2418 llvm::interleave(
2419 specs.inputs, ss,
2420 [&](const std::pair<std::string, ArrayInfo>& v) { ss << v.first; },
2421 ",");
2422 auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str()));
2423 s.clear();
2424 llvm::interleave(specs.outputs, ss, ",");
2425 auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str()));
2426 s.clear();
2427 llvm::interleave(specs.control_outputs, ss, ",");
2428 auto control_outputs =
2429 b.getNamedAttr("control_outputs", b.getStringAttr(ss.str()));
2430
2431 attrs.push_back(b.getNamedAttr(
2432 "tf.entry_function",
2433 b.getDictionaryAttr({inputs, outputs, control_outputs})));
2434 }
2435 }
2436
2437 // Record version info.
2438 PopulateTfVersions(module.get(), graph.versions());
2439
2440 const auto& graph_func_name = specs.graph_func_name.empty()
2441 ? kImportModelDefaultGraphFuncName
2442 : specs.graph_func_name;
2443 TF_RETURN_IF_ERROR(importer.ImporterBase::Convert(graph_func_name, func_type,
2444 arg_nodes, ret_nodes,
2445 control_ret_nodes, attrs));
2446 TF_RETURN_IF_ERROR(importer.ImporterBase::ConvertDeferredFunctions());
2447
2448 // Mark main function public, others private.
2449 for (auto function : module.get().getOps<mlir::FuncOp>()) {
2450 auto visibility = function.getName() == graph_func_name
2451 ? mlir::FuncOp::Visibility::Public
2452 : mlir::FuncOp::Visibility::Private;
2453 function.setVisibility(visibility);
2454 }
2455 return module;
2456 }
2457
InferMainFunctionType(const GraphImportConfig & specs,mlir::MLIRContext * context,absl::InlinedVector<OutputTensor,4> * arg_nodes,absl::InlinedVector<OutputTensor,4> * ret_nodes)2458 StatusOr<mlir::FunctionType> GraphDefImporter::InferMainFunctionType(
2459 const GraphImportConfig& specs, mlir::MLIRContext* context,
2460 absl::InlinedVector<OutputTensor, 4>* arg_nodes,
2461 absl::InlinedVector<OutputTensor, 4>* ret_nodes) {
2462 // Find all the input nodes and output nodes.
2463 // Feeds have been remapped to single output nodes (Placeholder), so an exact
2464 // name match is sufficient.
2465 absl::flat_hash_map<absl::string_view, int> inputs;
2466 for (auto input_and_idx : llvm::enumerate(specs.inputs)) {
2467 TensorId tensor = ParseTensorName(input_and_idx.value().first);
2468 auto remapped_it = remapped_feeds_.find(tensor);
2469 if (remapped_it != remapped_feeds_.end()) {
2470 inputs.insert({remapped_it->second, input_and_idx.index()});
2471 } else {
2472 inputs.insert({tensor.node(), input_and_idx.index()});
2473 }
2474 }
2475
2476 absl::flat_hash_set<absl::string_view> output_node_names;
2477 std::vector<TensorId> outputs;
2478 output_node_names.reserve(specs.outputs.size());
2479 for (const auto& output : specs.outputs) {
2480 TensorId tensor = ParseTensorName(output);
2481 auto remapped_it = remapped_feeds_.find(tensor);
2482 if (remapped_it != remapped_feeds_.end()) {
2483 output_node_names.insert(remapped_it->second);
2484 outputs.push_back({remapped_it->second, 0});
2485 } else {
2486 output_node_names.insert(tensor.node());
2487 outputs.push_back(tensor);
2488 }
2489 }
2490
2491 if (!inputs.empty() || !outputs.empty()) {
2492 arg_nodes->resize(inputs.size());
2493 ret_nodes->resize(outputs.size());
2494
2495 for (Node* n : GetOrderedNodes()) {
2496 // Handle inputs/arguments.
2497 auto input_it = inputs.find(n->name());
2498 if (input_it != inputs.end()) {
2499 (*arg_nodes)[input_it->second] = {n, 0};
2500 }
2501
2502 // Handle outputs/returns.
2503 if (output_node_names.contains(n->name())) {
2504 for (int i = 0, e = outputs.size(); i != e; ++i) {
2505 TensorId tensor = outputs[i];
2506 if (n->name() != tensor.node()) continue;
2507 (*ret_nodes)[i] = {n, tensor.index()};
2508 }
2509 }
2510 }
2511 }
2512
2513 // Starts to construct the function type.
2514 mlir::Builder builder(context);
2515 llvm::SmallVector<mlir::Type, 4> arg_types;
2516 arg_types.reserve(specs.inputs.size());
2517 int i = 0;
2518 for (const auto& it : specs.inputs) {
2519 Node* arg_node = arg_nodes->at(i).node;
2520 if (arg_node == nullptr) {
2521 return errors::InvalidArgument("Input ", it.first,
2522 " was not found in graph");
2523 }
2524 mlir::Type element_type;
2525 const auto& node_info = it.second;
2526 DataType imported_dtype = node_info.imported_dtype;
2527 // Uses the existing output type of the arg node if the data type of the
2528 // the node isn't specified through the import configuration.
2529 if (imported_dtype == DT_INVALID) {
2530 imported_dtype = arg_node->output_type(0);
2531 if (imported_dtype == DT_INVALID) {
2532 return errors::InvalidArgument("Input ", i, "has invalid data type");
2533 }
2534 }
2535 TF_RETURN_IF_ERROR(
2536 ::tensorflow::ConvertDataType(imported_dtype, builder, &element_type));
2537 if (node_info.shape.unknown_rank()) {
2538 arg_types.push_back(mlir::UnrankedTensorType::get(element_type));
2539 } else {
2540 llvm::SmallVector<int64_t, 4> shape;
2541 TF_RETURN_IF_ERROR(ConvertToMlirShape(node_info.shape, &shape));
2542 arg_types.push_back(mlir::RankedTensorType::get(shape, element_type));
2543 }
2544 i++;
2545 }
2546
2547 llvm::SmallVector<mlir::Type, 4> ret_types;
2548 ret_types.reserve(specs.outputs.size());
2549 for (int i = 0, e = specs.outputs.size(); i != e; ++i) {
2550 if (ret_nodes->at(i).node == nullptr) {
2551 return errors::InvalidArgument("Output ", specs.outputs[i],
2552 " was not found in graph");
2553 }
2554 }
2555 for (const auto& ret : *ret_nodes) {
2556 if (ret.node->num_outputs() <= ret.index) {
2557 return errors::InvalidArgument("Invalid output index ", ret.index,
2558 " specified for node: ", ret.node->name());
2559 }
2560 TF_ASSIGN_OR_RETURN(auto type,
2561 InferOutputType(*ret.node, ret.index, builder));
2562 ret_types.push_back(type);
2563 }
2564
2565 return builder.getFunctionType(arg_types, ret_types);
2566 }
2567
2568 StatusOr<mlir::FunctionType>
GetArgsRetsAndTypesFromFunctionGraph(mlir::MLIRContext * context,absl::InlinedVector<OutputTensor,4> * arg_nodes,absl::InlinedVector<OutputTensor,4> * ret_nodes)2569 GraphDefImporter::GetArgsRetsAndTypesFromFunctionGraph(
2570 mlir::MLIRContext* context, absl::InlinedVector<OutputTensor, 4>* arg_nodes,
2571 absl::InlinedVector<OutputTensor, 4>* ret_nodes) {
2572 auto add_node = [](Node* node, absl::InlinedVector<OutputTensor, 4>* nodes) {
2573 auto* attr = node->attrs().Find("index");
2574 if (!attr)
2575 return errors::InvalidArgument(node->type_string(), " node '",
2576 node->name(),
2577 "' is missing attribute 'index'");
2578
2579 auto index = attr->i();
2580 const int num_nodes = nodes->size();
2581 if (num_nodes < index + 1) nodes->resize(index + 1);
2582
2583 if ((*nodes)[index].node != nullptr)
2584 return errors::InvalidArgument(node->type_string(), " node '",
2585 node->name(), "' has attribute 'index' ",
2586 index, " that conflicts with node '",
2587 (*nodes)[index].node->name(), "'");
2588 (*nodes)[index] = {node, 0};
2589
2590 return Status::OK();
2591 };
2592
2593 // Collect arg and ret nodes from graph.
2594 for (auto* node : GetOrderedNodes())
2595 if (node->IsArg())
2596 TF_RETURN_IF_ERROR(add_node(node, arg_nodes));
2597 else if (node->IsRetval())
2598 TF_RETURN_IF_ERROR(add_node(node, ret_nodes));
2599
2600 // Collect arg and ret types and create function type.
2601 mlir::Builder builder(context);
2602 llvm::SmallVector<mlir::Type, 4> arg_types;
2603 arg_types.reserve(arg_nodes->size());
2604 for (auto arg_node_and_idx : llvm::enumerate(*arg_nodes)) {
2605 auto& arg_node = arg_node_and_idx.value();
2606 if (arg_node.node == nullptr)
2607 return errors::InvalidArgument("Graph missing _Arg at index ",
2608 arg_node_and_idx.index());
2609
2610 TF_ASSIGN_OR_RETURN(auto type,
2611 InferOutputType(*arg_node.node, /*idx=*/0, builder));
2612 arg_types.push_back(type);
2613 }
2614
2615 llvm::SmallVector<mlir::Type, 4> ret_types;
2616 ret_types.reserve(ret_nodes->size());
2617 for (auto ret_node_and_idx : llvm::enumerate(*ret_nodes)) {
2618 auto& ret_node = ret_node_and_idx.value();
2619 if (ret_node.node == nullptr)
2620 return errors::InvalidArgument("Graph missing _Retval at index ",
2621 ret_node_and_idx.index());
2622
2623 TF_ASSIGN_OR_RETURN(auto type,
2624 InferInputType(*ret_node.node, /*idx=*/0, builder));
2625 ret_types.push_back(type);
2626 }
2627
2628 return builder.getFunctionType(arg_types, ret_types);
2629 }
2630
GetControlRetsFromGraph(llvm::ArrayRef<std::string> control_outputs,absl::InlinedVector<Node *,4> * control_ret_nodes)2631 Status GraphDefImporter::GetControlRetsFromGraph(
2632 llvm::ArrayRef<std::string> control_outputs,
2633 absl::InlinedVector<Node*, 4>* control_ret_nodes) {
2634 if (control_outputs.empty()) return Status::OK();
2635
2636 llvm::SmallDenseMap<llvm::StringRef, int32_t> controls_to_idx;
2637 for (auto control_and_idx : llvm::enumerate(control_outputs))
2638 controls_to_idx.insert({control_and_idx.value(), control_and_idx.index()});
2639
2640 if (controls_to_idx.size() != control_outputs.size())
2641 return errors::InvalidArgument("Control outputs must be unique");
2642
2643 control_ret_nodes->resize(controls_to_idx.size());
2644
2645 for (auto* node : GetOrderedNodes()) {
2646 auto it = controls_to_idx.find(node->name());
2647 if (it != controls_to_idx.end()) (*control_ret_nodes)[it->second] = node;
2648 }
2649
2650 for (auto node_and_name : llvm::zip(*control_ret_nodes, control_outputs))
2651 if (std::get<0>(node_and_name) == nullptr)
2652 return errors::InvalidArgument(
2653 "Control output '", std::get<1>(node_and_name), "' is missing");
2654
2655 return Status::OK();
2656 }
2657
2658 // Stateful helper class to import a TensorFlow model expressed in SavedModel
2659 // into an MLIR Module.
2660 class SavedModelObjectGraphImporter : public ImporterBase {
2661 public:
2662 // Main entry point: converts all functions in the given meta graph to an MLIR
2663 // Module.
2664 static StatusOr<mlir::OwningModuleRef> Convert(
2665 SavedModelV2Bundle* saved_model, absl::Span<std::string> exported_names,
2666 mlir::MLIRContext* context, bool add_default_attributes);
2667
2668 private:
SavedModelObjectGraphImporter(const FunctionLibraryDefinition & flib,const GraphDebugInfo & debug_info,const GraphImportConfig & specs,mlir::ModuleOp module,std::unordered_map<std::string,std::string> * tf_name_to_mlir_name,NameUniquifier * function_name_uniquifier)2669 explicit SavedModelObjectGraphImporter(
2670 const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info,
2671 const GraphImportConfig& specs, mlir::ModuleOp module,
2672 std::unordered_map<std::string, std::string>* tf_name_to_mlir_name,
2673 NameUniquifier* function_name_uniquifier)
2674 : ImporterBase(flib, debug_info, specs, module, tf_name_to_mlir_name,
2675 function_name_uniquifier) {}
2676 };
2677
2678 // Determines the names used to reference objects in the SavedObjectGraph.
2679 class ObjectNames {
2680 public:
2681 explicit ObjectNames(const SavedObjectGraph& object_graph,
2682 absl::Span<std::string> exported_names);
2683
2684 // Gets the names that external users of the SavedModel can use to refer to
2685 // this node.
2686 llvm::ArrayRef<llvm::StringRef> GetExportedNames(int node_id) const;
2687
2688 // Gets the name in the module symbol table for this node.
2689 // This name is only used for internal IR references.
2690 llvm::StringRef GetSymbolTableName(int node_id) const;
2691
2692 private:
2693 // In the absence of any other information, use this name as the symbol table
2694 // name for this node.
2695 std::string GetDefaultSymbolTableName(int node_id) const;
2696 // Determines if a name is exported.
2697 bool IsExported(const std::string& name);
2698 // Main object graph traversal function.
2699 void RecursivelyVisitObjectGraph(int node_id);
2700 // Gets a stable StringRef from a std::string.
2701 llvm::StringRef SaveString(const std::string& s) const;
2702
2703 // The object graph we are traversing.
2704 const SavedObjectGraph& object_graph_;
2705 // The set of names to export. Empty means "export all".
2706 std::unordered_set<std::string> names_to_export_;
2707
2708 // When we recursively follow the object graph tree structure from the root,
2709 // we track its path in the object graph by pushing and popping from here
2710 // during traversal.
2711 llvm::SmallVector<std::string, 8> path_segments_;
2712 // The set of node_id's that are on the current DFS stack.
2713 // For cyclic object graphs, this prevents infinite recursion.
2714 std::unordered_set<int> on_stack_nodes_;
2715
2716 // Key: node_id.
2717 // Value: all object names that node_id appears as.
2718 // Each object name corresponds to a unique path from the root of the object
2719 // graph.
2720 // The common intuitive case is when there is only one name for a given
2721 // object, which corresponds to the object graph being a tree.
2722 //
2723 // But, there cases where the object graph is a general graph. For
2724 // example, this happens commonly in Keras models, where `foo.bar` is
2725 // also reachable via the name `keras_api.foo.bar`.
2726 // Cycles are possible too.
2727 absl::flat_hash_map<int, std::vector<std::string>> object_names_;
2728
2729 // Key: node_id
2730 // Value: all names that this object is exported as
2731 absl::flat_hash_map<int, llvm::SmallVector<llvm::StringRef, 1>>
2732 exported_names_;
2733 // Key: node_id
2734 // Value: pretty symbol table name to use for internal references to this
2735 // object.
2736 absl::flat_hash_map<int, llvm::StringRef> pretty_symbol_table_name_;
2737
2738 // Stable strings we can take StringRef's into. Used only by the SaveString
2739 // method.
2740 mutable std::unordered_set<std::string> saved_strings_;
2741 };
2742
ObjectNames(const SavedObjectGraph & object_graph,absl::Span<std::string> exported_names)2743 ObjectNames::ObjectNames(const SavedObjectGraph& object_graph,
2744 absl::Span<std::string> exported_names)
2745 : object_graph_(object_graph),
2746 names_to_export_(exported_names.begin(), exported_names.end()) {
2747 // Visit all reachable nodes from the root of the object graph.
2748 // This builds up object_names_ to contain all names like `foo.bar` that a
2749 // particular node in the graph can be reached from.
2750 RecursivelyVisitObjectGraph(/*node_id=*/0);
2751
2752 // Populate the exported_names_ map.
2753 // TODO(silvasean): Diagnose typos in exported names?
2754 for (auto& kv : object_names_) {
2755 // Make object names map independent of our particular choice of object
2756 // graph traversal.
2757 std::sort(kv.second.begin(), kv.second.end(),
2758 [](absl::string_view a, absl::string_view b) {
2759 // The sort order here influences the "pretty name" we assign
2760 // below. We want the most debuggable name to be first.
2761 //
2762 // Debuggability heuristics:
2763 // 1. Names that end in digits are likely to be internal aliases
2764 // to the "real" names.
2765 // 2. Longer names are more likely to be internal aliases.
2766 //
2767 // Example set of object names created by Keras for the weight
2768 // matrix of a fully connected layer on a trivial FC mnist
2769 // model:
2770 // - `model.layer-1.kernel` (this is the "best" name)
2771 // - `model.keras_api.layers.1.kernel`
2772 // - `model.variables.0`
2773 // - `model.keras_api.layers.1.keras_api.trainable_variables.0`
2774 // - ... 10 more long aliases ending in digits ...
2775 return std::make_tuple(isdigit(a.back()), a.size(), a) <
2776 std::make_tuple(isdigit(b.back()), b.size(), b);
2777 });
2778 for (const std::string& name : kv.second) {
2779 if (IsExported(name)) {
2780 exported_names_[kv.first].push_back(SaveString(name));
2781 }
2782 }
2783 }
2784 // Create "pretty" symbol table names for nodes where that is applicable.
2785 // We could make all symbol table names use the default, which is basically
2786 // just the node id. But for debugging purposes, it's nicer if we can mix in
2787 // a recognizable object name if we have the information to do so.
2788 for (auto& kv : object_names_) {
2789 int node_id = kv.first;
2790 std::string internal_name =
2791 absl::StrCat(GetDefaultSymbolTableName(node_id), "__");
2792 // If the object has an exported name, we prefer that since it is probably
2793 // the most recognizable. Otherwise, we grab some non-exported name of the
2794 // object.
2795 if (exported_names_.find(node_id) != exported_names_.end()) {
2796 internal_name += exported_names_[node_id][0].str();
2797 } else {
2798 internal_name += object_names_[node_id][0];
2799 }
2800 pretty_symbol_table_name_[node_id] = SaveString(internal_name);
2801 }
2802 }
2803
GetExportedNames(int node_id) const2804 llvm::ArrayRef<llvm::StringRef> ObjectNames::GetExportedNames(
2805 int node_id) const {
2806 auto it = exported_names_.find(node_id);
2807 if (it != exported_names_.end()) {
2808 return it->second;
2809 }
2810 return {};
2811 }
2812
GetSymbolTableName(int node_id) const2813 llvm::StringRef ObjectNames::GetSymbolTableName(int node_id) const {
2814 auto it = pretty_symbol_table_name_.find(node_id);
2815 if (it != pretty_symbol_table_name_.end()) {
2816 return it->second;
2817 }
2818 return SaveString(GetDefaultSymbolTableName(node_id));
2819 }
2820
GetDefaultSymbolTableName(int node_id) const2821 std::string ObjectNames::GetDefaultSymbolTableName(int node_id) const {
2822 return absl::StrCat("__sm_node", node_id);
2823 }
2824
IsExported(const std::string & name)2825 bool ObjectNames::IsExported(const std::string& name) {
2826 if (names_to_export_.empty()) {
2827 return true;
2828 }
2829 return names_to_export_.find(name) != names_to_export_.end();
2830 }
2831
RecursivelyVisitObjectGraph(int node_id)2832 void ObjectNames::RecursivelyVisitObjectGraph(int node_id) {
2833 const SavedObject& object = object_graph_.nodes(node_id);
2834
2835 switch (object.kind_case()) {
2836 case SavedObject::kConstant:
2837 case SavedObject::kFunction:
2838 case SavedObject::kVariable: {
2839 object_names_[node_id].push_back(absl::StrJoin(path_segments_, "."));
2840 break;
2841 }
2842 default:
2843 break;
2844 }
2845
2846 for (const auto& child_ref : object.children()) {
2847 bool on_stack = !on_stack_nodes_.insert(child_ref.node_id()).second;
2848 if (on_stack) {
2849 // This is a backedge. Don't traverse it.
2850 continue;
2851 }
2852
2853 path_segments_.push_back(child_ref.local_name());
2854 RecursivelyVisitObjectGraph(child_ref.node_id());
2855 path_segments_.pop_back();
2856
2857 on_stack_nodes_.erase(child_ref.node_id());
2858 }
2859 }
2860
SaveString(const std::string & s) const2861 llvm::StringRef ObjectNames::SaveString(const std::string& s) const {
2862 return llvm::StringRef(*saved_strings_.insert(s).first);
2863 }
2864
2865 // Extracts a TensorProto for a Const op from a GraphDef, given an op_name.
2866 // Returns nullptr on not found or other mismatch.
2867 // This returns a pointer to the actual node within the graph_def so as to
2868 // avoid expensive copies.
ExtractConstTensorFromGraph(const GraphDef & graph_def,const std::string & op_name)2869 const TensorProto* ExtractConstTensorFromGraph(const GraphDef& graph_def,
2870 const std::string& op_name) {
2871 const NodeDef* match_node = nullptr;
2872 for (const auto& node : graph_def.node()) {
2873 if (node.name() == op_name) {
2874 match_node = &node;
2875 }
2876 }
2877
2878 if (!match_node) {
2879 return nullptr;
2880 }
2881
2882 auto value_it = match_node->attr().find("value");
2883 if (value_it == match_node->attr().end()) {
2884 return nullptr;
2885 }
2886
2887 if (!value_it->second.has_tensor()) {
2888 return nullptr;
2889 }
2890
2891 return &value_it->second.tensor();
2892 }
2893
2894 const TrackableObjectGraph::TrackableObject::SerializedTensor*
FindSerializedTensorInTrackable(const TrackableObjectGraph::TrackableObject & trackable_object,StringPiece name)2895 FindSerializedTensorInTrackable(
2896 const TrackableObjectGraph::TrackableObject& trackable_object,
2897 StringPiece name) {
2898 for (const auto& maybe_serialized_tensor : trackable_object.attributes()) {
2899 if (maybe_serialized_tensor.name() == name) {
2900 return &maybe_serialized_tensor;
2901 }
2902 }
2903 return nullptr;
2904 }
2905
DiagnoseMultipleConcreteFunctions(const SavedObjectGraph & object_graph,const ObjectNames & object_names)2906 Status DiagnoseMultipleConcreteFunctions(const SavedObjectGraph& object_graph,
2907 const ObjectNames& object_names) {
2908 for (int node_id = 0; node_id < object_graph.nodes_size(); node_id++) {
2909 const SavedObject& object = object_graph.nodes(node_id);
2910 if (object_names.GetExportedNames(node_id).empty()) {
2911 continue;
2912 }
2913 if (object.kind_case() == SavedObject::kFunction) {
2914 // We only allow a single input signature to each SavedFunction.
2915 // This assumption means we have a 1:1 correspondence between
2916 // tf.function <=> SavedFunction <=> SavedConcreteFunction <=> FunctionDef
2917 // This makes defining the ABI easier (or even well-defined at all).
2918 // TODO(silvasean): How to detect a function that doesn't have an
2919 // explicitly user-provided input signature, but happens to have been
2920 // traced exactly once?
2921 if (object.function().concrete_functions_size() != 1) {
2922 llvm::SmallVector<std::string, 4> names;
2923 for (llvm::StringRef s : object_names.GetExportedNames(node_id)) {
2924 names.push_back("'" + s.str() + "'");
2925 }
2926 return errors::InvalidArgument(
2927 "Exported function with exported name(s) ",
2928 absl::StrJoin(names, ", "),
2929 " with multiple concrete functions. Add "
2930 "@tf.function(input_signature=[...]) on this function, or use a "
2931 "narrower list of exported names that excludes this function.");
2932 }
2933 }
2934 }
2935 return Status::OK();
2936 }
2937
2938 // Recursively traverses a StructuredValue, linearizing all the leaves.
2939 //
2940 // This currently only handles the subset of StructuredValue that is needed for
2941 // signatures.
2942 //
2943 // Given a StructuredValue with structure [{"x": leaf0}], the "index path"
2944 // needed to reach leaf0 is `[0, "x"]`, as it would be if you were operating on
2945 // a Python object (`obj[0]["x"] is leaf0`). Each leaf corresponds to a
2946 // linearized function argument or return on a FunctionDef, and hence to an
2947 // mlir::FuncOp argument / return.
2948 //
2949 // This must match the linearization that happens in `tf.nest.flatten`.
2950 // In particular, dict values should be linearized in sorted key order.
2951 //
2952 // The linearized index paths can be returned back to a structured
2953 // representation (e.g. to emit C structs matching a signature) with a simple
2954 // algorithm that recurses on each run of index paths with identical first
2955 // elements.
2956 class StructuredValueLinearizer {
2957 public:
2958 StructuredValueLinearizer(const StructuredValue& value,
2959 mlir::MLIRContext* context);
2960
2961 // Returns the list of index paths to each leaf of the StructuredValue,
2962 // in a linearized order matching `tf.nest.flatten`.
2963 //
2964 // If an error occurred during the linearization process, an error message
2965 // with `error_context` prepended will be included in the returned status.
2966 StatusOr<llvm::ArrayRef<mlir::ArrayAttr>> GetLeafIndexPaths(
2967 llvm::StringRef error_context) const;
2968
2969 private:
2970 // Main function that recursively traverses the StructuredValue.
2971 void RecursivelyFindLeaves(const StructuredValue& value);
2972
2973 mlir::Builder builder_;
2974 // The current index path. We push/pop this during recursive traversal of the
2975 // StructuredValue.
2976 llvm::SmallVector<mlir::Attribute, 4> current_index_path_;
2977 // The list of leaf index paths we have discovered so far.
2978 llvm::SmallVector<mlir::ArrayAttr, 4> leaf_index_paths_;
2979 // If non-empty, an error message to report.
2980 std::string error_message_;
2981 };
2982
StructuredValueLinearizer(const StructuredValue & value,mlir::MLIRContext * context)2983 StructuredValueLinearizer::StructuredValueLinearizer(
2984 const StructuredValue& value, mlir::MLIRContext* context)
2985 : builder_(context) {
2986 RecursivelyFindLeaves(value);
2987 }
2988
2989 StatusOr<llvm::ArrayRef<mlir::ArrayAttr>>
GetLeafIndexPaths(llvm::StringRef error_context) const2990 StructuredValueLinearizer::GetLeafIndexPaths(
2991 llvm::StringRef error_context) const {
2992 if (error_message_.empty()) {
2993 return llvm::makeArrayRef(leaf_index_paths_);
2994 }
2995 return errors::InvalidArgument(
2996 error_context.str(), error_message_,
2997 "This likely means that you have @tf.function "
2998 "on an exported function instead of "
2999 "@tf.function(input_signature=[...]). Consider annotating an "
3000 "input_signature or narrowing your set of "
3001 "exported names to not include this function.");
3002 }
3003
RecursivelyFindLeaves(const StructuredValue & value)3004 void StructuredValueLinearizer::RecursivelyFindLeaves(
3005 const StructuredValue& value) {
3006 switch (value.kind_case()) {
3007 case StructuredValue::kDictValue: {
3008 // Dict values must be linearized in sorted order of keys.
3009 const DictValue& dict = value.dict_value();
3010 using FieldTy = protobuf::MapPair<std::string, StructuredValue>;
3011 llvm::SmallVector<const FieldTy*, 4> fields;
3012 for (auto& field : dict.fields()) {
3013 fields.push_back(&field);
3014 }
3015 llvm::sort(fields, [](const FieldTy* a, const FieldTy* b) {
3016 return a->first < b->first;
3017 });
3018 for (auto& field : fields) {
3019 current_index_path_.push_back(builder_.getStringAttr(field->first));
3020 RecursivelyFindLeaves(field->second);
3021 current_index_path_.pop_back();
3022 }
3023 return;
3024 }
3025 case StructuredValue::kTupleValue: {
3026 const TupleValue& tuple = value.tuple_value();
3027 for (int i = 0, e = tuple.values_size(); i < e; i++) {
3028 current_index_path_.push_back(builder_.getI64IntegerAttr(i));
3029 RecursivelyFindLeaves(tuple.values(i));
3030 current_index_path_.pop_back();
3031 }
3032 return;
3033 }
3034 // We don't differentiate between tuples and lists.
3035 case StructuredValue::kListValue: {
3036 const ListValue& list = value.list_value();
3037 for (int i = 0, e = list.values_size(); i < e; i++) {
3038 current_index_path_.push_back(builder_.getI64IntegerAttr(i));
3039 RecursivelyFindLeaves(list.values(i));
3040 current_index_path_.pop_back();
3041 }
3042 return;
3043 }
3044 case StructuredValue::kTensorSpecValue: {
3045 // Base case: record the current path stack as the index path needed to
3046 // get to this leaf.
3047 leaf_index_paths_.push_back(builder_.getArrayAttr(current_index_path_));
3048 return;
3049 }
3050 case StructuredValue::kNoneValue: {
3051 // Base case: do nothing.
3052 // This arises, for example, as the top-level object of an output
3053 // signature when there are no return values.
3054 return;
3055 }
3056 default: {
3057 llvm::raw_string_ostream os(error_message_);
3058 // TODO(silvasean): Use an enumerant name string instead of a number.
3059 os << "Unhandled structured value kind " << value.kind_case()
3060 << " at index path: <value>";
3061 for (auto path_element : current_index_path_) {
3062 os << ".";
3063 if (auto integer = path_element.dyn_cast<mlir::IntegerAttr>()) {
3064 os << integer.getValue();
3065 } else {
3066 auto str = path_element.cast<mlir::StringAttr>();
3067 os << str.getValue();
3068 }
3069 }
3070 os << "\n";
3071 }
3072 }
3073 }
3074
3075 // For exported functions with bound inputs, rewrite the function
3076 // signature to match the requirements of tf_saved_model bound input args.
3077 //
3078 // The raw imported functions have `tensor<*x!tf_type.resource>` as the type for
3079 // mutable bound inputs and `tensor<...>` as the type for immutable
3080 // bound inputs. Here we canonicalize both of them into
3081 // `tensor<!tf_type.resource<tensor<...>>>`.
AdjustBoundInputArgTypes(mlir::ModuleOp module)3082 void AdjustBoundInputArgTypes(mlir::ModuleOp module) {
3083 mlir::SymbolTable symbol_table(module);
3084 for (auto func : module.getOps<mlir::FuncOp>()) {
3085 if (!mlir::tf_saved_model::IsExported(func)) continue;
3086 mlir::OpBuilder builder(func.getBody());
3087 llvm::SmallVector<mlir::Type, 4> new_input_types;
3088 for (int i = 0, e = func.getNumArguments(); i < e; i++) {
3089 auto arg = func.getArgument(i);
3090 auto global_tensor = mlir::tf_saved_model::LookupBoundInputOfType<
3091 mlir::tf_saved_model::GlobalTensorOp>(func, i, symbol_table);
3092 if (global_tensor) {
3093 auto old_type = arg.getType();
3094 auto new_type =
3095 mlir::tf_saved_model::GetBoundInputArgTypeFor(global_tensor);
3096 arg.setType(new_type);
3097 if (global_tensor.is_mutable()) {
3098 auto arg_with_original_type = builder.create<mlir::TF::CastOp>(
3099 global_tensor.getLoc(), old_type, arg,
3100 /*Truncate=*/builder.getBoolAttr(false));
3101 arg.replaceAllUsesWith(arg_with_original_type);
3102 // The RAUW replaces the arg with itself, so we need to set it back.
3103 arg_with_original_type.setOperand(arg);
3104 } else {
3105 auto arg_with_original_type =
3106 builder.create<mlir::TF::ReadVariableOp>(global_tensor.getLoc(),
3107 old_type, arg);
3108 arg.replaceAllUsesWith(arg_with_original_type);
3109 // The RAUW replaces the arg with itself, so we need to set it back.
3110 arg_with_original_type.setOperand(arg);
3111 }
3112 }
3113 new_input_types.push_back(arg.getType());
3114 }
3115 func.setType(mlir::FunctionType::get(module.getContext(), new_input_types,
3116 func.getType().getResults()));
3117 }
3118 }
3119
3120 // Marks the visibility of functions in the saved model module.
MarkSavedModelFunctionVisibility(mlir::ModuleOp module)3121 void MarkSavedModelFunctionVisibility(mlir::ModuleOp module) {
3122 for (auto func : module.getOps<mlir::FuncOp>()) {
3123 auto visibility = mlir::tf_saved_model::IsExported(func)
3124 ? mlir::FuncOp::Visibility::Public
3125 : mlir::FuncOp::Visibility::Private;
3126 func.setVisibility(visibility);
3127 }
3128 }
3129
3130 // Reorder the ops in the module to make testing easier and less dependent
3131 // on implementation details such as the order of functions in the
3132 // FunctionDefLibrary.
3133 //
3134 // The order this ensures is:
3135 // 1. GlobalTensorOp's
3136 // 2. FuncOps's.
3137 //
3138 // Within each of 1. and 2., ops are sorted by exported name (if
3139 // available, and only the first exported name is considered), followed by
3140 // non-exported ops.
SortSavedModelModule(mlir::ModuleOp module)3141 void SortSavedModelModule(mlir::ModuleOp module) {
3142 struct NamedGlobalTensor {
3143 llvm::StringRef name;
3144 GlobalTensorOp global_tensor;
3145 };
3146 llvm::SmallVector<NamedGlobalTensor, 8> named_global_tensors;
3147 for (auto global_tensor : module.getOps<GlobalTensorOp>()) {
3148 auto exported_names = mlir::tf_saved_model::GetExportedNames(global_tensor);
3149 // We use stable_sort, so duplicate empty names are fine here.
3150 named_global_tensors.push_back(
3151 {exported_names.empty() ? "" : exported_names.front(), global_tensor});
3152 }
3153 llvm::stable_sort(named_global_tensors,
3154 [](const NamedGlobalTensor& a, const NamedGlobalTensor& b) {
3155 return std::make_tuple(a.name.empty(), a.name) <
3156 std::make_tuple(b.name.empty(), b.name);
3157 });
3158
3159 struct NamedFunc {
3160 llvm::StringRef name;
3161 mlir::FuncOp func;
3162 };
3163 llvm::SmallVector<NamedFunc, 8> named_funcs;
3164 llvm::SmallVector<mlir::FuncOp, 8> private_funcs;
3165 for (auto func : module.getOps<mlir::FuncOp>()) {
3166 auto exported_names = mlir::tf_saved_model::GetExportedNames(func);
3167 if (!exported_names.empty())
3168 named_funcs.push_back({exported_names.front(), func});
3169 else
3170 private_funcs.push_back(func);
3171 }
3172 llvm::stable_sort(named_funcs, [](const NamedFunc& a, const NamedFunc& b) {
3173 return a.name < b.name;
3174 });
3175 llvm::stable_sort(private_funcs, [](mlir::FuncOp a, mlir::FuncOp b) {
3176 return a.getName() < b.getName();
3177 });
3178
3179 struct NamedAsset {
3180 llvm::StringRef name;
3181 AssetOp asset;
3182 };
3183 llvm::SmallVector<NamedAsset, 4> assets;
3184 for (auto asset : module.getOps<AssetOp>()) {
3185 assets.push_back({asset.getName(), asset});
3186 }
3187 llvm::stable_sort(assets, [](const NamedAsset& a, const NamedAsset& b) {
3188 return a.name < b.name;
3189 });
3190
3191 // Move onto the front of the module in reverse of the final desired order.
3192 for (auto func : llvm::reverse(private_funcs)) {
3193 func.getOperation()->moveBefore(&module.getBody()->front());
3194 }
3195 for (auto named_func : llvm::reverse(named_funcs)) {
3196 named_func.func.getOperation()->moveBefore(&module.getBody()->front());
3197 }
3198 for (auto named_global_tensor : llvm::reverse(named_global_tensors)) {
3199 named_global_tensor.global_tensor.getOperation()->moveBefore(
3200 &module.getBody()->front());
3201 }
3202
3203 for (auto asset : assets) {
3204 asset.asset.getOperation()->moveBefore(&module.getBody()->front());
3205 }
3206
3207 auto initializers = module.getOps<SessionInitializerOp>();
3208 if (!initializers.empty()) {
3209 (*initializers.begin())
3210 .getOperation()
3211 ->moveBefore(&module.getBody()->front());
3212 }
3213 }
3214
CreateSavedModelIR(const ObjectNames & object_names,mlir::ModuleOp module,const SavedObjectGraph & object_graph,const std::unordered_map<std::string,std::string> & tf_name_to_mlir_name,SavedModelV2Bundle * saved_model)3215 Status CreateSavedModelIR(
3216 const ObjectNames& object_names, mlir::ModuleOp module,
3217 const SavedObjectGraph& object_graph,
3218 const std::unordered_map<std::string, std::string>& tf_name_to_mlir_name,
3219 SavedModelV2Bundle* saved_model) {
3220 mlir::OpBuilder builder(module.getBodyRegion());
3221 mlir::SymbolTable symbol_table(module);
3222
3223 // Create a side data-structure, indexed by the object_graph node_id to
3224 // a TrackableObject that is restorable.
3225 absl::flat_hash_map<int, const TrackableObjectGraph::TrackableObject*>
3226 restored_objects;
3227 TF_RETURN_IF_ERROR(saved_model->VisitObjectsToRestore(
3228 [&](int saved_node_id,
3229 const TrackableObjectGraph::TrackableObject& trackable_object) {
3230 restored_objects.insert(
3231 std::make_pair(saved_node_id, &trackable_object));
3232 return Status::OK();
3233 }));
3234
3235 for (int node_id = 0; node_id < object_graph.nodes_size(); node_id++) {
3236 const SavedObject& object = object_graph.nodes(node_id);
3237 // For correctness, we cannot import functions that don't have exported
3238 // names, since they don't necessarily have a well-defined ABI (diagnosed
3239 // earlier).
3240 //
3241 // For variables/constants, pruning them is purely an optimization,
3242 // and more complicated since it requires use-def analysis of which
3243 // functions use which variables/constants, so we don't do anything
3244 // special for them here as part of our initial IR construction.
3245 if (object.kind_case() == SavedObject::kFunction) {
3246 if (object_names.GetExportedNames(node_id).empty()) {
3247 continue;
3248 }
3249 std::string error_context =
3250 "While importing SavedModel function '" +
3251 object_names.GetExportedNames(node_id)[0].str() + "': ";
3252 const SavedFunction& function = object.function();
3253 auto orig_func = symbol_table.lookup<mlir::FuncOp>(
3254 tf_name_to_mlir_name.find(function.concrete_functions(0))->second);
3255 mlir::FuncOp func = orig_func;
3256 // If there are potentially references to this func from within the
3257 // module, create a wrapper around it and decorate the wrapper with the
3258 // tf_saved_model attributes instead.
3259 if (!mlir::SymbolTable::symbolKnownUseEmpty(orig_func.getName(),
3260 &module.getBodyRegion())) {
3261 func = orig_func.cloneWithoutRegions();
3262 module.insert(module.getBody()->begin(), func);
3263 func.addEntryBlock();
3264 func.setName("__sm_exported_" + orig_func.getName().str());
3265 llvm::SmallVector<mlir::Value, 4> args_as_values;
3266 for (auto block_argument : func.getArguments()) {
3267 args_as_values.push_back(block_argument);
3268 }
3269 mlir::OpBuilder body_builder(&func.getBody());
3270 auto call = body_builder.create<mlir::TF::StatefulPartitionedCallOp>(
3271 func.getLoc(), orig_func.getType().getResults(), args_as_values,
3272 builder.getSymbolRefAttr(orig_func.getName()),
3273 /*config=*/builder.getStringAttr(""),
3274 /*config_proto=*/builder.getStringAttr(""),
3275 /*executor_type=*/builder.getStringAttr(""));
3276 body_builder.create<mlir::ReturnOp>(func.getLoc(), call.getResults());
3277 }
3278 func->setAttr(
3279 "tf_saved_model.exported_names",
3280 builder.getStrArrayAttr(object_names.GetExportedNames(node_id)));
3281 const SavedConcreteFunction& concrete_function =
3282 object_graph.concrete_functions().at(function.concrete_functions(0));
3283
3284 // We do not handle the other element of this tuple, which corresponds to
3285 // Python kwonlyargs, since currently TensorFlow prohibits this in
3286 // combination with input_signature:
3287 // https://github.com/tensorflow/tensorflow/blob/8cb8627abb5ef83a6fba34f8fd0e4ee430562eb1/tensorflow/python/eager/function.py#L2027-L2030
3288 // Our SavedModel import requires input_signature on the tf.function, so
3289 // we never need to handle the kwonlyargs.
3290 auto positional_arg_structure =
3291 concrete_function.canonicalized_input_signature()
3292 .tuple_value()
3293 .values(0);
3294 StructuredValueLinearizer input_linearizer(positional_arg_structure,
3295 builder.getContext());
3296
3297 int bound_input_base =
3298 func.getNumArguments() - concrete_function.bound_inputs_size();
3299 TF_ASSIGN_OR_RETURN(auto input_index_paths,
3300 input_linearizer.GetLeafIndexPaths(
3301 error_context + "in input signature: "));
3302 const int input_index_paths_size = input_index_paths.size();
3303 if (bound_input_base != input_index_paths_size) {
3304 return errors::InvalidArgument(
3305 error_context,
3306 "Argument mismatch between concrete function input signature "
3307 "vs underlying FunctionDef for concrete function '",
3308 function.concrete_functions(0), "' (", input_index_paths.size(),
3309 " vs ", bound_input_base, ")");
3310 }
3311 for (auto index_path : llvm::enumerate(input_index_paths)) {
3312 func.setArgAttr(index_path.index(), "tf_saved_model.index_path",
3313 index_path.value());
3314 }
3315
3316 for (auto& bound_input :
3317 llvm::enumerate(concrete_function.bound_inputs())) {
3318 int arg_index = bound_input_base + bound_input.index();
3319 auto symbol_ref = builder.getSymbolRefAttr(
3320 object_names.GetSymbolTableName(bound_input.value()));
3321 func.setArgAttr(arg_index, "tf_saved_model.bound_input", symbol_ref);
3322 }
3323
3324 StructuredValueLinearizer output_linearizer(
3325 concrete_function.output_signature(), builder.getContext());
3326 TF_ASSIGN_OR_RETURN(auto output_index_paths,
3327 output_linearizer.GetLeafIndexPaths(
3328 error_context + "in output signature: "));
3329 if (func.getNumResults() != output_index_paths.size()) {
3330 return errors::InvalidArgument(
3331 error_context,
3332 "Result mismatch between concrete function output signature "
3333 "vs underlying FunctionDef for concrete function '",
3334 function.concrete_functions(0), "' (", output_index_paths.size(),
3335 " vs ", func.getNumResults(), ")");
3336 }
3337 for (auto index_path : llvm::enumerate(output_index_paths)) {
3338 func.setResultAttr(index_path.index(), "tf_saved_model.index_path",
3339 index_path.value());
3340 }
3341 } else if (object.kind_case() == SavedObject::kVariable) {
3342 const SavedVariable& variable = object.variable();
3343 // Find the trackable in the side data structure.
3344 auto variable_trackable_it = restored_objects.find(node_id);
3345 if (variable_trackable_it == restored_objects.end()) {
3346 return errors::FailedPrecondition("Could not restore saved variable: ",
3347 variable.name());
3348 }
3349 const auto* serialized_tensor_attr = FindSerializedTensorInTrackable(
3350 *variable_trackable_it->second, "VARIABLE_VALUE");
3351 if (!serialized_tensor_attr) {
3352 return errors::FailedPrecondition(
3353 "Could not find serialized tensor for saved variable: ",
3354 variable.name());
3355 }
3356 const auto& checkpoint_key = serialized_tensor_attr->checkpoint_key();
3357
3358 // Load it from the reader.
3359 Tensor value;
3360 TF_RETURN_WITH_CONTEXT_IF_ERROR(
3361 saved_model->variable_reader()->Lookup(checkpoint_key, &value),
3362 "Could not read checkpoint key from variables bundle: ",
3363 checkpoint_key);
3364 TF_ASSIGN_OR_RETURN(auto value_attr, ConvertTensor(value, &builder));
3365 // A variable can have a partially known type, such as tensor<?x27x?xf32>,
3366 // even if the initializer is a specific static shape.
3367 TF_ASSIGN_OR_RETURN(
3368 auto type, ConvertToMlirTensorType(variable.shape(), variable.dtype(),
3369 &builder));
3370 auto op = builder.create<GlobalTensorOp>(
3371 builder.getUnknownLoc(),
3372 builder.getStringAttr(object_names.GetSymbolTableName(node_id)),
3373 value_attr,
3374 /*type=*/mlir::TypeAttr::get(type),
3375 /*is_mutable=*/builder.getUnitAttr());
3376 op->setAttr(
3377 "tf_saved_model.exported_names",
3378 builder.getStrArrayAttr(object_names.GetExportedNames(node_id)));
3379 } else if (object.kind_case() == SavedObject::kConstant) {
3380 const SavedConstant& constant = object.constant();
3381 const TensorProto* value = ExtractConstTensorFromGraph(
3382 saved_model->meta_graph_def().graph_def(), constant.operation());
3383 if (!value) {
3384 return errors::FailedPrecondition(
3385 "Unable to find const node referenced in object graph: ",
3386 constant.operation());
3387 }
3388 TF_ASSIGN_OR_RETURN(auto value_attr,
3389 ConvertTensorProto(*value, &builder));
3390 auto op = builder.create<GlobalTensorOp>(
3391 builder.getUnknownLoc(),
3392 builder.getStringAttr(object_names.GetSymbolTableName(node_id)),
3393 value_attr,
3394 /*type=*/mlir::TypeAttr::get(value_attr.Attribute::getType()),
3395 /*is_mutable=*/nullptr);
3396 op->setAttr(
3397 "tf_saved_model.exported_names",
3398 builder.getStrArrayAttr(object_names.GetExportedNames(node_id)));
3399 }
3400 }
3401 AdjustBoundInputArgTypes(module);
3402 module->setAttr("tf_saved_model.semantics", builder.getUnitAttr());
3403 SortSavedModelModule(module);
3404 MarkSavedModelFunctionVisibility(module);
3405 return Status::OK();
3406 }
3407
Convert(SavedModelV2Bundle * saved_model,absl::Span<std::string> exported_names,mlir::MLIRContext * context,bool add_default_attributes)3408 StatusOr<mlir::OwningModuleRef> SavedModelObjectGraphImporter::Convert(
3409 SavedModelV2Bundle* saved_model, absl::Span<std::string> exported_names,
3410 mlir::MLIRContext* context, bool add_default_attributes) {
3411 LoadImporterDialects(*context);
3412 GraphDebugInfo dummy_debug_info;
3413 const GraphDebugInfo& debug_info =
3414 saved_model->debug_info() ? *saved_model->debug_info() : dummy_debug_info;
3415
3416 GraphImportConfig specs;
3417 specs.prune_unused_nodes = true;
3418 mlir::OwningModuleRef module =
3419 mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
3420 std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
3421
3422 const auto& graphdef = saved_model->meta_graph_def().graph_def();
3423 PopulateTfVersions(module.get(), graphdef.versions());
3424
3425 GraphConstructorOptions options;
3426 options.allow_internal_ops = true;
3427 options.add_default_attributes = add_default_attributes;
3428 Graph graph(OpRegistry::Global());
3429
3430 GraphDef preprocessed_graphdef(graphdef);
3431 if (add_default_attributes) {
3432 TF_RETURN_IF_ERROR(PreprocessGraphDef(nullptr, &preprocessed_graphdef));
3433 }
3434
3435 TF_RETURN_IF_ERROR(
3436 ConvertGraphDefToGraph(options, preprocessed_graphdef, &graph));
3437
3438 NameUniquifier function_name_uniquifier(graph.flib_def());
3439 SavedModelObjectGraphImporter importer(graph.flib_def(), debug_info, specs,
3440 module.get(), &tf_name_to_mlir_name,
3441 &function_name_uniquifier);
3442
3443 TF_RETURN_IF_ERROR(importer.PrepareConvert(graph));
3444
3445 auto fn_names = graph.flib_def().ListFunctionNames();
3446 for (const auto& fn_name : fn_names) {
3447 TF_RETURN_IF_ERROR(importer.ConvertLibFunction(fn_name));
3448 }
3449 TF_RETURN_IF_ERROR(importer.ConvertDeferredFunctions());
3450
3451 if (!saved_model->meta_graph_def().has_object_graph_def()) {
3452 return errors::InvalidArgument(
3453 "SavedModel does not have an object graph. Please use TF2.");
3454 }
3455 auto& object_graph = saved_model->meta_graph_def().object_graph_def();
3456 ObjectNames object_names(object_graph, exported_names);
3457
3458 // Clean up a couple func's that always seem to be present when importing a
3459 // SavedModel. This is not strictly needed, as there is a separate pass that
3460 // will clean them up, but this makes staring at the raw IR of minimal
3461 // examples quite a bit nicer.
3462 for (auto func : llvm::make_early_inc_range(module->getOps<mlir::FuncOp>())) {
3463 if (func.getName().startswith("__inference__traced_save_") ||
3464 func.getName().startswith("__inference__traced_restore_") ||
3465 func.getName().startswith("__inference_signature_wrapper_")) {
3466 func.erase();
3467 }
3468 }
3469
3470 // Diagnose SavedFunction's with multiple input signatures.
3471 TF_RETURN_IF_ERROR(
3472 DiagnoseMultipleConcreteFunctions(object_graph, object_names));
3473
3474 // Construct the SavedModel IR.
3475 TF_RETURN_IF_ERROR(CreateSavedModelIR(object_names, module.get(),
3476 object_graph, tf_name_to_mlir_name,
3477 saved_model));
3478 assert(mlir::succeeded(mlir::verify(module.get())));
3479
3480 return module;
3481 }
3482
3483 class SimpleSavedModelMLIRImportInput : public SavedModelMLIRImportInput {
3484 public:
Create(const MLIRImportOptions & import_options,const MetaGraphDef * meta_graph_def,const GraphDebugInfo & debug_info)3485 static StatusOr<SimpleSavedModelMLIRImportInput> Create(
3486 const MLIRImportOptions& import_options,
3487 const MetaGraphDef* meta_graph_def, const GraphDebugInfo& debug_info) {
3488 DCHECK(meta_graph_def);
3489 GraphDef graph_def;
3490 if (import_options.enable_grappler) {
3491 // Grappler is best-effort.
3492 auto statusor = RunGrappler(*meta_graph_def);
3493 if (statusor.ok()) {
3494 graph_def = std::move(statusor).ValueOrDie();
3495 } else {
3496 // If the grappler fails, use the original graph def.
3497 LOG(WARNING) << "SimpleSavedModelMLIRImportInput: grappler failed: "
3498 << statusor.status();
3499 graph_def = meta_graph_def->graph_def();
3500 }
3501 } else {
3502 graph_def = meta_graph_def->graph_def();
3503 }
3504
3505 auto graph = std::make_unique<Graph>(OpRegistry::Global());
3506
3507 if (import_options.upgrade_legacy) {
3508 TF_RETURN_IF_ERROR(GenerateResourceSharedNameIfEmpty(
3509 graph_def, graph->flib_def().default_registry()));
3510 }
3511
3512 GraphConstructorOptions graph_ctor_options;
3513 graph_ctor_options.allow_internal_ops = true;
3514 graph_ctor_options.add_default_attributes = true;
3515 TF_RETURN_IF_ERROR(
3516 ConvertGraphDefToGraph(graph_ctor_options, graph_def, graph.get()));
3517
3518 if (import_options.upgrade_legacy) {
3519 // TODO(jpienaar): Remove need to const_cast.
3520 TF_RETURN_IF_ERROR(UpgradeLegacyGraph(
3521 graph.get(),
3522 const_cast<FunctionLibraryDefinition*>(&graph->flib_def()),
3523 /*restrict_functionalization_to_tpu_nodes=*/false));
3524 }
3525
3526 return SimpleSavedModelMLIRImportInput(meta_graph_def, debug_info,
3527 std::move(graph));
3528 }
3529
SimpleSavedModelMLIRImportInput(const MetaGraphDef * meta_graph_def,const GraphDebugInfo & debug_info,std::unique_ptr<Graph> graph)3530 SimpleSavedModelMLIRImportInput(const MetaGraphDef* meta_graph_def,
3531 const GraphDebugInfo& debug_info,
3532 std::unique_ptr<Graph> graph)
3533 : SavedModelMLIRImportInput(meta_graph_def, debug_info),
3534 graph_(std::move(graph)) {}
3535
GetSubGraph(absl::string_view name,const GraphImportConfig & specs)3536 StatusOr<const Graph*> GetSubGraph(absl::string_view name,
3537 const GraphImportConfig& specs) override {
3538 DCHECK(CheckGraphNameValidity(name));
3539 DCHECK(CheckGraphContainsFeedsAndFetches(specs));
3540 return graph_.get();
3541 }
3542
3543 private:
CheckGraphContainsFeedsAndFetches(const GraphImportConfig & specs) const3544 bool CheckGraphContainsFeedsAndFetches(const GraphImportConfig& specs) const {
3545 absl::flat_hash_set<std::string> feed_fetch_nodes;
3546 for (const auto& iter : specs.inputs) {
3547 TensorId tensor_id = ParseTensorName(iter.first);
3548 feed_fetch_nodes.insert(std::string(tensor_id.node()));
3549 }
3550 for (const auto& output : llvm::concat<const std::string>(
3551 specs.outputs, specs.control_outputs)) {
3552 TensorId tensor_id = ParseTensorName(output);
3553 feed_fetch_nodes.insert(std::string(tensor_id.node()));
3554 }
3555
3556 for (Node* node : graph_->op_nodes()) {
3557 feed_fetch_nodes.erase(node->name());
3558 }
3559
3560 return feed_fetch_nodes.empty();
3561 }
3562
CheckGraphNameValidity(absl::string_view name) const3563 bool CheckGraphNameValidity(absl::string_view name) const {
3564 // If it is one of the signature name, it is valid.
3565 const auto& signature_defs = meta_graph_def().signature_def();
3566 if (signature_defs.contains(std::string(name))) return true;
3567
3568 // If it is the restore graph name, it is valid.
3569 if (meta_graph_def().has_saver_def() &&
3570 meta_graph_def().saver_def().restore_op_name() == name)
3571 return true;
3572
3573 // If it is the init graph name, it is valid.
3574 std::string init_op_name;
3575 if (internal::GetInitOp("", meta_graph_def(), &init_op_name).ok()) {
3576 if (init_op_name == name) return true;
3577 }
3578
3579 return false;
3580 }
3581
3582 // `graph_` contains the entire graph in the original MetaGraphDef.
3583 std::unique_ptr<Graph> graph_;
3584 };
3585
GetOriginalTfFuncNamesFromGraphDef(const GraphDef & graph_def)3586 static absl::flat_hash_set<std::string> GetOriginalTfFuncNamesFromGraphDef(
3587 const GraphDef& graph_def) {
3588 absl::flat_hash_set<std::string> original_func_tf_names;
3589 for (const auto& function : graph_def.library().function()) {
3590 original_func_tf_names.insert(function.signature().name());
3591 }
3592 return original_func_tf_names;
3593 }
3594
3595 // A helper class to import a TensorFlow model expressed in SavedModel V1 into
3596 // an MLIR Module in SavedModel dialect.
3597 //
3598 // TODO(b/179683149): Rename this class to avoid confusion with TFLite.
3599 class SavedModelSignatureDefImporterLite {
3600 public:
3601 // Main entry point: converts all functions (specified by SignatureDefs) in
3602 // the given meta graph to an MLIR Module.
3603 //
3604 // `import_restore` is introduced to control whether restore graph
3605 // is imported in eg. SavedModelSignatureDefImporter. Ideally, we don't need
3606 // this option to control this as restore graph should be always imported.
3607 // However, right now, SavedModelSignatureDefImporter cannot handle restore
3608 // graph correctly.
3609 //
3610 // TODO(chky): Remove import_restore once the restore graph is correctly
3611 // handled in SavedModelSignatureDefImporter.
Convert(SavedModelMLIRImportInput & input,absl::optional<absl::Span<const std::string>> exported_names,mlir::MLIRContext * context,bool import_restore=true)3612 static StatusOr<mlir::OwningModuleRef> Convert(
3613 SavedModelMLIRImportInput& input,
3614 absl::optional<absl::Span<const std::string>> exported_names,
3615 mlir::MLIRContext* context, bool import_restore = true) {
3616 SavedModelSignatureDefImporterLite importer(input, exported_names, context,
3617 import_restore);
3618 return importer.ConvertSignatures();
3619 }
3620
3621 private:
SavedModelSignatureDefImporterLite(SavedModelMLIRImportInput & input,absl::optional<absl::Span<const std::string>> exported_names,mlir::MLIRContext * context,bool import_restore)3622 SavedModelSignatureDefImporterLite(
3623 SavedModelMLIRImportInput& input,
3624 absl::optional<absl::Span<const std::string>> exported_names,
3625 mlir::MLIRContext* context, bool import_restore)
3626 : input_(input),
3627 original_func_tf_names_(GetOriginalTfFuncNamesFromGraphDef(
3628 input.meta_graph_def().graph_def())),
3629 exported_names_(exported_names),
3630 module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(context))),
3631 symbol_table_(module_.get()),
3632 import_restore_(import_restore) {}
3633
3634 // Converts the SavedModel to the SavedModel dialect. Creates an MLIR function
3635 // for each signature.
3636 StatusOr<mlir::OwningModuleRef> ConvertSignatures();
3637 Status ConvertSignature(const std::string& sig_def_key,
3638 const SignatureDef& signature_def);
3639
3640 struct AssetInfo {
3641 std::string tensor_name;
3642 mlir::tf_saved_model::AssetOp op;
3643 };
3644 StatusOr<std::vector<AssetInfo>> ConvertAssets();
3645 // Converts the initialization graph in the SavedModel to an MLIR function.
3646 Status ConvertInitializer(const std::string& target_node_name,
3647 const std::vector<AssetInfo>& assets);
3648
3649 // Converts a graph with feeds and fetches to an MLIR function.
3650 StatusOr<mlir::OwningModuleRef> ConvertGraph(
3651 const std::string& name,
3652 const std::vector<std::pair<std::string, TensorInfo>>& inputs,
3653 const std::vector<std::pair<std::string, TensorInfo>>& outputs,
3654 const std::vector<std::string> control_outputs,
3655 std::unordered_map<std::string, std::string>& tf_name_to_mlir_name);
3656
3657 // Moves the functions in `sub_module` to `module_` and skips the duplicate
3658 // functions.
3659 Status MoveConvertedFunctionsToModule(
3660 absl::string_view name, mlir::ModuleOp sub_module,
3661 const std::unordered_map<std::string, std::string>& tf_name_to_mlir_name);
3662
3663 StatusOr<GraphImportConfig::InputArrays> ParseInputArrays(
3664 llvm::ArrayRef<std::pair<std::string, TensorInfo>> inputs);
3665
3666 private:
3667 SavedModelMLIRImportInput& input_;
3668 absl::flat_hash_set<std::string> original_func_tf_names_;
3669 absl::optional<absl::Span<const std::string>> exported_names_;
3670 mlir::OwningModuleRef module_;
3671 mlir::SymbolTable symbol_table_;
3672 bool import_restore_ = true;
3673 };
3674
3675 StatusOr<std::vector<SavedModelSignatureDefImporterLite::AssetInfo>>
ConvertAssets()3676 SavedModelSignatureDefImporterLite::ConvertAssets() {
3677 std::vector<AssetFileDef> asset_file_defs;
3678 TF_RETURN_IF_ERROR(
3679 internal::GetAssetFileDefs(input_.meta_graph_def(), &asset_file_defs));
3680
3681 std::vector<AssetInfo> results;
3682 results.reserve(asset_file_defs.size());
3683
3684 mlir::OpBuilder builder(module_->getBodyRegion());
3685 unsigned i = 0; // Use to generate unique sym_name(s) for duplicate assets.
3686 for (const auto& asset : asset_file_defs) {
3687 auto asset_op = builder.create<mlir::tf_saved_model::AssetOp>(
3688 module_->getLoc(),
3689 /*sym_name=*/
3690 builder.getStringAttr(
3691 absl::StrCat("__tf_saved_model_asset", i++, "_", asset.filename())),
3692 /*filename=*/
3693 builder.getStringAttr(
3694 io::JoinPath(kSavedModelAssetsDirectory, asset.filename())));
3695
3696 results.push_back({asset.tensor_info().name(), asset_op});
3697 }
3698
3699 return results;
3700 }
3701
MoveConvertedFunctionsToModule(absl::string_view name,mlir::ModuleOp sub_module,const std::unordered_map<std::string,std::string> & tf_name_to_mlir_name)3702 Status SavedModelSignatureDefImporterLite::MoveConvertedFunctionsToModule(
3703 absl::string_view name, mlir::ModuleOp sub_module,
3704 const std::unordered_map<std::string, std::string>& tf_name_to_mlir_name) {
3705 mlir::Builder builder(sub_module.getContext());
3706 mlir::SymbolTable sub_module_symbol_table(sub_module);
3707
3708 // Functions originally from graphdef library might have a different name
3709 // after conversion, we build the set of the converted names
3710 absl::flat_hash_set<std::string> original_func_mlir_names;
3711 for (const auto& kv : tf_name_to_mlir_name) {
3712 if (original_func_tf_names_.contains(kv.first))
3713 original_func_mlir_names.insert(kv.second);
3714 }
3715
3716 // Prefix private functions with the unique signature name, so that it cannot
3717 // collide with private functions used in the other signatures.
3718 for (auto func : sub_module.getOps<mlir::FuncOp>()) {
3719 if (mlir::tf_saved_model::IsExported(func)) continue;
3720
3721 // Skip the original functions from graphdef library
3722 if (original_func_mlir_names.count(func.sym_name().str())) continue;
3723
3724 std::string new_sym_name = absl::StrCat(name, "/", func.sym_name().str());
3725 if (mlir::failed(sub_module_symbol_table.replaceAllSymbolUses(
3726 func, new_sym_name, sub_module)))
3727 return tensorflow::errors::InvalidArgument(absl::StrCat(
3728 "SavedModelSignatureDefImporterLite: failed to assign a unique "
3729 "name to the private function used in a signature: ",
3730 func.sym_name().str()));
3731
3732 mlir::SymbolTable::setSymbolName(func, new_sym_name);
3733 }
3734
3735 // Copy all functions used by this signature to the final MLIR module.
3736 for (auto func : sub_module.getOps<mlir::FuncOp>()) {
3737 // The insert here is a NO-OP if the function already exists.
3738 symbol_table_.insert(func.clone());
3739 }
3740
3741 return Status::OK();
3742 }
3743
ConvertInitializer(const std::string & target_node_name,const std::vector<AssetInfo> & assets)3744 Status SavedModelSignatureDefImporterLite::ConvertInitializer(
3745 const std::string& target_node_name, const std::vector<AssetInfo>& assets) {
3746 std::vector<std::pair<std::string, TensorInfo>> inputs;
3747 inputs.reserve(assets.size());
3748 for (const auto& asset : assets) {
3749 TensorInfo tensor_info;
3750 tensor_info.set_name(asset.tensor_name);
3751 tensor_info.set_dtype(DT_STRING);
3752 tensor_info.mutable_tensor_shape();
3753 inputs.push_back({asset.tensor_name, tensor_info});
3754 }
3755
3756 std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
3757 TF_ASSIGN_OR_RETURN(auto sub_module,
3758 ConvertGraph(target_node_name, inputs, {},
3759 {target_node_name}, tf_name_to_mlir_name));
3760
3761 mlir::SymbolTable sub_symbol_table(*sub_module);
3762
3763 auto init_func_op = sub_symbol_table.lookup<mlir::FuncOp>(target_node_name);
3764 init_func_op->removeAttr("tf.entry_function");
3765
3766 mlir::OpBuilder builder(module_->getBodyRegion());
3767
3768 // Bind asset inputs to asset ops.
3769 DCHECK_EQ(init_func_op.getNumArguments(), assets.size());
3770 for (const auto& iter : llvm::enumerate(assets)) {
3771 auto asset_op = iter.value().op;
3772 init_func_op.setArgAttr(iter.index(), "tf_saved_model.bound_input",
3773 builder.getSymbolRefAttr(asset_op.getName()));
3774 }
3775
3776 // Set the exported name of init function to an reserved name for
3777 // tf_saved_model.
3778 init_func_op->setAttr(
3779 "tf_saved_model.exported_names",
3780 builder.getStrArrayAttr({absl::StrCat(
3781 "__tf_saved_model_session_initializer_", target_node_name)}));
3782
3783 // Move the converted functions to top level MLIR module.
3784 return MoveConvertedFunctionsToModule(target_node_name, *sub_module,
3785 tf_name_to_mlir_name);
3786 }
3787
3788 StatusOr<mlir::OwningModuleRef>
ConvertGraph(const std::string & name,const std::vector<std::pair<std::string,TensorInfo>> & inputs,const std::vector<std::pair<std::string,TensorInfo>> & outputs,const std::vector<std::string> control_outputs,std::unordered_map<std::string,std::string> & tf_name_to_mlir_name)3789 SavedModelSignatureDefImporterLite::ConvertGraph(
3790 const std::string& name,
3791 const std::vector<std::pair<std::string, TensorInfo>>& inputs,
3792 const std::vector<std::pair<std::string, TensorInfo>>& outputs,
3793 const std::vector<std::string> control_outputs,
3794 std::unordered_map<std::string, std::string>& tf_name_to_mlir_name) {
3795 VLOG(1) << "Importing Signature: " << name;
3796
3797 GraphImportConfig specs;
3798 specs.graph_func_name = name;
3799 specs.prune_unused_nodes = true;
3800 TF_ASSIGN_OR_RETURN(specs.inputs, ParseInputArrays(inputs));
3801 for (auto& output : outputs) specs.outputs.push_back(output.second.name());
3802 specs.control_outputs = control_outputs;
3803 specs.enable_shape_inference = false;
3804
3805 TF_ASSIGN_OR_RETURN(const auto* subgraph, input_.GetSubGraph(name, specs));
3806
3807 // Convert sub-graph to MLIR module.
3808 return GraphDefImporter::Convert(module_->getContext(), *subgraph,
3809 input_.debug_info(), subgraph->flib_def(),
3810 specs, tf_name_to_mlir_name);
3811 }
3812
ConvertSignature(const std::string & sig_def_key,const SignatureDef & signature_def)3813 Status SavedModelSignatureDefImporterLite::ConvertSignature(
3814 const std::string& sig_def_key, const SignatureDef& signature_def) {
3815 // Create local vectors for the input and output and sort them to be
3816 // deterministic. We don't want anyone to really depend on the order, client
3817 // should lookup argument/result mapping by attribute name.
3818 // To avoid accidentally depending on the order we use an unintuitive sorting.
3819 std::vector<std::pair<std::string, TensorInfo>> inputs(
3820 signature_def.inputs().begin(), signature_def.inputs().end());
3821 llvm::sort(inputs, [](const auto& lhs, const auto& rhs) {
3822 return tensorflow::Fingerprint64(lhs.first) <
3823 tensorflow::Fingerprint64(rhs.first);
3824 });
3825 std::vector<std::pair<std::string, TensorInfo>> outputs(
3826 signature_def.outputs().begin(), signature_def.outputs().end());
3827 llvm::sort(outputs, [](const auto& lhs, const auto& rhs) {
3828 return tensorflow::Fingerprint64(lhs.first) <
3829 tensorflow::Fingerprint64(rhs.first);
3830 });
3831
3832 std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
3833
3834 // Convert sub-graph to MLIR module.
3835 TF_ASSIGN_OR_RETURN(
3836 auto sub_module,
3837 ConvertGraph(sig_def_key, inputs, outputs, {}, tf_name_to_mlir_name));
3838 mlir::OpBuilder builder(sub_module->getBodyRegion());
3839
3840 // Find the FuncOp which corresponds to current SignatureDef.
3841 mlir::SymbolTable sub_symbol_table(*sub_module);
3842 auto func_op = sub_symbol_table.lookup<mlir::FuncOp>(sig_def_key);
3843 TF_RET_CHECK(func_op)
3844 << "Graphdef importer should have created a function named "
3845 << sig_def_key << ".";
3846
3847 // Use unique SignatureDef key as exported name.
3848 func_op->setAttr("tf_saved_model.exported_names",
3849 builder.getStrArrayAttr({sig_def_key}));
3850
3851 // Transfer input and output parameter names to index_path attributes.
3852 for (auto input_and_idx : llvm::enumerate(inputs)) {
3853 func_op.setArgAttr(input_and_idx.index(), "tf_saved_model.index_path",
3854 builder.getStrArrayAttr({input_and_idx.value().first}));
3855 }
3856 for (auto output_and_idx : llvm::enumerate(outputs)) {
3857 func_op.setResultAttr(
3858 output_and_idx.index(), "tf_saved_model.index_path",
3859 builder.getStrArrayAttr({output_and_idx.value().first}));
3860 }
3861
3862 // Move the converted functions to top level MLIR module.
3863 return MoveConvertedFunctionsToModule(sig_def_key, *sub_module,
3864 tf_name_to_mlir_name);
3865 }
3866
3867 StatusOr<GraphImportConfig::InputArrays>
ParseInputArrays(llvm::ArrayRef<std::pair<std::string,TensorInfo>> inputs)3868 SavedModelSignatureDefImporterLite::ParseInputArrays(
3869 llvm::ArrayRef<std::pair<std::string, TensorInfo>> inputs) {
3870 GraphImportConfig::InputArrays results;
3871 for (const auto& iter : inputs) {
3872 const auto& tensor_info = iter.second;
3873
3874 // TODO(b/184675681): Support other encoding cases.
3875 //
3876 // TODO(b/184679394): Add unit test for this check.
3877 TF_RET_CHECK(tensor_info.encoding_case() == tensorflow::TensorInfo::kName)
3878 << "Only dense tensor is supported, but got encoding case "
3879 << tensor_info.encoding_case();
3880
3881 VLOG(1) << "Importing Signature Input: input_name = " << iter.first
3882 << ", tensor_info = " << tensor_info.DebugString();
3883
3884 ArrayInfo array_info;
3885 array_info.imported_dtype = tensor_info.dtype();
3886
3887 if (tensor_info.has_tensor_shape()) {
3888 array_info.shape = tensor_info.tensor_shape();
3889 } else {
3890 // If there is no tensor shape in the tensor info, conservatively set
3891 // unknown_rank to true.
3892 array_info.shape.set_unknown_rank(true);
3893 }
3894
3895 results.insert(std::pair<std::string, ArrayInfo>(tensor_info.name(),
3896 std::move(array_info)));
3897 }
3898 return results;
3899 }
3900
3901 StatusOr<mlir::OwningModuleRef>
ConvertSignatures()3902 SavedModelSignatureDefImporterLite::ConvertSignatures() {
3903 LoadImporterDialects(*module_->getContext());
3904
3905 const auto& signatures = input_.meta_graph_def().signature_def();
3906 PopulateTfVersions(module_.get(),
3907 input_.meta_graph_def().graph_def().versions());
3908
3909 llvm::DenseSet<llvm::StringRef> exported_name_set;
3910 bool import_all_signatures = !exported_names_.has_value();
3911 if (exported_names_.has_value()) {
3912 exported_name_set.insert(exported_names_->begin(), exported_names_->end());
3913 }
3914
3915 for (const auto& key_and_signature_def : signatures) {
3916 const std::string& sig_def_key = key_and_signature_def.first;
3917 const SignatureDef& signature_def = key_and_signature_def.second;
3918
3919 // It is safe to skip "__saved_model_init_op" since it is an internal
3920 // signature that is not user-accessible. This signature will be handled in
3921 // ConvertInitializer().
3922 if (sig_def_key == "__saved_model_init_op") {
3923 continue;
3924 }
3925 if (!import_all_signatures && exported_name_set.count(sig_def_key) == 0) {
3926 continue;
3927 }
3928
3929 TF_RETURN_IF_ERROR(ConvertSignature(sig_def_key, signature_def));
3930 }
3931
3932 TF_ASSIGN_OR_RETURN(auto assets, ConvertAssets());
3933
3934 mlir::OpBuilder builder(module_->getBodyRegion());
3935 llvm::SmallVector<mlir::Attribute, 2> init_sym_refs;
3936
3937 if (import_restore_ && input_.meta_graph_def().has_saver_def()) {
3938 std::vector<AssetInfo> variable_and_assets;
3939
3940 // Create an AssetOp for the variable checkpoint files. The relative
3941 // filename is used here.
3942 auto variable_filename_op = builder.create<mlir::tf_saved_model::AssetOp>(
3943 module_->getLoc(),
3944 /*sym_name=*/
3945 builder.getStringAttr("__tf_saved_model_variables"),
3946 /*filename=*/
3947 builder.getStringAttr(io::JoinPath(kSavedModelVariablesDirectory,
3948 kSavedModelVariablesFilename)));
3949 variable_and_assets.push_back(
3950 {input_.meta_graph_def().saver_def().filename_tensor_name(),
3951 variable_filename_op});
3952 variable_and_assets.insert(variable_and_assets.end(), assets.begin(),
3953 assets.end());
3954
3955 const auto& restore_op_name =
3956 input_.meta_graph_def().saver_def().restore_op_name();
3957 TF_RETURN_IF_ERROR(
3958 ConvertInitializer(restore_op_name, variable_and_assets));
3959 init_sym_refs.push_back(builder.getSymbolRefAttr(restore_op_name));
3960 }
3961
3962 std::string init_op_name;
3963 TF_RETURN_IF_ERROR(
3964 internal::GetInitOp("", input_.meta_graph_def(), &init_op_name));
3965 if (!init_op_name.empty()) {
3966 TF_RETURN_IF_ERROR(ConvertInitializer(init_op_name, assets));
3967 init_sym_refs.push_back(builder.getSymbolRefAttr(init_op_name));
3968 }
3969
3970 builder.create<mlir::tf_saved_model::SessionInitializerOp>(
3971 module_->getLoc(), builder.getArrayAttr(init_sym_refs));
3972
3973 (*module_)->setAttr("tf_saved_model.semantics", builder.getUnitAttr());
3974
3975 SortSavedModelModule(*module_);
3976 MarkSavedModelFunctionVisibility(*module_);
3977
3978 return std::move(module_);
3979 }
3980
3981 // A helper class to import a TensorFlow model expressed in SavedModel V1 into
3982 // an MLIR Module in SavedModel dialect. In addition to importing the model, it
3983 // performs a few graph transformations, including:
3984 // 1) Convert read-only ref variables to resource variables
3985 // 2) Lift resource variables to global_tensors by using a TF session.
3986 class SavedModelSignatureDefImporter {
3987 public:
3988 // Main entry point: converts all functions (specified by SignatureDefs) in
3989 // the given meta graph to an MLIR Module.
Convert(const SavedModelBundle & bundle,absl::optional<absl::Span<const std::string>> exported_names,mlir::MLIRContext * context,tensorflow::MLIRImportOptions options,bool lift_varhandle_ops_to_args=true)3990 static StatusOr<mlir::OwningModuleRef> Convert(
3991 const SavedModelBundle& bundle,
3992 absl::optional<absl::Span<const std::string>> exported_names,
3993 mlir::MLIRContext* context, tensorflow::MLIRImportOptions options,
3994 bool lift_varhandle_ops_to_args = true) {
3995 // debug_info might not be loaded with loader_lite.
3996 GraphDebugInfo debug_info;
3997 if (bundle.debug_info != nullptr) debug_info = *bundle.debug_info;
3998
3999 TF_ASSIGN_OR_RETURN(auto input,
4000 SimpleSavedModelMLIRImportInput::Create(
4001 options, &bundle.meta_graph_def, debug_info));
4002
4003 TF_ASSIGN_OR_RETURN(auto module,
4004 SavedModelSignatureDefImporterLite::Convert(
4005 input, exported_names, context,
4006 /*import_restore=*/false));
4007
4008 mlir::OpBuilder builder(module->getContext());
4009 (*module)->setAttr("tf_saved_model.under_construction",
4010 builder.getUnitAttr());
4011 TF_RETURN_IF_ERROR(
4012 LiftVariables(bundle, *module, lift_varhandle_ops_to_args));
4013 (*module)->removeAttr("tf_saved_model.under_construction");
4014
4015 return module;
4016 }
4017
4018 private:
4019 // Lifts the variables in `module`.
4020 static Status LiftVariables(const SavedModelBundle& bundle,
4021 mlir::ModuleOp module,
4022 bool lift_varhandle_ops_to_args);
4023 };
4024
LiftVariables(const SavedModelBundle & bundle,mlir::ModuleOp module,bool lift_varhandle_ops_to_args)4025 Status SavedModelSignatureDefImporter::LiftVariables(
4026 const SavedModelBundle& bundle, mlir::ModuleOp module,
4027 bool lift_varhandle_ops_to_args) {
4028 mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
4029
4030 mlir::PassManager pm(module.getContext());
4031 SetCrashReproducer(pm);
4032 pm.addNestedPass<mlir::FuncOp>(
4033 mlir::tf_executor::CreateTFExecutorGraphPruningPass());
4034 pm.addNestedPass<mlir::FuncOp>(
4035 mlir::CreateExecutorDialectToFunctionalConversionPass());
4036 pm.addPass(
4037 mlir::tf_saved_model::CreateRemoveVariablesInSessionInitializerPass());
4038 pm.addNestedPass<mlir::FuncOp>(
4039 mlir::TF::
4040 CreateConvertReadonlyReferenceVariablesToResourceVariablesPass());
4041 if (lift_varhandle_ops_to_args) {
4042 pm.addNestedPass<mlir::FuncOp>(
4043 mlir::tf_saved_model::CreateMarkInitializedVariablesPass(
4044 bundle.GetSession()));
4045 pm.addPass(mlir::TF::CreatePromoteVarHandlesToArgsPass());
4046 pm.addPass(
4047 mlir::tf_saved_model::CreateLiftVariablesPass(bundle.GetSession()));
4048 } else {
4049 pm.addPass(
4050 mlir::tf_saved_model::CreateInitializeVariablesInSessionInitializerPass(
4051 bundle.GetSession()));
4052 }
4053 pm.addNestedPass<mlir::FuncOp>(
4054 mlir::tf_saved_model::CreateDedupBoundInputBindingPass());
4055 if (mlir::failed(pm.run(module)))
4056 return diag_handler.Combine(errors::Internal("Failed to lift variables."));
4057
4058 return Status::OK();
4059 }
4060
4061 } // namespace
4062
~SavedModelMLIRImportInput()4063 SavedModelMLIRImportInput::~SavedModelMLIRImportInput() {}
4064
ConvertGraphdefToMlir(const GraphDef & graphdef,const GraphDebugInfo & debug_info,const GraphImportConfig & specs,mlir::MLIRContext * context,bool add_default_attributes)4065 StatusOr<mlir::OwningModuleRef> ConvertGraphdefToMlir(
4066 const GraphDef& graphdef, const GraphDebugInfo& debug_info,
4067 const GraphImportConfig& specs, mlir::MLIRContext* context,
4068 bool add_default_attributes) {
4069 GraphConstructorOptions options;
4070 options.allow_internal_ops = true;
4071 options.add_default_attributes = add_default_attributes;
4072 Graph graph(OpRegistry::Global());
4073
4074 GraphDef preprocessed_graphdef(graphdef);
4075 if (add_default_attributes) {
4076 TF_RETURN_IF_ERROR(PreprocessGraphDef(&specs, &preprocessed_graphdef));
4077 }
4078 if (specs.upgrade_legacy) {
4079 TF_RETURN_IF_ERROR(GenerateResourceSharedNameIfEmpty(
4080 preprocessed_graphdef, graph.flib_def().default_registry()));
4081 }
4082 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
4083 options, std::move(preprocessed_graphdef), &graph));
4084 return ConvertGraphToMlir(graph, debug_info, graph.flib_def(), specs,
4085 context);
4086 }
4087
ConvertGraphToMlir(const Graph & graph,const GraphDebugInfo & debug_info,const FunctionLibraryDefinition & flib_def,const GraphImportConfig & specs,mlir::MLIRContext * context)4088 StatusOr<mlir::OwningModuleRef> ConvertGraphToMlir(
4089 const Graph& graph, const GraphDebugInfo& debug_info,
4090 const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs,
4091 mlir::MLIRContext* context) {
4092 // TODO(jpienaar): Remove need to const_cast.
4093 if (specs.upgrade_legacy) {
4094 TF_RETURN_IF_ERROR(
4095 UpgradeLegacyGraph(const_cast<Graph*>(&graph),
4096 const_cast<FunctionLibraryDefinition*>(&flib_def),
4097 specs.restrict_functionalization_to_tpu_nodes));
4098 }
4099 std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
4100 return GraphDefImporter::Convert(context, graph, debug_info, flib_def, specs,
4101 tf_name_to_mlir_name);
4102 }
4103
ConvertFunctionToMlir(const FunctionBody * fbody,const FunctionLibraryDefinition & flib_def,mlir::MLIRContext * context)4104 stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertFunctionToMlir(
4105 const FunctionBody* fbody, const FunctionLibraryDefinition& flib_def,
4106 mlir::MLIRContext* context) {
4107 tensorflow::GraphDebugInfo dummy_debug_info;
4108 tensorflow::GraphImportConfig specs;
4109 specs.graph_func_name = fbody->fdef.signature().name();
4110 specs.enable_shape_inference = false;
4111 specs.graph_as_function = true;
4112 for (const auto* control_ret_node : fbody->control_ret_nodes)
4113 specs.control_outputs.push_back(control_ret_node->name());
4114 std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
4115 return GraphDefImporter::Convert(context, *fbody->graph, dummy_debug_info,
4116 flib_def, specs, tf_name_to_mlir_name);
4117 }
4118
ConvertSavedModelToMlir(SavedModelV2Bundle * saved_model,mlir::MLIRContext * context,absl::Span<std::string> exported_names,bool add_default_attributes)4119 StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(
4120 SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
4121 absl::Span<std::string> exported_names, bool add_default_attributes) {
4122 return SavedModelObjectGraphImporter::Convert(
4123 saved_model, exported_names, context, add_default_attributes);
4124 }
4125
ConvertSavedModelV1ToMlir(const SavedModelBundle & saved_model,absl::Span<std::string> exported_names,mlir::MLIRContext * context,MLIRImportOptions options,bool lift_variables)4126 StatusOr<mlir::OwningModuleRef> ConvertSavedModelV1ToMlir(
4127 const SavedModelBundle& saved_model, absl::Span<std::string> exported_names,
4128 mlir::MLIRContext* context, MLIRImportOptions options,
4129 bool lift_variables) {
4130 absl::optional<absl::Span<const std::string>> optional_exported_names;
4131 // TODO(b/187062560): Change ConvertSavedModelV1ToMlir() to take an optional
4132 // `exported_names` so that it can be configured to import only restore/init
4133 // graphs.
4134 if (!exported_names.empty()) optional_exported_names = exported_names;
4135 return SavedModelSignatureDefImporter::Convert(
4136 saved_model, optional_exported_names, context, options, lift_variables);
4137 }
4138
ConvertSavedModelV1ToMlirLite(const MetaGraphDef & meta_graph_def,const GraphDebugInfo & debug_info,absl::optional<absl::Span<const std::string>> exported_names,mlir::MLIRContext * context,MLIRImportOptions options)4139 StatusOr<mlir::OwningModuleRef> ConvertSavedModelV1ToMlirLite(
4140 const MetaGraphDef& meta_graph_def, const GraphDebugInfo& debug_info,
4141 absl::optional<absl::Span<const std::string>> exported_names,
4142 mlir::MLIRContext* context, MLIRImportOptions options) {
4143 TF_ASSIGN_OR_RETURN(auto input, SimpleSavedModelMLIRImportInput::Create(
4144 options, &meta_graph_def, debug_info));
4145 return ConvertSavedModelV1ToMlirLite(input, exported_names, context);
4146 }
4147
ConvertSavedModelV1ToMlirLite(SavedModelMLIRImportInput & input,absl::optional<absl::Span<const std::string>> exported_names,mlir::MLIRContext * context)4148 StatusOr<mlir::OwningModuleRef> ConvertSavedModelV1ToMlirLite(
4149 SavedModelMLIRImportInput& input,
4150 absl::optional<absl::Span<const std::string>> exported_names,
4151 mlir::MLIRContext* context) {
4152 return SavedModelSignatureDefImporterLite::Convert(input, exported_names,
4153 context);
4154 }
4155
MlirModuleToString(mlir::ModuleOp module,mlir::OpPrintingFlags flags)4156 std::string MlirModuleToString(mlir::ModuleOp module,
4157 mlir::OpPrintingFlags flags) {
4158 std::string txt_module;
4159 {
4160 llvm::raw_string_ostream os{txt_module};
4161 module.print(os, flags);
4162 }
4163 return txt_module;
4164 }
4165
MlirModuleToString(mlir::ModuleOp module,bool show_debug_info)4166 std::string MlirModuleToString(mlir::ModuleOp module, bool show_debug_info) {
4167 mlir::OpPrintingFlags flags;
4168 if (show_debug_info) flags.enableDebugInfo();
4169 return MlirModuleToString(module, flags);
4170 }
4171
4172 } // namespace tensorflow
4173