• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
17 
18 #include <atomic>
19 #include <functional>
20 #include <iterator>
21 #include <memory>
22 #include <queue>
23 #include <string>
24 #include <tuple>
25 #include <type_traits>
26 #include <unordered_set>
27 #include <utility>
28 #include <vector>
29 
30 #include "absl/algorithm/container.h"
31 #include "absl/container/flat_hash_map.h"
32 #include "absl/container/flat_hash_set.h"
33 #include "absl/container/inlined_vector.h"
34 #include "absl/strings/escaping.h"
35 #include "absl/strings/numbers.h"
36 #include "absl/strings/str_cat.h"
37 #include "absl/strings/str_join.h"
38 #include "absl/strings/string_view.h"
39 #include "absl/strings/strip.h"
40 #include "llvm/ADT/ArrayRef.h"
41 #include "llvm/ADT/DenseMap.h"
42 #include "llvm/ADT/DenseSet.h"
43 #include "llvm/ADT/STLExtras.h"
44 #include "llvm/ADT/ScopeExit.h"
45 #include "llvm/ADT/SetVector.h"
46 #include "llvm/ADT/SmallVector.h"
47 #include "llvm/ADT/StringRef.h"
48 #include "llvm/ADT/StringSet.h"
49 #include "llvm/ADT/Twine.h"
50 #include "llvm/Support/FormatVariadic.h"
51 #include "llvm/Support/SourceMgr.h"
52 #include "llvm/Support/raw_ostream.h"
53 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
54 #include "mlir/IR/Attributes.h"  // from @llvm-project
55 #include "mlir/IR/Builders.h"  // from @llvm-project
56 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
57 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
58 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
59 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
60 #include "mlir/IR/Identifier.h"  // from @llvm-project
61 #include "mlir/IR/Location.h"  // from @llvm-project
62 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
63 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
64 #include "mlir/IR/Types.h"  // from @llvm-project
65 #include "mlir/IR/Verifier.h"  // from @llvm-project
66 #include "mlir/Pass/PassManager.h"  // from @llvm-project
67 #include "tensorflow/cc/saved_model/constants.h"
68 #include "tensorflow/cc/saved_model/loader_util.h"
69 #include "tensorflow/compiler/jit/shape_inference_helpers.h"
70 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
71 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
72 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
73 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
74 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
75 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
76 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
77 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
78 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h"
79 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
80 #include "tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h"
81 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h"
82 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
83 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
84 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
85 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
86 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
87 #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
88 #include "tensorflow/compiler/xla/status_macros.h"
89 #include "tensorflow/core/common_runtime/function.h"
90 #include "tensorflow/core/common_runtime/graph_constructor.h"
91 #include "tensorflow/core/common_runtime/shape_refiner.h"
92 #include "tensorflow/core/framework/attr_value.pb.h"
93 #include "tensorflow/core/framework/function.pb.h"
94 #include "tensorflow/core/framework/graph.pb.h"
95 #include "tensorflow/core/framework/node_def.pb.h"
96 #include "tensorflow/core/framework/node_def_util.h"
97 #include "tensorflow/core/framework/op.h"
98 #include "tensorflow/core/framework/resource_var.h"
99 #include "tensorflow/core/framework/shape_inference.h"
100 #include "tensorflow/core/framework/tensor.pb.h"
101 #include "tensorflow/core/framework/types.h"
102 #include "tensorflow/core/framework/types.pb.h"
103 #include "tensorflow/core/framework/versions.pb.h"
104 #include "tensorflow/core/graph/algorithm.h"
105 #include "tensorflow/core/graph/graph.h"
106 #include "tensorflow/core/graph/node_builder.h"
107 #include "tensorflow/core/graph/tensor_id.h"
108 #include "tensorflow/core/grappler/utils/transitive_fanin.h"
109 #include "tensorflow/core/lib/core/errors.h"
110 #include "tensorflow/core/lib/strings/str_util.h"
111 #include "tensorflow/core/platform/crash_analysis.h"
112 #include "tensorflow/core/platform/errors.h"
113 #include "tensorflow/core/platform/fingerprint.h"
114 #include "tensorflow/core/platform/logging.h"
115 #include "tensorflow/core/platform/path.h"
116 #include "tensorflow/core/platform/protobuf.h"
117 #include "tensorflow/core/platform/types.h"
118 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
119 #include "tensorflow/core/protobuf/meta_graph.pb.h"
120 #include "tensorflow/core/protobuf/saved_object_graph.pb.h"
121 #include "tensorflow/core/protobuf/saver.pb.h"
122 #include "tensorflow/core/protobuf/struct.pb.h"
123 #include "tensorflow/core/protobuf/trackable_object_graph.pb.h"
124 #include "tensorflow/core/util/device_name_utils.h"
125 #include "tensorflow/core/util/dump_graph.h"
126 #include "tensorflow/stream_executor/lib/statusor.h"
127 
StringRefToView(llvm::StringRef ref)128 static inline absl::string_view StringRefToView(llvm::StringRef ref) {
129   return {ref.data(), ref.size()};
130 }
131 
132 namespace tensorflow {
133 
134 const char kImportModelDefaultGraphFuncName[] = "main";
135 
136 using mlir::NamedAttrList;
137 using mlir::TensorType;
138 using mlir::tf_saved_model::AssetOp;
139 using mlir::tf_saved_model::GlobalTensorOp;
140 using mlir::tf_saved_model::SessionInitializerOp;
141 using stream_executor::port::StatusOr;
142 
143 namespace {
144 
IsOutputShapesAttribute(const AttrValue & attr_value,llvm::StringRef attr_name)145 bool IsOutputShapesAttribute(const AttrValue& attr_value,
146                              llvm::StringRef attr_name) {
147   return attr_name.compare("_output_shapes") == 0 &&
148          attr_value.value_case() == AttrValue::kList;
149 }
150 
IsResourceOutputShapesAttribute(const AttrValue & attr_value,llvm::StringRef attr_name)151 bool IsResourceOutputShapesAttribute(const AttrValue& attr_value,
152                                      llvm::StringRef attr_name) {
153   if (attr_name == "_handle_dtypes" || attr_name == "_handle_shapes")
154     return attr_value.value_case() == AttrValue::kList;
155   return false;
156 }
157 
LoadImporterDialects(mlir::MLIRContext & context)158 void LoadImporterDialects(mlir::MLIRContext& context) {
159   // Load dialects involved in the conversion
160   mlir::DialectRegistry registry;
161   mlir::RegisterAllTensorFlowDialects(registry);
162   context.appendDialectRegistry(registry);
163   for (llvm::StringRef name : registry.getDialectNames())
164     context.getOrLoadDialect(name);
165 }
166 
167 // This class is used to generate new MLIR function name strings that are both
168 // unique in the TF function library `flib_` and unique among the name strings
169 // generated by the class object during its lifetime.
170 //
171 // In theory, this class is not necessary because we should simply take
172 // the TF function name and use it as MLIR function name. However, for some
173 // unknown reasons (callout for investigation in b/142268695), keeping the
174 // function names unchanged in an MLIR roundtrip causes test failures.
175 // TODO(b/142268695) Re-evaluate whether we need this class v.s. directly using
176 // and TF function name as MLIR function name after b/142268695 is root caused.
177 class NameUniquifier : public OpOrArgNameMapper {
178  public:
NameUniquifier(const FunctionLibraryDefinition & flib)179   explicit NameUniquifier(const FunctionLibraryDefinition& flib)
180       : flib_(flib) {}
181 
182  private:
IsUnique(llvm::StringRef name)183   bool IsUnique(llvm::StringRef name) override {
184     return !flib_.Contains(std::string(name));
185   }
186 
GetName(OpOrVal op_or_val)187   std::string GetName(OpOrVal op_or_val) override {
188     DCHECK(false) << "Unimplemented";
189     return "";
190   }
191 
192   const FunctionLibraryDefinition& flib_;
193 };
194 
195 // Stateful helper class to import a TensorFlow model into an MLIR Module.
196 //
197 // This is the base class that contains common utilities shared between the
198 // GraphDef importer and SavedModel importer.
199 //
200 // A subclass is expected to call `PrepareConvert` first to perform necessary
201 // preparation over the graph and also certain internal bookkeeping data.
202 // Afterwards the other protected methods can be called.
203 class ImporterBase {
204  protected:
ImporterBase(const FunctionLibraryDefinition & flib,const GraphDebugInfo & debug_info,const GraphImportConfig & specs,mlir::ModuleOp module,std::unordered_map<std::string,std::string> * tf_name_to_mlir_name,NameUniquifier * function_name_uniquifier,llvm::StringRef function_name_for_debug_info="")205   explicit ImporterBase(
206       const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info,
207       const GraphImportConfig& specs, mlir::ModuleOp module,
208       std::unordered_map<std::string, std::string>* tf_name_to_mlir_name,
209       NameUniquifier* function_name_uniquifier,
210       llvm::StringRef function_name_for_debug_info = "")
211       : builder_(module.getContext()),
212         module_(module),
213         context_(module.getContext()),
214         tf_name_to_mlir_name_(tf_name_to_mlir_name),
215         graph_flib_(flib),
216         specs_(specs),
217         debug_info_(debug_info),
218         function_name_for_debug_info_(function_name_for_debug_info),
219         function_name_uniquifier_(function_name_uniquifier),
220         error_handler_(module.getContext()) {
221     // Log import config.
222     if (VLOG_IS_ON(1)) {
223       LOG(INFO) << "Importing with: " << specs.str();
224       for (auto& it : *tf_name_to_mlir_name) {
225         LOG(INFO) << "\t" << it.first << " -> " << it.second;
226       }
227     }
228   }
229 
230   // Returns the inferred function signature of the given function body. Input
231   // types are unranked tensor of the respective datatype in the function and
232   // result types are inferred by the shape_refiner_. Result types need not be
233   // unranked tensors and could be ranked tensors in cases where result type
234   // depends on an op with static output shape like tf.Const.
235   StatusOr<mlir::FunctionType> InferLibFunctionType(const FunctionBody& fbody);
236 
237   // Extracts arg and ret nodes from FunctionBody.
238   void GetArgsAndRetsFromFunctionBody(
239       const FunctionBody& fbody,
240       absl::InlinedVector<OutputTensor, 4>* arg_nodes,
241       absl::InlinedVector<OutputTensor, 4>* ret_nodes,
242       absl::InlinedVector<Node*, 4>* control_ret_nodes);
243 
244   // Prepares converting the graph to an MLIR module. This step removes the
245   // backedges of the graph, orders the nodes and infers the shapes.
246   // PrepareConvert needs to ensure that the original `graph` is cloned prior
247   // execution. The cloning procedure relies on the roundtrip through the
248   // GraphDef. Graph to GraphDef def conversion is heavy, in case, `graph_def`
249   // was obtained previously provide it to the PrepareConvert to reuse.
250   Status PrepareConvert(const Graph& graph,
251                         std::unique_ptr<GraphDef> graph_def = nullptr);
252 
253   // Converts the prepared graph to a Function and adds it to the module. A set
254   // of nodes from the graph are given to converted to the arguments and returns
255   // of the function.
256   Status Convert(llvm::StringRef func_name, mlir::FunctionType func_type,
257                  const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
258                  const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
259                  const absl::InlinedVector<Node*, 4>& control_ret_nodes,
260                  llvm::ArrayRef<mlir::NamedAttribute> attrs);
261 
262   // Finds out the function definition for the given function name from the
263   // graph and converts it to a function of the module. This method is called
264   // on demand because the graph flib_def does not provide an iterator
265   // interface.
266   Status ConvertLibFunction(llvm::StringRef func_name);
267 
268   // Returns the list of nodes in the graph. Nodes are presented in the reverse
269   // order of a post-order depth-first visit starting from the graph's source
270   // nodes.
GetOrderedNodes() const271   llvm::ArrayRef<Node*> GetOrderedNodes() const { return ordered_nodes_; }
272 
273   // Returns the inferred input type at index `idx` of the `node` in the
274   // context.
275   StatusOr<mlir::Type> InferInputType(const Node& node, int idx,
276                                       mlir::Builder builder);
277 
278   // Returns the inferred output type at index `idx` of the `node` in the
279   // context.
280   StatusOr<mlir::Type> InferOutputType(const Node& node, int idx,
281                                        mlir::Builder builder);
282 
283   // Convert deferred TF functions to the MLIR representation.
284   // Conversion is deferred for efficiency reasons, e.g., to limit depth
285   // of recursion and reduce stack size pressure.
286   Status ConvertDeferredFunctions();
287 
288  private:
289   // Most types with subtypes have only one subtype.
290   using ElementSubtypes = llvm::SmallVector<TensorType, 1>;
291 
292   // Metadata used for deferred function conversion.
293   struct DeferredConversionMetaData {
DeferredConversionMetaDatatensorflow::__anon4b86074c0111::ImporterBase::DeferredConversionMetaData294     DeferredConversionMetaData(
295         const std::string& function_name,
296         const std::vector<mlir::NamedAttribute>& attributes)
297         : function_name(function_name), attributes(attributes) {}
298 
299     std::string function_name;
300     std::vector<mlir::NamedAttribute> attributes;
301   };
302 
303   // Adds all the ordered_nodes to the shape refiner shape_refiner_. Then all
304   // data type and shape information is maintained by the shape_refiner_.
305   // TODO(jpienaar): Remove once shape inference on import is removed.
306   Status AddNodesToShapeRefiner(
307       std::unordered_map<string, Node*>* node_name_map);
308 
309   // Prune nodes that do not feed into fetch nodes.
310   Status PruneUnreachableNodes(
311       std::unordered_map<string, Node*>* node_name_map);
312 
313   // Converts feeds to Placeholder nodes.
314   Status ConvertFeedsToPlaceholders(
315       std::unordered_map<string, Node*>* node_name_map);
316 
317   // Converts the inferred shape referred to by 'handle' in 'context', with
318   // given element type, and returns an MLIR tensor type.
319   StatusOr<TensorType> ConvertDataTypeAndShape(
320       DataType dtype, const shape_inference::ShapeHandle& handle,
321       const std::vector<shape_inference::ShapeAndType>* handle_subtypes,
322       shape_inference::InferenceContext* context, mlir::Builder builder);
323 
324   // Converts the inferred shape referred to by 'handle' in 'context', with
325   // given element type, and returns an MLIR tensor type.
326   StatusOr<TensorType> ConvertElementTypeAndShape(
327       mlir::Type element_type, const shape_inference::ShapeHandle& handle,
328       shape_inference::InferenceContext* context, mlir::Builder builder);
329 
330   // Converts the inferred subtypes for an element type to corresponding MLIR
331   // types in 'context'.
332   StatusOr<ElementSubtypes> ConvertSubtypes(
333       const std::vector<shape_inference::ShapeAndType>* handle_subtypes,
334       shape_inference::InferenceContext* context, mlir::Builder builder);
335 
336   // Converts the tensor proto into an MLIR elements attribute.
ConvertTensorProto(const TensorProto & value)337   StatusOr<mlir::ElementsAttr> ConvertTensorProto(const TensorProto& value) {
338     return ::tensorflow::ConvertTensorProto(value, &builder_);
339   }
340 
341   // Converts func name in graphdef to mlir::SymbolRefAttribute.
342   StatusOr<mlir::FlatSymbolRefAttr> ConvertFunctionCallName(
343       const std::string& func_name);
344 
345   // Converts the given non-function-call AttrValue to an MLIR Attribute.
346   StatusOr<mlir::Attribute> ConvertAttributeValue(const AttrValue& value);
347 
348   // Converts the given function-call AttrValue to MLIR Attributes and pushes
349   // them to the given attributes list. For example, if there is a kFunc
350   // AttrValue {name : foo, attrs : {k1 : bar, k2 : rfc}}, it will convert it to
351   // a list of MLIR Attributes: [{base_name : foo}, {base_name.k1 : bar},
352   // {base_name.k2 : rfc}}.
353   Status ConvertFunctionCallAttribute(const std::string& base_name,
354                                       const AttrValue& value,
355                                       NamedAttrList* attributes);
356 
357   // Helper to create either a tf_executor operation or a TF operation wrapped
358   // in an island.
359   mlir::Operation* CreateOperation(
360       const Node& node, llvm::StringRef node_type_name,
361       const mlir::OperationState& result,
362       const llvm::SmallVectorImpl<mlir::Value>& control_operands);
363 
364   // Converts one NodeDef from the input GraphDef into an Operation and
365   // inserts it into the MLIR module using builder_.
366   Status ConvertNode(const Node& node);
367 
368   // If the input graph represents a while-loop, the edges pointing from a
369   // "NextIteration" node to a "Merge" node add cyclic dependencies and make the
370   // topological sorting impossible. We need to remove these edges from the
371   // input graph to infer shapes and construct a Function. For each
372   // "NextIteration" node, there are two operations, "NextIteration.source"
373   // and "NextIteration.sink" are added to the MLIR module.
374   using BackEdge = BackEdgeHelper::BackEdge;
375 
376   // Removes backedges from the input graph. The removed edges are added back to
377   // to OpBuilder after the remaining graph is converted to the Function.
378   Status RemoveBackedges();
379 
380   // Restores backedges removed during shape inference to the final Function.
381   Status AddBackedges();
382 
383   // Restores a single backedge in the Function by adding a replicated
384   // operation before the dst operation.
385   Status AddBackedge(mlir::Operation* sink, mlir::Operation* dst,
386                      int dst_input);
387 
388   // Adds the input arguments and return operation to the function. The
389   // arguments are added as basic block argument. Also the argument types and
390   // the id of the nodes from the input graph needs to be specified.
391   Status ConvertFunctionArgAndRets(
392       mlir::FuncOp func, mlir::tf_executor::GraphOp graph_op,
393       llvm::ArrayRef<mlir::Type> arg_types,
394       const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
395       const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
396       const absl::InlinedVector<Node*, 4>& control_ret_nodes);
397 
398   // Gets the location information of the given node. It uses the
399   // "original_node_name" in the NodeDef to get the corresponding file location
400   // (FileLineColLoc) from the input DebugInfo and returns an CallSiteLoc. If
401   // there are multiple "original_node_names", a FusedLoc is returned. If the
402   // node name couldn't be found in the input DebugInfo, a NameLoc is used as
403   // the location.
404   mlir::Location GetLocation(const Node& node);
405 
406   // Appends the location string for the node to the error message and returns
407   // the combined error status.
408   Status EmitErrorWithLocationStr(const Node& node, const Status& error_status);
409 
410   // Inserts a placeholder node in the graph to replace a feed output tensor,
411   // and returns the new placeholder node and a boolean indicating if the
412   // original input node was removed from the graph. Uses of the feed output
413   // tensor are replaced with this placeholder node. If the feed output tensor
414   // is of a single output node, the control dependencies are forwarded to the
415   // the placeholder node, and the original node will be removed.
416   // Note: This modifies the graph, and so any list of ordered nodes needs to be
417   // reconstructed.
418   StatusOr<std::pair<Node*, bool>> CreatePlaceholderNodeForFeed(
419       const TensorShapeProto& shape, DataType dtype, Node* node, int index,
420       const std::unordered_map<string, Node*>& node_name_map);
421 
422   // Gets the input and output nodes corresponding to the specified input and
423   // output nodes in specs_. If there are no input or output nodes specified,
424   // nodes will be empty.
425   Status GetInputOutputNodes(
426       const std::unordered_map<string, Node*>& node_name_map,
427       std::unordered_set<const Node*>* nodes);
428 
GetUnmodelledOpTypes()429   llvm::StringSet<>& GetUnmodelledOpTypes() {
430     // All the TF ops encountered that aren't modelled in dialect.
431     static auto* unmodelled_op_types = new llvm::StringSet<>();
432     return *unmodelled_op_types;
433   }
434 
435   // The input graph with backedges removed. The removed backedges are stored
436   // in the back_edge_helper.
437   BackEdgeHelper back_edge_helper_;
438   // A map between node and output index, for each backedge.
439   absl::flat_hash_map<const Node*, int> back_edge_node_output_;
440   absl::flat_hash_map<const Node*, BackEdge> back_edge_dst_inputs_;
441   // A map between sink and source operation of NextIteration
442   absl::flat_hash_map<mlir::Operation*, mlir::Operation*>
443       next_iteration_sink_source_;
444 
445   // All nodes and version information about the (copied) imported graph.
446   std::unique_ptr<Graph> graph_;
447   std::vector<Node*> ordered_nodes_;
448 
449   // Maps from a Node ID to a MLIR value.
450   using NodeValueMap = absl::flat_hash_map<int, mlir::Operation*>;
451 
452   mlir::OpBuilder builder_;
453   mlir::ModuleOp module_;
454   mlir::MLIRContext* context_;
455   std::unordered_map<std::string, std::string>* tf_name_to_mlir_name_;
456   const FunctionLibraryDefinition& graph_flib_;
457   const GraphImportConfig& specs_;
458   const GraphDebugInfo& debug_info_;
459   llvm::StringRef function_name_for_debug_info_;
460   NodeValueMap node_values_;
461   // TODO(jpienaar): Remove once shape inference on import is removed.
462   // The shape_refinner_ will be nullptr if shape inference on import is
463   // not enabled.
464   std::unique_ptr<ShapeRefiner> shape_refiner_ = nullptr;
465   NameUniquifier* function_name_uniquifier_;
466   mlir::StatusScopedDiagnosticHandler error_handler_;
467 
468  protected:
469   // Maps feed as TensorId to new Placeholder node name.
470   absl::flat_hash_map<TensorId, absl::string_view> remapped_feeds_;
471   // Keep track of functions required deferred conversion.
472   std::queue<DeferredConversionMetaData> deferred_functions_;
473 };
474 
475 // Returns true if the node with given name has a non primary output that is
476 // used by some other node as an input. Returns false if no outputs are in use
477 // or only the first output is in use.
HasNonPrimaryOutputInUse(const GraphDef & graph_def,const std::string & node)478 bool HasNonPrimaryOutputInUse(const GraphDef& graph_def,
479                               const std::string& node) {
480   for (const auto& node_def : graph_def.node()) {
481     for (const auto& input : node_def.input()) {
482       if (absl::StartsWith(input, node + ":") && input != node + ":0") {
483         return true;
484       }
485     }
486   }
487   return false;
488 }
489 
490 // Updates the given LegacyFedInput node with Placeholder node if it is one of
491 // the inputs. Returns an error if non primary output of the LegacyFedInput node
492 // is in use and therefore can not be replaced by the Placeholder node that only
493 // has a single output.
UpdateLegacyFedInputNode(const GraphDef & graph_def,const GraphImportConfig::InputArrays & inputs,NodeDef * node)494 Status UpdateLegacyFedInputNode(const GraphDef& graph_def,
495                                 const GraphImportConfig::InputArrays& inputs,
496                                 NodeDef* node) {
497   const std::string& node_name = node->name();
498   auto it = inputs.find(node_name);
499 
500   // Node is not an input.
501   if (it == inputs.end()) return Status::OK();
502 
503   if (HasNonPrimaryOutputInUse(graph_def, node_name)) {
504     return errors::InvalidArgument(
505         "LegacyFedInput node ", node->name(),
506         " has non primary output in use and can not be replaced with "
507         "Placeholder node");
508   }
509 
510   DataType dtype = it->second.imported_dtype;
511   // Uses the existing output type if it isn't specified by the user.
512   if (dtype == DT_INVALID) {
513     dtype = node->attr().at("output_types").list().type(0);
514   }
515   // Update op name, drop inputs and set attributes required by the Placeholder
516   // op.
517   *node->mutable_op() = "Placeholder";
518   node->clear_attr();
519   node->clear_input();
520   AddNodeAttr("dtype", dtype, node);
521   AddNodeAttr("shape", it->second.shape, node);
522   return Status::OK();
523 }
524 
525 // Preprocesses GraphDef before it can be converted to Graph by,
526 // - Adding the default attributes to each node def if they are missing from
527 //   the GraphDef.
528 // - Replacing LegacyFedInput nodes with Placeholder nodes if
529 //   convert_legacy_fed_inputs option is enabled.
PreprocessGraphDef(const GraphImportConfig * specs,GraphDef * graph_def)530 Status PreprocessGraphDef(const GraphImportConfig* specs, GraphDef* graph_def) {
531   for (auto& node_def : *graph_def->mutable_node()) {
532     // TODO(hinsu): Completely deprecate support for LegacyFedInput ops. One
533     // solution could be have a tool to let users upgrade old serialized graphs.
534     if (specs && specs->convert_legacy_fed_inputs &&
535         node_def.op() == "LegacyFedInput") {
536       TF_RETURN_IF_ERROR(
537           UpdateLegacyFedInputNode(*graph_def, specs->inputs, &node_def));
538     }
539 
540     const tensorflow::OpRegistrationData* op_reg_data =
541         tensorflow::OpRegistry::Global()->LookUp(node_def.op());
542     if (!op_reg_data) {
543       // This is likely a function call node, so we should continue.
544       continue;
545     }
546     ::tensorflow::AddDefaultsToNodeDef(op_reg_data->op_def, &node_def);
547   }
548   return Status::OK();
549 }
550 
551 // Mapping from node name to feed (index and ArrayInfo). Node name must outlive
552 // this map.
553 using FeedsByNode = absl::flat_hash_map<
554     absl::string_view,
555     absl::flat_hash_map<int, const std::pair<std::string, ArrayInfo>*>>;
556 
557 // Creates from a `GraphImportConfig::InputArrays` a mapping from a feeds output
558 // tensor name to index and ArrayInfo. Keys and values are backed by
559 // `GraphImportConfig::InputArrays`.
GetFeedsByNode(const GraphImportConfig::InputArrays & inputs)560 StatusOr<FeedsByNode> GetFeedsByNode(
561     const GraphImportConfig::InputArrays& inputs) {
562   FeedsByNode feeds_by_node;
563   feeds_by_node.reserve(inputs.size());
564 
565   for (const auto& input : inputs) {
566     TensorId tensor = ParseTensorName(input.first);
567     if (tensor.index() < 0)
568       return errors::FailedPrecondition(
569           "Feed output tensor must be a data output '", tensor.ToString(), "'");
570 
571     auto& node = feeds_by_node[tensor.node()];
572     if (!node.insert({tensor.index(), &input}).second)
573       return errors::FailedPrecondition(
574           "Multiple feeds for the same output tensor '", tensor.ToString(),
575           "'");
576   }
577 
578   return feeds_by_node;
579 }
580 
581 // Creates a unique name for a node that will be replacing a feed output tensor.
GetUniqueNodeName(absl::string_view node_name,int index,const std::unordered_map<string,Node * > & node_name_map)582 std::string GetUniqueNodeName(
583     absl::string_view node_name, int index,
584     const std::unordered_map<string, Node*>& node_name_map) {
585   std::string new_node_name_base = absl::StrCat(node_name, "_", index);
586   int count = 0;
587   std::string new_node_name = new_node_name_base;
588   while (node_name_map.find(new_node_name) != node_name_map.end()) {
589     new_node_name = absl::StrCat(new_node_name_base, "_", count++);
590   }
591   return new_node_name;
592 }
593 
ConvertDeferredFunctions()594 Status ImporterBase::ConvertDeferredFunctions() {
595   while (!deferred_functions_.empty()) {
596     auto conversion_metadata = deferred_functions_.front();
597     deferred_functions_.pop();
598 
599     const FunctionDef* func_def =
600         graph_flib_.Find(conversion_metadata.function_name);
601     // Converts the graph to an MLIR function and adds it to the module.
602     // We populate the NodeSpec so that all the _Arg ops get their shape
603     // added correctly.
604     GraphImportConfig specs;
605     specs.enable_shape_inference = specs_.enable_shape_inference;
606     for (const auto& name_and_value : func_def->attr()) {
607       if (name_and_value.first == "_input_shapes") {
608         auto& list = name_and_value.second.list();
609         auto& signature = func_def->signature();
610         // Some models have "_input_shapes" attribute, but with its value empty
611         if (list.shape_size() > 0 &&
612             list.shape_size() != signature.input_arg_size()) {
613           return errors::FailedPrecondition(
614               "Number of input arguments must be equal to the length of "
615               "_input_shapes attribute in function '",
616               StringRefToView(conversion_metadata.function_name), "'.");
617         }
618         for (int i = 0, e = signature.input_arg_size(); i < e; i++) {
619           auto& input_arg = signature.input_arg(i);
620           auto& array_info = specs.inputs[input_arg.name()];
621           array_info.imported_dtype = input_arg.type();
622           // set to unranked for empty "_input_shapes" attribute
623           if (list.shape_size() > 0)
624             array_info.shape = list.shape(i);
625           else
626             array_info.shape.set_unknown_rank(true);
627         }
628       }
629     }
630 
631     ImporterBase importer(graph_flib_, debug_info_, specs, module_,
632                           tf_name_to_mlir_name_, function_name_uniquifier_,
633                           conversion_metadata.function_name);
634 
635     std::unique_ptr<FunctionBody> fbody;
636     TF_RETURN_IF_ERROR(
637         FunctionDefToBodyHelper(*func_def, AttrSlice(), &graph_flib_, &fbody));
638     TF_RETURN_IF_ERROR(importer.PrepareConvert(*fbody->graph));
639 
640     TF_ASSIGN_OR_RETURN(auto func_type, importer.InferLibFunctionType(*fbody));
641 
642     absl::InlinedVector<OutputTensor, 4> arg_nodes;
643     absl::InlinedVector<OutputTensor, 4> ret_nodes;
644     absl::InlinedVector<Node*, 4> control_ret_nodes;
645     importer.GetArgsAndRetsFromFunctionBody(*fbody, &arg_nodes, &ret_nodes,
646                                             &control_ret_nodes);
647     const std::string& mlir_func_name =
648         (*tf_name_to_mlir_name_)[conversion_metadata.function_name];
649 
650     TF_RETURN_IF_ERROR(importer.Convert(mlir_func_name, func_type, arg_nodes,
651                                         ret_nodes, control_ret_nodes,
652                                         conversion_metadata.attributes));
653 
654     // Additional function bodies could be discovered during the deferred
655     // loading of the current function. Add them to the working queue.
656     while (!importer.deferred_functions_.empty()) {
657       deferred_functions_.push(importer.deferred_functions_.front());
658       importer.deferred_functions_.pop();
659     }
660   }
661 
662   return Status::OK();
663 }
664 
RemoveBackedges()665 Status ImporterBase::RemoveBackedges() {
666   // Remove all the backedges. So the nodes can be added to the shape refiner.
667   TF_RETURN_IF_ERROR(back_edge_helper_.Remove(graph_.get()));
668   VLOG(1) << "Found " << (back_edge_helper_.RemovedEdges().size())
669           << " backedges.";
670 
671   // Creates a map for quickly identifying whether a node output is a backedge.
672   for (const auto& edge : back_edge_helper_.RemovedEdges()) {
673     if (back_edge_node_output_.find(edge.src) != back_edge_node_output_.end() &&
674         back_edge_node_output_[edge.src] != edge.src_output) {
675       return errors::FailedPrecondition(
676           "More than one of the src node outputs are backedges!");
677     }
678     back_edge_node_output_[edge.src] = edge.src_output;
679     // We expect a merge to receive a single backedge (multiple NextIteration
680     // nodes feeding into the same merge is unexpected here).
681     DCHECK(!back_edge_dst_inputs_.contains(edge.dst));
682     back_edge_dst_inputs_[edge.dst] = edge;
683   }
684 
685   // Obtains a RPO ordering, using node names as a tiebreak for stable sorting.
686   GetReversePostOrder(
687       *graph_, &ordered_nodes_,
688       [](const Node* n1, const Node* n2) { return n1->name() < n2->name(); });
689   return Status::OK();
690 }
691 
CopyStackTraces(const Graph & from,Graph * to)692 Status CopyStackTraces(const Graph& from, Graph* to) {
693   // Copy over the stack traces.
694   // TODO(jpienaar): This really shouldn't be needed, copying the Graph above
695   // and then needing these traversals is unfortunate.
696   std::unordered_map<string, Node*> node_map = from.BuildNodeNameIndex();
697   for (Node* node : to->nodes()) {
698     if (const Node* old_node = node_map[node->name()]) {
699       if (const std::shared_ptr<AbstractStackTrace>& stack =
700               old_node->GetStackTrace()) {
701         DVLOG(2) << "Stack for " << node->name() << " "
702                  << old_node->GetStackTrace()->ToString(
703                         AbstractStackTrace::TracePrintingOptions());
704         node->SetStackTrace(stack);
705       } else {
706         DVLOG(1) << "No stack for " << node->name() << " (" << node
707                  << ") in Graph " << &from;
708       }
709     } else {
710       DVLOG(1) << "No stack for " << node->name() << " (" << node
711                << ") in Graph " << &from;
712     }
713   }
714 
715   return Status::OK();
716 }
717 
CreatePlaceholderNodeForFeed(const TensorShapeProto & shape,DataType dtype,Node * node,int index,const std::unordered_map<string,Node * > & node_name_map)718 StatusOr<std::pair<Node*, bool>> ImporterBase::CreatePlaceholderNodeForFeed(
719     const TensorShapeProto& shape, DataType dtype, Node* node, int index,
720     const std::unordered_map<string, Node*>& node_name_map) {
721   DCHECK_LT(index, node->num_outputs());
722   const bool update_inplace = node->num_outputs() == 1 && index == 0;
723   std::string new_node_name =
724       update_inplace ? node->name()
725                      : GetUniqueNodeName(node->name(), index, node_name_map);
726 
727   Node* placeholder_node;
728   NodeBuilder builder(new_node_name, "Placeholder");
729   builder.Attr("shape", shape);
730   builder.Attr("dtype", dtype);
731   TF_RETURN_IF_ERROR(builder.Finalize(graph_.get(), &placeholder_node));
732 
733   // Update edges from original feed with Placeholder node.
734   std::vector<const Edge*> data_edges;
735   std::vector<const Edge*> control_edges;
736   for (const tensorflow::Edge* edge : node->out_edges()) {
737     if (edge->src_output() == index) {
738       data_edges.push_back(edge);
739     } else if (update_inplace && edge->IsControlEdge()) {
740       control_edges.push_back(edge);
741     }
742   }
743 
744   for (const auto* edge : data_edges) {
745     TF_RETURN_IF_ERROR(graph_->UpdateEdge(placeholder_node, 0, edge->dst(),
746                                           edge->dst_input()));
747   }
748 
749   // TODO(lyandy): Preserve control dependencies properly by not forwarding
750   // control dependencies to data outputs and not removing single output nodes.
751   // When a data output is replaced as a feed, unless there is another non feed
752   // data output or an explicit control output used by the same node, transitive
753   // control dependencies are not to be executed. For single output nodes,
754   // Placeholders can be converted to a NoOp if there are no uses, and
755   // PlaceholderWithDefault can be converted to an Identity.
756   for (const auto* edge : control_edges) {
757     graph_->AddControlEdge(placeholder_node, edge->dst());
758     graph_->RemoveControlEdge(edge);
759   }
760 
761   if (update_inplace) {
762     graph_->RemoveNode(node);
763   }
764 
765   return std::pair<Node*, bool>(placeholder_node, update_inplace);
766 }
767 
GetInputOutputNodes(const std::unordered_map<string,Node * > & node_name_map,std::unordered_set<const Node * > * nodes)768 Status ImporterBase::GetInputOutputNodes(
769     const std::unordered_map<string, Node*>& node_name_map,
770     std::unordered_set<const Node*>* nodes) {
771   auto add_node = [&](absl::string_view name) {
772     auto it = node_name_map.find(std::string(name));
773     if (it == node_name_map.end()) {
774       return errors::FailedPrecondition(
775           absl::StrCat("Graph does not contain node: ", name));
776     }
777     nodes->insert(it->second);
778     return Status::OK();
779   };
780 
781   // Remap feeds and fetches to newly created Placeholder nodes.
782   for (const auto& input : specs_.inputs) {
783     TensorId tensor = ParseTensorName(input.first);
784     auto remapped_it = remapped_feeds_.find(tensor);
785     if (remapped_it != remapped_feeds_.end()) {
786       TF_RETURN_IF_ERROR(add_node(remapped_it->second));
787     } else {
788       TF_RETURN_IF_ERROR(add_node(tensor.node()));
789     }
790   }
791 
792   for (const auto& output : specs_.outputs) {
793     TensorId tensor = ParseTensorName(output);
794     auto remapped_it = remapped_feeds_.find(tensor);
795     if (remapped_it != remapped_feeds_.end()) {
796       TF_RETURN_IF_ERROR(add_node(remapped_it->second));
797     } else {
798       TF_RETURN_IF_ERROR(add_node(tensor.node()));
799     }
800   }
801 
802   for (const auto& control_output : specs_.control_outputs)
803     TF_RETURN_IF_ERROR(add_node(control_output));
804 
805   return Status::OK();
806 }
807 
808 // TODO(jpienaar): Remove this post shape inference on import flag is removed.
AddNodesToShapeRefiner(std::unordered_map<string,Node * > * node_name_map)809 Status ImporterBase::AddNodesToShapeRefiner(
810     std::unordered_map<string, Node*>* node_name_map) {
811   shape_refiner_ = absl::make_unique<ShapeRefiner>(graph_->versions(),
812                                                    graph_->op_registry());
813   // Some operations (for example "TPUExecute") don't have shape inference
814   // function defined, so we should set this to false for adding nodes with
815   // these types of operations.
816   shape_refiner_->set_require_shape_inference_fns(false);
817   shape_refiner_->set_function_library_for_shape_inference(&graph_flib_);
818 
819   TF_ASSIGN_OR_RETURN(auto feeds_by_node, GetFeedsByNode(specs_.inputs));
820 
821   // First add all nodes to the refiner.
822   for (Node* node : ordered_nodes_) {
823     // We need to use a TensorFlow node to teach the shape refiner that user
824     // specifies certain data type and shape for the inputs in the `specs_`.
825     // This node shouldn't have any inputs, only have one output and its
826     // output type/shape is only determined by its "named" attributes. (The
827     // attributes should have fixed names so we can use the info from `specs_`
828     // to set the value of them.) `Placeholder` satisfies these constraints.
829     //
830     // Therefore, if the input node isn't a `Placeholder`, we create one and use
831     // it to replace the original input node, so the shape refiner can
832     // successfully propagate the user's input type and shape to the rest of the
833     // graph.
834     bool node_added_to_shape_refiner = false;
835     auto it = feeds_by_node.find(node->name());
836     if (it != feeds_by_node.end()) {
837       auto op_name = node->op_def().name();
838       if (op_name != "Placeholder" && op_name != "LegacyFedInput" &&
839           op_name != FunctionLibraryDefinition::kArgOp) {
840         for (const auto& output_tensor : it->second) {
841           const int index = output_tensor.first;
842           const ArrayInfo& array_info = output_tensor.second->second;
843 
844           DataType dtype = array_info.imported_dtype;
845           // Uses the existing output type if it isn't specified by the user.
846           if (dtype == DT_INVALID) {
847             dtype = node->output_type(index);
848           }
849 
850           TF_ASSIGN_OR_RETURN(
851               auto placeholder_node_and_removed,
852               CreatePlaceholderNodeForFeed(array_info.shape, dtype, node, index,
853                                            *node_name_map));
854 
855           Node* placeholder_node = placeholder_node_and_removed.first;
856           if (placeholder_node_and_removed.second) {
857             // Original node has been removed from the graph.
858             node = placeholder_node;
859             node_added_to_shape_refiner = true;
860           }
861           remapped_feeds_[{it->first, index}] = placeholder_node->name();
862           (*node_name_map)[placeholder_node->name()] = placeholder_node;
863           // Add the new placeholder node to the shape refiner.
864           Status status = shape_refiner_->AddNode(placeholder_node);
865           if (!status.ok()) {
866             return EmitErrorWithLocationStr(*placeholder_node, status);
867           }
868         }
869       } else {
870         auto index_it = it->second.find(0);
871         if (index_it == it->second.end()) {
872           return errors::FailedPrecondition(
873               "Missing feed output tensor at index 0 for node '", node->name(),
874               "'");
875         }
876         node->AddAttr("shape", index_it->second->second.shape);
877         DataType dtype = index_it->second->second.imported_dtype;
878         // Uses the existing output type if it isn't specified by the user.
879         if (dtype == DT_INVALID) {
880           dtype = node->output_type(0);
881         }
882         node->AddAttr("dtype", dtype);
883       }
884     }
885     if (!node_added_to_shape_refiner) {
886       // Add the node to the shape refiner if the node hasn't been removed.
887       Status status = shape_refiner_->AddNode(node);
888       if (!status.ok()) {
889         return EmitErrorWithLocationStr(*node, status);
890       }
891     }
892 
893     auto set_shape_from_list_attr = [&](const AttrValue* attr) {
894       auto& list = attr->list();
895       for (auto shape : llvm::enumerate(list.shape())) {
896         auto* node_context = shape_refiner_->GetContext(node);
897         shape_inference::ShapeHandle handle;
898         Status status =
899             node_context->MakeShapeFromShapeProto(shape.value(), &handle);
900         if (!status.ok()) {
901           return EmitErrorWithLocationStr(*node, status);
902         }
903         node_context->set_output(shape.index(), handle);
904       }
905       return Status::OK();
906     };
907 
908     // We currently have no other way to get shapes from ReadVariableOp's.
909     // Some graphs seem to have _output_shapes attributes on them, so use that
910     // if possible.
911     // Note: _output_shapes are optionally set when the user exports the graph
912     // and it is not guaranteed (nor an error if missing). There is not a
913     // promised contract, so effectively a heuristic.
914     if (node->op_def().name() == "ReadVariableOp") {
915       if (const AttrValue* attr = node->attrs().Find("_output_shapes"))
916         TF_RETURN_IF_ERROR(set_shape_from_list_attr(attr));
917     }
918 
919     // If it is the argument node, the shape handle is set explicitly, so it
920     // can be propagated to the body nodes of the function.
921     if (StringPiece(node->type_string()) == FunctionLibraryDefinition::kArgOp) {
922       auto* node_context = shape_refiner_->GetContext(node);
923       DCHECK(node_context != nullptr);
924       if (const AttrValue* attr = node->attrs().Find("shape")) {
925         shape_inference::ShapeHandle handle;
926         Status status =
927             node_context->MakeShapeFromShapeProto(attr->shape(), &handle);
928         if (!status.ok()) {
929           return EmitErrorWithLocationStr(*node, status);
930         }
931         node_context->set_output(0, handle);
932       } else if (const AttrValue* attr = node->attrs().Find("_output_shapes")) {
933         TF_RETURN_IF_ERROR(set_shape_from_list_attr(attr));
934       } else {
935         node_context->set_output(0, node_context->UnknownShape());
936       }
937     }
938   }
939 
940   // Since we might have inserted and removed nodes from the graph, fix
941   // source/sink edges and reconstruct the RPO ordering of nodes
942   FixupSourceAndSinkEdges(graph_.get());
943 
944   // Prune nodes in the graph that are not reachable from the output.
945   if (specs_.prune_unused_nodes) {
946     std::unordered_set<const Node*> prune_start;
947     TF_RETURN_IF_ERROR(GetInputOutputNodes(*node_name_map, &prune_start));
948     if (!prune_start.empty()) {
949       if (PruneForReverseReachability(graph_.get(), prune_start)) {
950         VLOG(1) << "Pruned unused nodes in graphdef";
951       } else {
952         VLOG(1) << "No unused nodes in graphdef to prune";
953       }
954     } else {
955       VLOG(1) << "No output nodes specified, skipping pruning";
956     }
957   } else {
958     VLOG(1) << "Pruning unused nodes in graphdef is disabled";
959   }
960 
961   // Re-initialize ordered_nodes_ since we might have modified the graph.
962   GetReversePostOrder(
963       *graph_, &ordered_nodes_,
964       [](const Node* n1, const Node* n2) { return n1->name() < n2->name(); });
965 
966   VLOG(1) << "Inferring graph shapes to fixpoint";
967 
968   // The "changed" information from UpdateNode can give false positives, so we
969   // create a dedicated method to verify the shapes are not changed before and
970   // after the shape refine.
971   auto same_inferred_shape = [](shape_inference::InferenceContext* c,
972                                 shape_inference::ShapeHandle s0,
973                                 shape_inference::ShapeHandle s1) -> bool {
974     if (s0.SameHandle(s1) || (!c->RankKnown(s0) && !c->RankKnown(s1))) {
975       return true;
976     }
977     if (c->Rank(s0) != c->Rank(s1)) {
978       return false;
979     }
980     for (int i = 0; i < c->Rank(s0); ++i) {
981       if (!c->Dim(s0, i).SameHandle(c->Dim(s1, i))) {
982         int64_t val0 = c->Value(c->Dim(s0, i));
983         int64_t val1 = c->Value(c->Dim(s1, i));
984         // Negative value is treated as unknown so all negative values indicate
985         // the same dimension.
986         if (val0 >= 0 && val1 >= 0 && val0 != val1) return false;
987       }
988     }
989     return true;
990   };
991 
992   bool changed = true;
993   int i = 0;
994   const int kMaxIterationCount = 2;
995   while (changed && i != kMaxIterationCount) {
996     changed = false;
997     for (const Node* node : ordered_nodes_) {
998       auto* shape_context = shape_refiner_->GetContext(node);
999       DCHECK(shape_context != nullptr);
1000       absl::InlinedVector<shape_inference::ShapeHandle, 4> existing;
1001       existing.reserve(shape_context->num_outputs());
1002       for (int o = 0; o < shape_context->num_outputs(); ++o) {
1003         existing.push_back(shape_context->output(o));
1004       }
1005       bool inferred = false;
1006       shape_inference::ShapeHandle handle;
1007       Status status =
1008           shape_refiner_->UpdateNode(node, /*relax=*/false, &inferred);
1009       if (!status.ok()) {
1010         return EmitErrorWithLocationStr(*node, status);
1011       }
1012       for (int o = 0; o < shape_context->num_outputs(); ++o) {
1013         if (!same_inferred_shape(shape_context, shape_context->output(o),
1014                                  existing[o])) {
1015           changed = true;
1016           break;
1017         }
1018       }
1019     }
1020     ++i;
1021   }
1022   if (i >= kMaxIterationCount) {
1023     LOG(WARNING) << "Graph shapes did not converge to a fixpoint within "
1024                  << kMaxIterationCount
1025                  << " iterations. Graph shapes may be conservative.";
1026   }
1027   VLOG(1) << "Graph shapes were inferred with " << (i - 1)
1028           << " extra rounds of analysis to reach a fixpoint.";
1029   return Status::OK();
1030 }
1031 
InferInputType(const Node & node,int idx,mlir::Builder builder)1032 StatusOr<mlir::Type> ImporterBase::InferInputType(const Node& node, int idx,
1033                                                   mlir::Builder builder) {
1034   if (specs_.enable_shape_inference) {
1035     // TODO(jpienaar): Remove this if shape inference on import flag is removed.
1036     ExtendedInferenceContext* shape_context =
1037         shape_refiner_->GetExtendedContext(&node);
1038     DataType dtype = shape_context->input_type(idx);
1039     auto* context = shape_context->get_context();
1040     return ConvertDataTypeAndShape(dtype, context->input(idx),
1041                                    context->input_handle_shapes_and_types(idx),
1042                                    context, builder);
1043   }
1044   DataType dtype = node.properties()->input_types[idx];
1045   mlir::Type element_type;
1046   TF_RETURN_IF_ERROR(ConvertDataType(dtype, builder, &element_type));
1047   return mlir::UnrankedTensorType::get(element_type);
1048 }
1049 
InferOutputType(const Node & node,int idx,mlir::Builder builder)1050 StatusOr<mlir::Type> ImporterBase::InferOutputType(const Node& node, int idx,
1051                                                    mlir::Builder builder) {
1052   DataType dtype = node.properties()->output_types[idx];
1053 
1054   // Returns output type given inference context.
1055   auto shape_ic = [&](shape_inference::InferenceContext* c) {
1056     return ConvertDataTypeAndShape(dtype, c->output(idx),
1057                                    c->output_handle_shapes_and_types(idx), c,
1058                                    builder);
1059   };
1060 
1061   if (specs_.enable_shape_inference) {
1062     // TODO(jpienaar): Remove this if shape inference on import flag is removed.
1063     ExtendedInferenceContext* shape_context =
1064         shape_refiner_->GetExtendedContext(&node);
1065     return shape_ic(shape_context->get_context());
1066   }
1067 
1068   // Treat TensorList init ops specially here as the op requires knowing its
1069   // element dtype.
1070   // TODO(jpienaar): Reconsider post refactoring shape functions.
1071   if (node.type_string() == "TensorListReserve" ||
1072       node.type_string() == "EmptyTensorList") {
1073     mlir::Type etype;
1074     if (auto element_dtype = node.attrs().Find("element_dtype")) {
1075       TF_RETURN_IF_ERROR(
1076           ConvertDataType(element_dtype->type(), builder, &etype));
1077     }
1078     return mlir::RankedTensorType::get(
1079         {}, mlir::TF::VariantType::get({mlir::UnrankedTensorType::get(etype)},
1080                                        etype.getContext()));
1081   }
1082 
1083   if (node.IsWhileNode()) {
1084     auto* output_shapes = node.attrs().Find("output_shapes");
1085     auto* element_types = node.attrs().Find("T");
1086     if (output_shapes && !output_shapes->list().shape().empty()) {
1087       const auto& output_shape = output_shapes->list().shape(idx);
1088       const auto& element_type = element_types->list().type(idx);
1089       return ConvertToMlirTensorType(output_shape, element_type, &builder);
1090     }
1091   }
1092 
1093   auto type_from_array_attr = [&node, &idx, &builder](
1094                                   absl::string_view output_shape_attr,
1095                                   absl::string_view element_type_attr) {
1096     auto* output_shapes = node.attrs().Find(output_shape_attr);
1097     auto* element_types = node.attrs().Find(element_type_attr);
1098     const auto& output_shape = output_shapes->list().shape(idx);
1099     const auto& element_type = element_types->list().type(idx);
1100     return ConvertToMlirTensorType(output_shape, element_type, &builder);
1101   };
1102 
1103   if (node.type_string() == "IteratorGetNext" ||
1104       node.type_string() == "IteratorGetNextSync" ||
1105       node.type_string() == "MultiDeviceIteratorGetNextFromShard")
1106     return type_from_array_attr("output_shapes", "output_types");
1107 
1108   if (node.type_string() == "InfeedDequeueTuple")
1109     return type_from_array_attr("shapes", "dtypes");
1110 
1111   if (node.type_string() == "InfeedDequeue") {
1112     assert(idx == 0);
1113     const auto& output_shape = node.attrs().Find("shape")->shape();
1114     const auto& element_type = node.attrs().Find("dtype")->type();
1115     return ConvertToMlirTensorType(output_shape, element_type, &builder);
1116   }
1117 
1118   // Returns a simple, more conservative unranked tensor type.
1119   auto default_type = [&]() -> StatusOr<mlir::Type> {
1120     mlir::Type element_type;
1121     TF_RETURN_IF_ERROR(ConvertDataType(dtype, builder, &element_type));
1122     return mlir::UnrankedTensorType::get(element_type);
1123   };
1124 
1125   // Below we only try and do some shape inference for "source" ops which have
1126   // no inputs.
1127   if (node.num_inputs() > 0) return default_type();
1128 
1129   // Do some simply inference here to get the function arguments correct for
1130   // this common case.
1131   // TODO(jpienaar): Reconsider post refactoring shape functions.
1132   if (node.IsArg()) {
1133     if (dtype == DT_RESOURCE) {
1134       const AttrValue* dtype_attr = node.attrs().Find("_handle_dtypes");
1135       const AttrValue* shape_attr = node.attrs().Find("_handle_shapes");
1136       if (dtype_attr && shape_attr) {
1137         if (dtype_attr->list().type().empty()) {
1138           return errors::InvalidArgument(
1139               "Invalid \"_handle_dtypes\" attribute value for _Arg node: ",
1140               shape_attr->DebugString());
1141         }
1142         if (shape_attr->list().shape().empty()) {
1143           return errors::InvalidArgument(
1144               "Invalid \"_handle_shapes\" attribute value for _Arg node: ",
1145               shape_attr->DebugString());
1146         }
1147         DataType dtype = dtype_attr->list().type(0);
1148         const TensorShapeProto& shape_proto = shape_attr->list().shape(0);
1149         TF_ASSIGN_OR_RETURN(
1150             auto etype, ConvertToMlirTensorType(shape_proto, dtype, &builder));
1151         return mlir::UnrankedTensorType::get(mlir::TF::ResourceType::get(
1152             {etype.cast<TensorType>()}, builder.getContext()));
1153       } else {
1154         return mlir::UnrankedTensorType::get(
1155             mlir::TF::ResourceType::get(builder.getContext()));
1156       }
1157     } else if (auto shape = node.attrs().Find("_output_shapes")) {
1158       if (shape->has_list() && shape->list().shape_size() == 1) {
1159         return ConvertToMlirTensorType(shape->list().shape().at(0), dtype,
1160                                        &builder);
1161       }
1162     }
1163   }
1164 
1165   const tensorflow::OpRegistrationData* op_reg_data;
1166   TF_RETURN_IF_ERROR(
1167       graph_->op_registry()->LookUp(node.type_string(), &op_reg_data));
1168   if (!op_reg_data) {
1169     DVLOG(1) << "Skipping inference for unregistered op " << node.type_string();
1170     return default_type();
1171   }
1172   if (op_reg_data->shape_inference_fn == nullptr) {
1173     DVLOG(1) << "Skipping inference for op without shape function "
1174              << node.type_string();
1175     return default_type();
1176   }
1177   shape_inference::InferenceContext c(graph_->versions().producer(),
1178                                       node.attrs(), op_reg_data->op_def,
1179                                       std::vector<PartialTensorShape>{}, {},
1180                                       /*input_tensors_as_shapes=*/{}, {});
1181   TF_RETURN_IF_ERROR(c.Run(op_reg_data->shape_inference_fn));
1182   return shape_ic(&c);
1183 }
1184 
ConvertDataTypeAndShape(DataType dtype,const shape_inference::ShapeHandle & handle,const std::vector<shape_inference::ShapeAndType> * handle_subtypes,shape_inference::InferenceContext * context,mlir::Builder builder)1185 StatusOr<TensorType> ImporterBase::ConvertDataTypeAndShape(
1186     DataType dtype, const shape_inference::ShapeHandle& handle,
1187     const std::vector<shape_inference::ShapeAndType>* handle_subtypes,
1188     shape_inference::InferenceContext* context, mlir::Builder builder) {
1189   TF_ASSIGN_OR_RETURN(auto subtypes,
1190                       ConvertSubtypes(handle_subtypes, context, builder));
1191 
1192   mlir::Type element_type;
1193   if (dtype == DT_VARIANT)
1194     element_type = mlir::TF::VariantType::get(subtypes, context_);
1195   else if (dtype == DT_RESOURCE)
1196     element_type = mlir::TF::ResourceType::get(subtypes, context_);
1197   else
1198     TF_RETURN_IF_ERROR(
1199         ::tensorflow::ConvertDataType(dtype, builder, &element_type));
1200 
1201   return ConvertElementTypeAndShape(element_type, handle, context, builder);
1202 }
1203 
ConvertElementTypeAndShape(mlir::Type element_type,const shape_inference::ShapeHandle & handle,shape_inference::InferenceContext * context,mlir::Builder builder)1204 StatusOr<TensorType> ImporterBase::ConvertElementTypeAndShape(
1205     mlir::Type element_type, const shape_inference::ShapeHandle& handle,
1206     shape_inference::InferenceContext* context, mlir::Builder builder) {
1207   if (!context->RankKnown(handle)) {
1208     return mlir::UnrankedTensorType::get(element_type);
1209   }
1210 
1211   // Sentinel for an unknown dimension size. getTensorType interprets any
1212   // negative value as an unknown dimension.
1213   // TODO(jmolloy): Ideally this shouldn't be a local sentinel.
1214   const int64_t kUnknownDim = -1;
1215 
1216   absl::InlinedVector<int64_t, 4> dimensions;
1217   int32_t rank = context->Rank(handle);
1218   dimensions.reserve(rank);
1219   for (int i = 0; i < rank; ++i) {
1220     auto dim_handle = context->Dim(handle, i);
1221     if (!context->ValueKnown(dim_handle))
1222       dimensions.push_back(kUnknownDim);
1223     else
1224       dimensions.push_back(context->Value(dim_handle));
1225   }
1226 
1227   return mlir::RankedTensorType::get(
1228       llvm::makeArrayRef(dimensions.begin(), dimensions.end()), element_type);
1229 }
1230 
ConvertSubtypes(const std::vector<shape_inference::ShapeAndType> * handle_subtypes,shape_inference::InferenceContext * context,mlir::Builder builder)1231 StatusOr<ImporterBase::ElementSubtypes> ImporterBase::ConvertSubtypes(
1232     const std::vector<shape_inference::ShapeAndType>* handle_subtypes,
1233     shape_inference::InferenceContext* context, mlir::Builder builder) {
1234   ElementSubtypes subtypes;
1235   if (!handle_subtypes) return subtypes;
1236 
1237   subtypes.reserve(handle_subtypes->size());
1238   for (const auto& subtype : *handle_subtypes) {
1239     mlir::Type element_type;
1240     TF_RETURN_IF_ERROR(
1241         ::tensorflow::ConvertDataType(subtype.dtype, builder, &element_type));
1242     TF_ASSIGN_OR_RETURN(TensorType type,
1243                         ConvertElementTypeAndShape(element_type, subtype.shape,
1244                                                    context, builder));
1245     subtypes.push_back(type);
1246   }
1247   return subtypes;
1248 }
1249 
ConvertFunctionCallAttribute(const std::string & base_name,const AttrValue & value,NamedAttrList * attributes)1250 Status ImporterBase::ConvertFunctionCallAttribute(const std::string& base_name,
1251                                                   const AttrValue& value,
1252                                                   NamedAttrList* attributes) {
1253   TF_ASSIGN_OR_RETURN(auto func_attr,
1254                       ConvertFunctionCallName(value.func().name()));
1255   if (!func_attr) return Status::OK();
1256   attributes->push_back(builder_.getNamedAttr(base_name, func_attr));
1257 
1258   for (const auto& it : value.func().attr()) {
1259     auto name = absl::StrCat(base_name, ".", it.first);
1260     TF_ASSIGN_OR_RETURN(auto value, ConvertAttributeValue(it.second));
1261     attributes->push_back(builder_.getNamedAttr(name, value));
1262   }
1263   return Status::OK();
1264 }
1265 
ConvertFunctionCallName(const std::string & func_name)1266 StatusOr<mlir::FlatSymbolRefAttr> ImporterBase::ConvertFunctionCallName(
1267     const std::string& func_name) {
1268   // Some ops like XlaHostCompute op uses empty value to represent missing
1269   // functions. Such attribute values should be defined optional in MLIR
1270   // definition.
1271   if (func_name.empty()) return mlir::FlatSymbolRefAttr();
1272 
1273   TF_RETURN_IF_ERROR(ConvertLibFunction(func_name));
1274   auto mlir_func_name = (*tf_name_to_mlir_name_)[func_name];
1275   return builder_.getSymbolRefAttr(mlir_func_name);
1276 }
1277 
ConvertAttributeValue(const AttrValue & value)1278 StatusOr<mlir::Attribute> ImporterBase::ConvertAttributeValue(
1279     const AttrValue& value) {
1280   switch (value.value_case()) {
1281     case AttrValue::kFunc: {
1282       // TODO(b/156546237): Unify kFunc/NameAttrList attribute representation.
1283       // Currently kFunc/NameAttrList attributes in a kList/repeated AttrValue
1284       // will not use this representation. This also doesn't handle empty
1285       // function values like ConvertFunctionCallName method.
1286       NamedAttrList attrs;
1287       for (const auto& func_attr : value.func().attr()) {
1288         TF_ASSIGN_OR_RETURN(
1289             auto attr, ImporterBase::ConvertAttributeValue(func_attr.second));
1290         attrs.push_back(builder_.getNamedAttr(func_attr.first, attr));
1291       }
1292       auto func_attrs = builder_.getDictionaryAttr(attrs);
1293       return mlir::TF::FuncAttr::get(context_, value.func().name(), func_attrs);
1294     }
1295     case AttrValue::kList: {
1296       if (!value.list().func().empty()) {
1297         absl::InlinedVector<mlir::Attribute, 8> attrs;
1298         for (const auto& item : value.list().func()) {
1299           TF_ASSIGN_OR_RETURN(auto attr, ConvertFunctionCallName(item.name()));
1300           if (item.attr_size() != 0)
1301             return errors::Unimplemented(
1302                 "func attributes with non-zero attr.size()");
1303           if (attr) attrs.push_back(attr);
1304         }
1305         return builder_.getArrayAttr(
1306             llvm::makeArrayRef(attrs.begin(), attrs.end()));
1307       }
1308       return ConvertNonFuncAttributeValue(value, &builder_);
1309     }
1310     default:
1311       return ConvertNonFuncAttributeValue(value, &builder_);
1312   }
1313 }
1314 
GetArgsAndRetsFromFunctionBody(const FunctionBody & fbody,absl::InlinedVector<OutputTensor,4> * arg_nodes,absl::InlinedVector<OutputTensor,4> * ret_nodes,absl::InlinedVector<Node *,4> * control_ret_nodes)1315 void ImporterBase::GetArgsAndRetsFromFunctionBody(
1316     const FunctionBody& fbody, absl::InlinedVector<OutputTensor, 4>* arg_nodes,
1317     absl::InlinedVector<OutputTensor, 4>* ret_nodes,
1318     absl::InlinedVector<Node*, 4>* control_ret_nodes) {
1319   arg_nodes->reserve(fbody.arg_nodes.size());
1320   ret_nodes->reserve(fbody.ret_nodes.size());
1321   for (auto arg : fbody.arg_nodes) {
1322     arg_nodes->emplace_back(arg, 0);
1323   }
1324   for (auto ret : fbody.ret_nodes) {
1325     ret_nodes->emplace_back(ret, 0);
1326   }
1327   *control_ret_nodes = fbody.control_ret_nodes;
1328 }
1329 
ConvertLibFunction(llvm::StringRef func_name)1330 Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) {
1331   // If the library function has been converted already, nothing needs to be
1332   // done.
1333   if (tf_name_to_mlir_name_->find(std::string(func_name)) !=
1334       tf_name_to_mlir_name_->end())
1335     return Status::OK();
1336 
1337   std::string mlir_func_name(
1338       function_name_uniquifier_->GetUniqueName(func_name));
1339   (*tf_name_to_mlir_name_)[std::string(func_name)] = mlir_func_name;
1340 
1341   const auto& func_lib = graph_flib_;
1342   const auto* func_def = func_lib.Find(std::string(func_name));
1343   if (func_def == nullptr) {
1344     return errors::FailedPrecondition(
1345         absl::StrCat("Failed to find function '", StringRefToView(func_name),
1346                      "'. The imported TensorFlow GraphDef is ill-formed."));
1347   }
1348 
1349   // Converts the argument and return types to MLIR types.
1350   std::vector<mlir::NamedAttribute> attributes;
1351   attributes.reserve(func_def->attr_size());
1352   for (const auto& name_and_value : func_def->attr()) {
1353     // This is a function definition attribute, so it shouldn't contain
1354     // kFunc attribute and it is treated as normal one.
1355     TF_ASSIGN_OR_RETURN(auto attr,
1356                         ConvertAttributeValue(name_and_value.second));
1357     std::string attr_name =
1358         mangling_util::MangleAttributeName(name_and_value.first);
1359     attributes.push_back(builder_.getNamedAttr(attr_name, attr));
1360   }
1361 
1362   // Checks opdef stateful attribute and import that as Function Attribute
1363   if (func_def->signature().is_stateful()) {
1364     auto stateful_str = mlir::TF::TensorFlowDialect::GetStatefulAttrName();
1365     attributes.push_back(
1366         builder_.getNamedAttr(stateful_str, builder_.getUnitAttr()));
1367   }
1368 
1369   // Checks for an associated custom gradient function. Adds it to the attribute
1370   // list of this function.
1371   auto grad_func_name = func_lib.FindGradient(std::string(func_name));
1372   if (!grad_func_name.empty()) {
1373     TF_RETURN_IF_ERROR(ConvertLibFunction(grad_func_name));
1374     auto mlir_grad_func_name = (*tf_name_to_mlir_name_)[grad_func_name];
1375     auto gradient_attr = builder_.getSymbolRefAttr(mlir_grad_func_name);
1376     auto grad_string = mlir::TF::TensorFlowDialect::GetGradientAttrName();
1377     attributes.push_back(builder_.getNamedAttr(grad_string, gradient_attr));
1378   }
1379 
1380   deferred_functions_.emplace(func_name.str(), attributes);
1381   return Status::OK();
1382 }
1383 
PruneUnreachableNodes(std::unordered_map<string,Node * > * node_name_map)1384 Status ImporterBase::PruneUnreachableNodes(
1385     std::unordered_map<string, Node*>* node_name_map) {
1386   std::unordered_set<const Node*> prune_start;
1387   TF_RETURN_IF_ERROR(GetInputOutputNodes(*node_name_map, &prune_start));
1388 
1389   if (!prune_start.empty()) {
1390     if (PruneForReverseReachability(graph_.get(), prune_start)) {
1391       VLOG(1) << "Pruned unused nodes in graphdef";
1392     } else {
1393       VLOG(1) << "No unused nodes in graphdef to prune";
1394     }
1395   } else {
1396     VLOG(1) << "No output nodes specified, skipping pruning";
1397   }
1398   return Status::OK();
1399 }
1400 
ConvertFeedsToPlaceholders(std::unordered_map<string,Node * > * node_name_map)1401 Status ImporterBase::ConvertFeedsToPlaceholders(
1402     std::unordered_map<string, Node*>* node_name_map) {
1403   // Feeds (edges) are converted into single-output placeholder nodes to
1404   // simplify the conversion process.
1405   TF_ASSIGN_OR_RETURN(auto feeds_by_node, GetFeedsByNode(specs_.inputs));
1406   for (const auto& it : feeds_by_node) {
1407     TensorId tensor = ParseTensorName(it.first);
1408     auto jt = node_name_map->find(std::string(tensor.node()));
1409     if (jt == node_name_map->end()) {
1410       return errors::FailedPrecondition(
1411           absl::StrCat("Graph does not contain node: ", tensor.node()));
1412     }
1413 
1414     Node* node = jt->second;
1415     auto op_name = node->op_def().name();
1416     if (op_name != "Placeholder" && op_name != "LegacyFedInput" &&
1417         op_name != FunctionLibraryDefinition::kArgOp) {
1418       for (const auto& output_tensor : it.second) {
1419         const int index = output_tensor.first;
1420         const ArrayInfo& array_info = output_tensor.second->second;
1421 
1422         DataType dtype = array_info.imported_dtype;
1423         // Uses the existing output type if it isn't specified by the user.
1424         if (dtype == DT_INVALID) {
1425           dtype = node->output_type(index);
1426         }
1427 
1428         TF_ASSIGN_OR_RETURN(
1429             auto placeholder_node_and_removed,
1430             CreatePlaceholderNodeForFeed(array_info.shape, dtype, node, index,
1431                                          *node_name_map));
1432 
1433         Node* placeholder_node = placeholder_node_and_removed.first;
1434         if (placeholder_node->in_edges().empty()) {
1435           graph_->AddControlEdge(graph_->source_node(), placeholder_node,
1436                                  true /* skip test for duplicates */);
1437         }
1438         if (placeholder_node->out_edges().empty()) {
1439           graph_->AddControlEdge(placeholder_node, graph_->sink_node(),
1440                                  true /* skip test for duplicates */);
1441         }
1442         remapped_feeds_[{it.first, index}] = placeholder_node->name();
1443         (*node_name_map)[placeholder_node->name()] = placeholder_node;
1444       }
1445     }
1446   }
1447   return Status::OK();
1448 }
1449 
PrepareConvert(const Graph & graph,std::unique_ptr<GraphDef> graph_def)1450 Status ImporterBase::PrepareConvert(const Graph& graph,
1451                                     std::unique_ptr<GraphDef> graph_def) {
1452   // TODO(fengliuai): Converting to GraphDef and back is the easiest way to
1453   // clone a graph.
1454   // TODO(fengliuai): clone the graph without going to graph_def first.
1455   if (graph_def == nullptr) {
1456     graph_def = std::make_unique<GraphDef>();
1457     graph.ToGraphDef(graph_def.get());
1458   }
1459   graph_ = absl::make_unique<Graph>(graph.flib_def());
1460   GraphConstructorOptions opts;
1461   opts.allow_internal_ops = true;
1462   opts.add_default_attributes = true;
1463   TF_RETURN_IF_ERROR(::tensorflow::ConvertGraphDefToGraph(
1464       opts, std::move(*graph_def), graph_.get()));
1465 
1466   TF_RETURN_IF_ERROR(RemoveBackedges());
1467 
1468   TF_RETURN_IF_ERROR(CopyStackTraces(graph, graph_.get()));
1469 
1470   auto node_name_map = graph_->BuildNodeNameIndex();
1471 
1472   if (specs_.enable_shape_inference) {
1473     // TODO(jpienaar): Remove once infer shapes on import flag is removed.
1474     TF_RETURN_IF_ERROR(AddNodesToShapeRefiner(&node_name_map));
1475   } else {
1476     TF_RETURN_IF_ERROR(ConvertFeedsToPlaceholders(&node_name_map));
1477   }
1478 
1479   // Prune nodes in the graph that are not reachable from the output.
1480   if (specs_.prune_unused_nodes) {
1481     TF_RETURN_IF_ERROR(PruneUnreachableNodes(&node_name_map));
1482   }
1483 
1484   if (!specs_.enable_shape_inference) {
1485     // Re-initialize ordered_nodes_ since we might have modified the graph.
1486     GetReversePostOrder(
1487         *graph_, &ordered_nodes_,
1488         [](const Node* n1, const Node* n2) { return n1->name() < n2->name(); });
1489   }
1490 
1491   return Status::OK();
1492 }
1493 
Convert(llvm::StringRef func_name,mlir::FunctionType func_type,const absl::InlinedVector<OutputTensor,4> & arg_nodes,const absl::InlinedVector<OutputTensor,4> & ret_nodes,const absl::InlinedVector<Node *,4> & control_ret_nodes,llvm::ArrayRef<mlir::NamedAttribute> attrs)1494 Status ImporterBase::Convert(
1495     llvm::StringRef func_name, mlir::FunctionType func_type,
1496     const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
1497     const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
1498     const absl::InlinedVector<Node*, 4>& control_ret_nodes,
1499     llvm::ArrayRef<mlir::NamedAttribute> attrs) {
1500   // TODO(b/122040776): Uses debug info for FunctionDef.
1501   auto function = mlir::FuncOp::create(mlir::UnknownLoc::get(context_),
1502                                        func_name, func_type, attrs);
1503 
1504   module_.push_back(function);
1505   // Seeds the builder with an initial block.
1506   function.addEntryBlock();
1507   builder_ = mlir::OpBuilder(function.getBody());
1508 
1509   // Create the graph operation in which we will convert the individual nodes.
1510   auto graph = builder_.create<mlir::tf_executor::GraphOp>(
1511       function.getLoc(), func_type.getResults());
1512   builder_.createBlock(&graph.body());
1513 
1514   for (const Node* node : ordered_nodes_) {
1515     TF_RETURN_IF_ERROR(ConvertNode(*node));
1516   }
1517 
1518   // Adds the backedges back to the function by creating the source and sink
1519   // pairs.
1520   TF_RETURN_IF_ERROR(AddBackedges());
1521 
1522   TF_RETURN_IF_ERROR(ConvertFunctionArgAndRets(function, graph,
1523                                                func_type.getInputs(), arg_nodes,
1524                                                ret_nodes, control_ret_nodes));
1525 
1526   // TODO(jpienaar): Update post removing shape_refinier_.
1527   if (!specs_.enable_shape_inference) {
1528     // Refine graph's type given more precise fetch.
1529     auto fetch = graph.GetFetch();
1530     bool all_equal = true;
1531     for (auto it :
1532          llvm::zip_first(graph.getResults(), fetch.getOperandTypes())) {
1533       auto rt = std::get<1>(it);
1534       if (rt == std::get<0>(it).getType()) continue;
1535       std::get<0>(it).setType(rt);
1536       all_equal = false;
1537     }
1538     if (!all_equal) {
1539       function.setType(mlir::FunctionType::get(function.getContext(),
1540                                                func_type.getInputs(),
1541                                                graph.getResultTypes()));
1542     }
1543   }
1544 
1545   return Status::OK();
1546 }
1547 
ConvertFunctionArgAndRets(mlir::FuncOp func,mlir::tf_executor::GraphOp graph_op,llvm::ArrayRef<mlir::Type> arg_types,const absl::InlinedVector<OutputTensor,4> & arg_nodes,const absl::InlinedVector<OutputTensor,4> & ret_nodes,const absl::InlinedVector<Node *,4> & control_ret_nodes)1548 Status ImporterBase::ConvertFunctionArgAndRets(
1549     mlir::FuncOp func, mlir::tf_executor::GraphOp graph_op,
1550     llvm::ArrayRef<mlir::Type> arg_types,
1551     const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
1552     const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
1553     const absl::InlinedVector<Node*, 4>& control_ret_nodes) {
1554   // Store the arg/return attributes as a list rather than uniqueuing during
1555   // construction.
1556   llvm::SmallVector<mlir::NamedAttrList, 4> arg_attrs;
1557   arg_attrs.resize(func.getNumArguments());
1558   llvm::SmallVector<mlir::NamedAttrList, 4> ret_attrs;
1559   ret_attrs.resize(func.getNumResults());
1560 
1561   auto set_attributes_on_func = [&](Node* node, int64_t index, bool is_arg) {
1562     for (const auto& node_attr : node->attrs()) {
1563       const auto& key = node_attr.first;
1564       // Only import optional attributes (e.g., those starting with an
1565       // underscore).
1566       if (key.empty() || key[0] != '_') continue;
1567       // Ignore shape inference attributes as shape information is already
1568       // populated in the result type.
1569       if (IsOutputShapesAttribute(node_attr.second, key) ||
1570           IsResourceOutputShapesAttribute(node_attr.second, key))
1571         continue;
1572       TF_ASSIGN_OR_RETURN(auto converted_attr,
1573                           ConvertAttributeValue(node_attr.second));
1574       std::string dialect_attribute = "tf." + key;
1575       if (is_arg) {
1576         arg_attrs[index].set(dialect_attribute, converted_attr);
1577       } else {
1578         func.setResultAttr(index, dialect_attribute, converted_attr);
1579         ret_attrs[index].set(dialect_attribute, converted_attr);
1580       }
1581     }
1582     return Status::OK();
1583   };
1584 
1585   auto* bb = &func.front();
1586   llvm::SmallDenseMap<std::pair<Node*, int>, mlir::Value, 4>
1587       arg_nodes_to_values;
1588   for (int i = 0, e = arg_types.size(); i < e; ++i) {
1589     auto& arg_node = arg_nodes[i];
1590     // The lookup can't fail here: otherwise some nodes in the function haven't
1591     // be converted to mlir operations and don't have a mapping.
1592     mlir::Operation* island = node_values_.find(arg_node.node->id())->second;
1593 
1594     auto bb_arg = bb->getArgument(i);
1595     mlir::Value arg_def = bb_arg;
1596 
1597     if (island->getNumResults() != 2)
1598       return errors::InvalidArgument(
1599           "Only feed output tensors of single output nodes are supported");
1600 
1601     // Collect mapping of OutputTensor to associated block arg.
1602     arg_nodes_to_values.try_emplace({arg_node.node, arg_node.index}, arg_def);
1603     island->getResult(0).replaceAllUsesWith(arg_def);
1604     // Erase control outputs from feed.
1605     auto control_uses = island->getResult(1).getUses();
1606     for (auto& control_use : llvm::make_early_inc_range(control_uses))
1607       control_use.getOwner()->eraseOperand(control_use.getOperandNumber());
1608 
1609     if (!arg_node.node->requested_device().empty())
1610       arg_attrs[i].set("tf.device", builder_.getStringAttr(
1611                                         arg_node.node->requested_device()));
1612 
1613     if (arg_node.node->IsArg()) {
1614       TF_RETURN_IF_ERROR(
1615           set_attributes_on_func(arg_node.node, i, /*is_arg=*/true));
1616     }
1617 
1618     island->dropAllReferences();
1619     island->erase();
1620   }
1621 
1622   llvm::SmallVector<mlir::Value, 8> inst_to_return;
1623   for (auto ret_and_idx : llvm::enumerate(ret_nodes)) {
1624     const auto& ret = ret_and_idx.value();
1625     auto* inst = node_values_[ret.node->id()];
1626     if (ret.node->IsRetval()) {
1627       if (!ret.node->requested_device().empty())
1628         ret_attrs[ret_and_idx.index()].set(
1629             "tf.device", builder_.getStringAttr(ret.node->requested_device()));
1630       TF_RETURN_IF_ERROR(set_attributes_on_func(ret.node, ret_and_idx.index(),
1631                                                 /*is_arg=*/false));
1632       // Lookup the instruction inside the island
1633       auto island_op = llvm::cast<mlir::tf_executor::IslandOp>(inst);
1634       mlir::Operation* inner_op = &island_op.GetBody().front();
1635       // Remove kRetOp or kDeviceRetOp operation and return its operand.
1636       // kRetOp and kDeviceRetOp should have just one operand unless they have
1637       // control dependencies.
1638       if (inner_op->getNumOperands() != 1)
1639         return errors::Unimplemented("Return node with multiple inputs.");
1640       inst_to_return.push_back(inner_op->getOperand(0));
1641       inst->dropAllReferences();
1642       inst->erase();
1643     } else {
1644       // Lookup and use block arg if fetch is a feed.
1645       auto it = arg_nodes_to_values.find({ret.node, ret.index});
1646       if (it != arg_nodes_to_values.end())
1647         inst_to_return.push_back(it->second);
1648       else
1649         inst_to_return.push_back(inst->getResult(ret.index));
1650     }
1651   }
1652 
1653   for (Node* control_ret : control_ret_nodes) {
1654     auto* inst = node_values_[control_ret->id()];
1655     inst_to_return.push_back(*std::prev(inst->result_end()));
1656   }
1657 
1658   // Terminate the function by adding a Fetch operation to terminate the graph
1659   // and a return operation to return the Graph results.
1660   builder_.setInsertionPointToEnd(&graph_op.body().front());
1661   builder_.create<mlir::tf_executor::FetchOp>(graph_op.getLoc(),
1662                                               inst_to_return);
1663   builder_.setInsertionPointToEnd(bb);
1664   builder_.create<mlir::ReturnOp>(mlir::UnknownLoc::get(context_),
1665                                   graph_op.getResults());
1666 
1667   func.setAllArgAttrs(
1668       llvm::to_vector<4>(llvm::map_range(arg_attrs, [&](NamedAttrList& list) {
1669         return list.getDictionary(context_);
1670       })));
1671   func.setAllResultAttrs(
1672       llvm::to_vector<4>(llvm::map_range(ret_attrs, [&](NamedAttrList& list) {
1673         return list.getDictionary(context_);
1674       })));
1675 
1676   return Status::OK();
1677 }
1678 
GetLocation(const Node & node)1679 mlir::Location ImporterBase::GetLocation(const Node& node) {
1680   DVLOG(1) << "Getting location for " << node.name() << " " << &node;
1681   // TODO(b/142400497): What is the semantic contract for locations?
1682   const auto& debug_info = debug_info_.traces();
1683 
1684   // Create a location for node `name` in function `function_name`.
1685   auto create_location = [&](llvm::StringRef name,
1686                              llvm::StringRef function_name) -> mlir::Location {
1687     // Use the catenation of function and node names as the lookup key into the
1688     // debug info. This matches the way that the key is formed on the python
1689     // side.
1690     //
1691     // We also use this as the name for the NameLoc for ops in function, since
1692     // otherwise our names could collide across functions.
1693     // For ops in the main graph, we omit the "@function_name" (which, would be
1694     // just "@" since function_name would be empty) because some code seems to
1695     // depend on the name being this way for correctness.
1696     std::string debug_info_key = (name + "@" + function_name).str();
1697     std::string name_for_name_loc =
1698         function_name.empty() ? name.str() : (name + "@" + function_name).str();
1699     auto name_loc_id = mlir::Identifier::get(name_for_name_loc, context_);
1700 
1701     llvm::SmallVector<mlir::Location, 4> locations;
1702     // Prefer stack traces if available, fallback to debug info if not, and then
1703     // finally to just name.
1704     if (auto stack_trace = node.GetStackTrace()) {
1705       DVLOG(1) << "Stack available for " << node.name();
1706       absl::Span<const StackFrame> frames = stack_trace->ToFrames();
1707       locations.reserve(frames.size());
1708       for (const StackFrame& frame : llvm::reverse(frames)) {
1709         auto file_name = mlir::Identifier::get(frame.file_name, context_);
1710         // Use col 1 as there is no column info in StackTrace.
1711         auto file_line_loc =
1712             mlir::FileLineColLoc::get(file_name, frame.line_number, 1);
1713         locations.push_back(file_line_loc);
1714       }
1715     } else {
1716       DVLOG(1) << "No stack trace for " << node.name();
1717       const auto location_it = debug_info.find(debug_info_key);
1718       if (location_it != debug_info.end()) {
1719         DVLOG(1) << "Available serialized debug info for " << node.name();
1720         // Convert the stack trace to a chain of mlir::CallSiteLocs.
1721         const auto& trace = location_it->second;
1722         locations.reserve(trace.file_line_cols_size());
1723         for (const auto& location : trace.file_line_cols()) {
1724           const auto& file = debug_info_.files(location.file_index());
1725           auto file_name = mlir::Identifier::get(file, context_);
1726           auto file_line_loc = mlir::FileLineColLoc::get(
1727               file_name, location.line(), location.col());
1728           locations.push_back(file_line_loc);
1729         }
1730       }
1731     }
1732 
1733     // If there are no locations in the stack trace, fall back to just a
1734     // NameLoc with no child.
1735     if (locations.empty()) return mlir::NameLoc::get(name_loc_id);
1736 
1737     // Use the front FileLineColLoc to generate a NameLoc.
1738     mlir::Location node_name_loc =
1739         mlir::NameLoc::get(name_loc_id, locations.front());
1740 
1741     // If there are more locations then generate a stack trace, otherwise just
1742     // return the name loc.
1743     auto callsite_locs = llvm::makeArrayRef(locations).drop_front();
1744     return callsite_locs.empty()
1745                ? node_name_loc
1746                : mlir::CallSiteLoc::get(node_name_loc, callsite_locs);
1747   };
1748 
1749   // For NextIteration nodes, location is used to pair source and sink nodes.
1750   // Hence, we use node name as location to keep it unique.
1751   // TODO(prakalps): In future the plan is to use tokens to pair source/sink
1752   // nodes. Then NextIteration nodes would not need to be handled separately.
1753   if (node.type_string() == "NextIteration")
1754     return create_location(node.name(), function_name_for_debug_info_);
1755 
1756   if (node.GetStackTrace())
1757     return create_location(node.name(), function_name_for_debug_info_);
1758 
1759   const auto& node_def = node.def();
1760   auto original_nodes =
1761       node_def.experimental_debug_info().original_node_names();
1762   auto original_funcs =
1763       node_def.experimental_debug_info().original_func_names();
1764 
1765   if (original_nodes.empty()) {
1766     return create_location(node.name(), function_name_for_debug_info_);
1767   } else {
1768     // If the original nodes are defined, then we use them to get a list of
1769     // call sites, and then fuse them to a single fused location, with the name
1770     // of the node_def.
1771     llvm::SmallVector<mlir::Location, 4> node_locations;
1772     node_locations.reserve(original_nodes.size() + 1);
1773 
1774     // store the names in the experimental_debug_info
1775     for (int i = 0, e = original_nodes.size(); i != e; ++i) {
1776       auto node_name = original_nodes[i];
1777       auto func_name = (i < original_funcs.size()) ? original_funcs[i] : "";
1778       node_locations.push_back(create_location(node_name, func_name));
1779     }
1780     // store the name of the node_def
1781     node_locations.push_back(
1782         create_location(node.name(), function_name_for_debug_info_));
1783     return mlir::FusedLoc::get(context_, node_locations);
1784   }
1785 }
1786 
EmitErrorWithLocationStr(const Node & node,const Status & error_status)1787 Status ImporterBase::EmitErrorWithLocationStr(const Node& node,
1788                                               const Status& error_status) {
1789   const mlir::Location location = GetLocation(node);
1790   mlir::emitError(location);
1791   return error_handler_.Combine(error_status);
1792 }
1793 
CreateOperation(const Node & node,llvm::StringRef node_type_name,const mlir::OperationState & result,const llvm::SmallVectorImpl<mlir::Value> & control_operands)1794 mlir::Operation* ImporterBase::CreateOperation(
1795     const Node& node, llvm::StringRef node_type_name,
1796     const mlir::OperationState& result,
1797     const llvm::SmallVectorImpl<mlir::Value>& control_operands) {
1798   // For the tf.executor specific operations (not wrapped in an island), we
1799   // have an extra returned value for the control result, and we concatenate
1800   // control and non-control operands.
1801   mlir::SmallVector<mlir::Type, 4> types(result.types);
1802   types.push_back(mlir::tf_executor::ControlType::get(builder_.getContext()));
1803   mlir::SmallVector<mlir::Value, 4> operands(result.operands);
1804   operands.append(control_operands.begin(), control_operands.end());
1805 
1806   auto loc = result.location;
1807   // Dispatch based on the name and create the appropriate operation.
1808   if (node.IsSwitch()) {
1809     // Switch and _SwitchN both are in switch class, differentiate based on
1810     // op name.
1811     if (node.op_def().name() == "_SwitchN") {
1812       return builder_.create<mlir::tf_executor::SwitchNOp>(loc, types, operands,
1813                                                            result.attributes);
1814     }
1815     return builder_.create<mlir::tf_executor::SwitchOp>(loc, types, operands,
1816                                                         result.attributes);
1817   }
1818   if (node.IsMerge()) {
1819     return builder_.create<mlir::tf_executor::MergeOp>(loc, types, operands,
1820                                                        result.attributes);
1821   }
1822   if (node.IsNextIteration()) {
1823     // NextIteration is a bit special, we create a pair of operations that are
1824     // linked together through a token returned by the source.
1825     // We make use of a separate builder to insert the source at the top of
1826     // the block.
1827     mlir::OpBuilder builder_at_begin(builder_.getBlock(),
1828                                      builder_.getBlock()->begin());
1829     auto source_op =
1830         builder_at_begin.create<mlir::tf_executor::NextIterationSourceOp>(
1831             loc, operands[0].getType(), result.attributes);
1832     return builder_.create<mlir::tf_executor::NextIterationSinkOp>(
1833         loc, source_op.token(), operands, result.attributes);
1834   }
1835   if (node.IsLoopCond()) {
1836     return builder_.create<mlir::tf_executor::LoopCondOp>(loc, types, operands,
1837                                                           result.attributes);
1838   }
1839   if (node.IsEnter()) {
1840     return builder_.create<mlir::tf_executor::EnterOp>(loc, types, operands,
1841                                                        result.attributes);
1842   }
1843   if (node.IsExit()) {
1844     return builder_.create<mlir::tf_executor::ExitOp>(loc, types, operands,
1845                                                       result.attributes);
1846   }
1847   if (node.IsControlTrigger()) {
1848     return builder_.create<mlir::tf_executor::ControlTriggerOp>(
1849         loc, operands, result.attributes);
1850   }
1851   // Regular TensorFlow operation are wrapped in a tf_executor.island.
1852   auto island = builder_.create<mlir::tf_executor::IslandOp>(
1853       result.location, types, control_operands,
1854       mlir::ArrayRef<mlir::NamedAttribute>{});
1855   island.body().push_back(new mlir::Block);
1856   mlir::OpBuilder island_builder =
1857       mlir::OpBuilder::atBlockEnd(&island.GetBody());
1858 
1859   // Create the operation inside the island now.
1860   mlir::Operation* inner_op = island_builder.createOperation(result);
1861 
1862   // Sets operand_segment_sizes or result_segment_sizes attribute to the op.
1863   const auto set_segment_sizes_attr =
1864       [&](const NameRangeMap& arg_ranges,
1865           const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
1866           llvm::StringRef attr_name) {
1867         std::vector<mlir::Attribute> values;
1868         values.reserve(args.size());
1869         for (const auto& arg : args) {
1870           auto range = arg_ranges.at(arg.name());
1871           values.push_back(
1872               island_builder.getI32IntegerAttr(range.second - range.first));
1873         }
1874         auto attr_type =
1875             mlir::VectorType::get(args.size(), builder_.getIntegerType(32));
1876         auto attr_value = mlir::DenseElementsAttr::get(attr_type, values);
1877         inner_op->setAttr(attr_name, attr_value);
1878       };
1879 
1880   if (inner_op->hasTrait<mlir::OpTrait::AttrSizedOperandSegments>() ||
1881       inner_op->hasTrait<mlir::OpTrait::AttrSizedResultSegments>()) {
1882     // The op has multiple variadic operands or results.
1883     // Calculate operand and result segment sizes using the OpDef.
1884     NameRangeMap input_ranges, output_ranges;
1885     // This will fail only if the OpDef is syntactically invalid.
1886     // TODO(jpienaar): Convert this CHECK into a properly propagated error.
1887     TF_CHECK_OK(
1888         NameRangesForNode(node, node.op_def(), &input_ranges, &output_ranges));
1889     if (inner_op->hasTrait<mlir::OpTrait::AttrSizedOperandSegments>()) {
1890       // Add derived "operand_segment_sizes" attr to the created operation.
1891       // TODO(b/146937733): Don't use <void> here.
1892       set_segment_sizes_attr(input_ranges, node.op_def().input_arg(),
1893                              mlir::OpTrait::AttrSizedOperandSegments<
1894                                  void>::getOperandSegmentSizeAttr());
1895     }
1896 
1897     if (inner_op->hasTrait<mlir::OpTrait::AttrSizedResultSegments>()) {
1898       // Add derived "result_segment_sizes" attr to the created operation.
1899       // TODO(b/146937733): Don't use <void> here.
1900       set_segment_sizes_attr(output_ranges, node.op_def().output_arg(),
1901                              mlir::OpTrait::AttrSizedResultSegments<
1902                                  void>::getResultSegmentSizeAttr());
1903     }
1904   }
1905 
1906   mlir::OperationName name = inner_op->getName();
1907   if (!name.getAbstractOperation() &&
1908       // Skip unmodelled ops that are handled differently.
1909       (node_type_name != "_Arg" && node_type_name != "_Retval") &&
1910       // Skip if warning already reported.
1911       (GetUnmodelledOpTypes().insert(name.getStringRef()).second)) {
1912     if (node.op_def().is_stateful()) {
1913       LOG(INFO) << "[potentially conservative] Op type `" << node.type_string()
1914                 << "` is stateful but effects not modelled";
1915     } else {
1916       // See if any resource type is used.
1917       bool resource = false;
1918       std::function<bool(mlir::Type)> record_resource;
1919       record_resource = [&](mlir::Type type) {
1920         if (resource) return true;
1921         if (type.isa<mlir::TF::ResourceType>()) {
1922           resource = true;
1923           return true;
1924         }
1925         if (auto with_subtype =
1926                 type.dyn_cast<mlir::SubElementTypeInterface>()) {
1927           with_subtype.walkSubTypes([&](mlir::Type t) { record_resource(t); });
1928         }
1929         return resource;
1930       };
1931 
1932       for (mlir::Type t : inner_op->getResultTypes())
1933         if (record_resource(t)) break;
1934       for (mlir::Type t : inner_op->getOperandTypes())
1935         if (record_resource(t)) break;
1936       if (resource)
1937         LOG(INFO) << "[potentially conservative] Op type `"
1938                   << node.type_string()
1939                   << "` has resource operands/results but effects not modelled";
1940     }
1941   }
1942 
1943   // Add the terminator for the island
1944   island_builder.create<mlir::tf_executor::YieldOp>(result.location,
1945                                                     inner_op->getResults());
1946   return island.getOperation();
1947 }
1948 
ConvertNode(const Node & node)1949 Status ImporterBase::ConvertNode(const Node& node) {
1950   if (!node.IsOp()) {
1951     // Don't import the pseudo-nodes _SOURCE or _SINK. These are added by
1952     // Graph and don't exist in GraphDef.
1953     return Status::OK();
1954   }
1955 
1956   // If it is a custom OP, its definition should be found in the library. We
1957   // create the MLIR function and insert it to the module if it doesn't exist.
1958   std::string node_type_name = node.type_string();
1959   const auto* func_def = graph_flib_.Find(node_type_name);
1960   bool convert_to_legacy_call = false;
1961   if (func_def) {
1962     TF_RETURN_IF_ERROR(ConvertLibFunction(node_type_name));
1963     node_type_name = (*tf_name_to_mlir_name_)[node_type_name];
1964     convert_to_legacy_call = true;
1965   }
1966 
1967   auto get_full_op_name = [&](const std::string& op_name) {
1968     const char* kTfPrefix = "tf.";
1969     return kTfPrefix + op_name;
1970   };
1971 
1972   std::string op_name = get_full_op_name(node_type_name);
1973   if (back_edge_node_output_.contains(&node)) {
1974     op_name = op_name + ".sink";
1975   }
1976 
1977   mlir::OperationState result(GetLocation(node), op_name);
1978   for (int i = 0; i < node.num_outputs(); ++i) {
1979     // The backedge has been removed, so we shouldn't count the corresponding
1980     // output from the src node when converting to an operation.
1981     if (back_edge_node_output_.contains(&node) &&
1982         back_edge_node_output_[&node] == i) {
1983       continue;
1984     }
1985     TF_ASSIGN_OR_RETURN(auto type, InferOutputType(node, i, builder_));
1986     result.types.push_back(type);
1987   }
1988 
1989   // Surprisingly input edges can be nondeterministically ordered. This
1990   // particularly seems to be the case for the control edges between _SOURCE
1991   // and _SINK that the Graph constructor inserts. Copy the input edges and
1992   // sort the edges, but only the control edges, not data edges!
1993   // TODO(jmolloy): We should probably just ignore _SOURCE and _SINK nodes.
1994   // They'll break roundtripping anyway unless we strip them when converting
1995   // back to graphdef.
1996   absl::InlinedVector<const Edge*, 8> in_edges(node.in_edges().size());
1997   absl::c_copy(node.in_edges(), in_edges.begin());
1998   absl::c_stable_sort(in_edges, [](const Edge* e1, const Edge* e2) {
1999     if (e1->IsControlEdge() && !e2->IsControlEdge()) return false;
2000     if (!e1->IsControlEdge() && e2->IsControlEdge()) return true;
2001     if (e1->IsControlEdge() && e2->IsControlEdge())
2002       return e1->src()->id() < e2->src()->id();
2003     return e1->dst_input() < e2->dst_input();
2004   });
2005 
2006   result.operands.reserve(in_edges.size());
2007 
2008   // Collect the control operands separately, they will be held by the island.
2009   mlir::SmallVector<mlir::Value, 8> control_operands;
2010 
2011   for (const auto* input_edge : in_edges) {
2012     const Node& input_node = *input_edge->src();
2013     if (input_node.IsSource()) {
2014       if (in_edges.size() != 1) {
2015         return errors::FailedPrecondition(
2016             "The node has other inputs besides the _Source node");
2017       }
2018       // We don't import the _SOURCE node.
2019       continue;
2020     }
2021     if (input_node.IsArg() && input_edge->IsControlEdge()) {
2022       // Currently we have not reached consensus as to what TF function
2023       // semantics are (b/133509504). Here we assume that all arguments to a
2024       // function should be available before we start execution of any internal
2025       // node. This makes the control dependencies between function arguments
2026       // and internal nodes redundant, and so we do not import them. The TF
2027       // inliner however assumes no such dependency between function args and
2028       // internal nodes exists, unless explicitly stated. Since we drop control
2029       // dependencies here, it leads to loss of information. If the function is
2030       // inlined later, the inliner would not know of these explicit control
2031       // dependencies present in the original graph.
2032       continue;
2033     }
2034     if (node_values_.find(input_node.id()) == node_values_.end())
2035       return errors::FailedPrecondition(
2036           "Graph not traversed in reverse post order; use seen before def!");
2037     mlir::Operation* inst = node_values_[input_node.id()];
2038     if (input_edge->IsControlEdge())
2039       control_operands.push_back(inst->getResult(inst->getNumResults() - 1));
2040     else
2041       result.operands.push_back(inst->getResult(input_edge->src_output()));
2042   }
2043 
2044   using FuncPairType = std::pair<const std::string*, const AttrValue*>;
2045   std::vector<FuncPairType> funcs;
2046   result.attributes.reserve(node.attrs().size() + 2);
2047   auto abstract_op = result.name.getAbstractOperation();
2048   auto derived_op =
2049       abstract_op
2050           ? abstract_op->getInterface<mlir::DerivedAttributeOpInterface>()
2051           : nullptr;
2052   for (const auto& name_and_value : node.attrs()) {
2053     const auto& attr_name = name_and_value.first;
2054     // Skip adding derived attributes to the generated op.
2055     if (derived_op && derived_op->isDerivedAttribute(attr_name)) continue;
2056     const AttrValue& attr_value = name_and_value.second;
2057 
2058     // Remove _output_shapes attribute that will be added by the exporter.
2059     if (IsOutputShapesAttribute(attr_value, attr_name)) continue;
2060 
2061     if (attr_value.value_case() == AttrValue::kFunc) {
2062       // Attribute iteration order is not defined for protocol buffer Map.
2063       // Process function attributes separately in the lexicographical order to
2064       // have deterministic order of functions in the constructed IR.
2065       funcs.emplace_back(&attr_name, &attr_value);
2066     } else {
2067       TF_ASSIGN_OR_RETURN(auto attr, ConvertAttributeValue(attr_value));
2068       result.attributes.push_back(builder_.getNamedAttr(attr_name, attr));
2069     }
2070   }
2071 
2072   auto comparator = [](const FuncPairType& a, const FuncPairType& b) {
2073     return *a.first < *b.first;
2074   };
2075   std::sort(funcs.begin(), funcs.end(), comparator);
2076   for (const auto& func : funcs) {
2077     TF_RETURN_IF_ERROR(ConvertFunctionCallAttribute(*func.first, *func.second,
2078                                                     &result.attributes));
2079   }
2080 
2081   const auto& node_def = node.def();
2082   // NodeDef can contain partial TF device names. In such cases, canonicalize
2083   // it. Note that in current TF, placer will place full device name to each
2084   // node.
2085   DeviceNameUtils::ParsedName parsed_name;
2086   if (!DeviceNameUtils::ParseFullName(node_def.device(), &parsed_name)) {
2087     return errors::InvalidArgument(
2088         "Op ", op_name, " has invalid device name: ", node_def.device());
2089   }
2090   // Keep the parsed name untouched if the device name is empty.
2091   if (!node_def.device().empty()) {
2092     if (!parsed_name.has_type) {
2093       parsed_name.type = "CPU";
2094       parsed_name.has_type = true;
2095     }
2096     if (!parsed_name.has_id) {
2097       parsed_name.id = 0;
2098       parsed_name.has_id = true;
2099     }
2100   }
2101   result.attributes.push_back(builder_.getNamedAttr(
2102       "device", builder_.getStringAttr(
2103                     DeviceNameUtils::ParsedNameToString(parsed_name))));
2104 
2105   // Map user function calls to LegacyCall ops and add the user function name
2106   // as an attribute.
2107   if (convert_to_legacy_call) {
2108     result.name = mlir::OperationName(get_full_op_name("LegacyCall"), context_);
2109     mlir::SymbolRefAttr val = builder_.getSymbolRefAttr(node_type_name);
2110     result.addAttribute("f", val);
2111 
2112     if (!result.attributes.get("_disable_call_shape_inference")) {
2113       result.addAttribute("_disable_call_shape_inference",
2114                           builder_.getBoolAttr(false));
2115     }
2116   }
2117 
2118   auto composite_control_flow_op = [&](const std::string& name) {
2119     result.name = mlir::OperationName(get_full_op_name(name), context_);
2120     bool stateless = absl::StartsWith(node_type_name, "Stateless");
2121     mlir::BoolAttr val = builder_.getBoolAttr(stateless);
2122     result.attributes.push_back(builder_.getNamedAttr("is_stateless", val));
2123   };
2124 
2125   // Map Case/If/While and StatelessCase/If/While op in TensorFlow to the common
2126   // Case/If/While op in MLIR and add the differentiating attribute.
2127   if (node.IsCaseNode()) composite_control_flow_op("Case");
2128   if (node.IsIfNode()) composite_control_flow_op("If");
2129   if (node.IsWhileNode()) {
2130     composite_control_flow_op("While");
2131     auto* output_shapes = node.attrs().Find("output_shapes");
2132     if (output_shapes && !output_shapes->list().shape().empty())
2133       result.attributes.push_back(
2134           builder_.getNamedAttr("shape_invariant", builder_.getUnitAttr()));
2135   }
2136 
2137   // Register the mapping between the TF node and the newly created operation.
2138   node_values_[node.id()] =
2139       CreateOperation(node, node_type_name, result, control_operands);
2140   return Status::OK();
2141 }
2142 
2143 // Add the backedges to the CFG. Given a backedge, we replace the original
2144 // source and destination operations by two new operations. Most of the
2145 // fields of the replacements are copied from the original operations.
2146 // However,
2147 // - for the src operation, one output is inserted to the front of the output
2148 //   list. The type of the output is set to the type of the non-control result
2149 //   of the dst operation, and
2150 // - for the dst operation, one operand is inserted to the front of the
2151 //   operand list. This operand is using the first result of the src
2152 //   operation.
2153 // TODO(fengliuai): Preserve the order of the results and operands if
2154 // necessary.
AddBackedges()2155 Status ImporterBase::AddBackedges() {
2156   for (auto it : back_edge_dst_inputs_) {
2157     BackEdge& edge = it.second;
2158     if (!edge.src->IsNextIteration() || !edge.dst->IsMerge()) {
2159       return errors::FailedPrecondition(
2160           "Invalid backedge; should be from NextIteration to Merge!");
2161     }
2162     auto* sink = node_values_[edge.src->id()];
2163     auto* dst = node_values_[edge.dst->id()];
2164     TF_RETURN_IF_ERROR(AddBackedge(sink, dst, edge.dst_input));
2165   }
2166   return Status::OK();
2167 }
2168 
AddBackedge(mlir::Operation * sink,mlir::Operation * dst,int dst_input)2169 Status ImporterBase::AddBackedge(mlir::Operation* sink, mlir::Operation* dst,
2170                                  int dst_input) {
2171   // Get the NextIteration.Source operation from the token operand of the sink.
2172   mlir::Operation* source = sink->getOperand(0).getDefiningOp();
2173 
2174   // Adds the "source" to the operands of the dst by creating a new dst
2175   // operation.
2176   mlir::OperationState state(dst->getLoc(), dst->getName());
2177   auto num_operands = dst->getNumOperands();
2178   state.operands.reserve(num_operands + 1);
2179   for (int input = 0, e = num_operands + 1; input != e; ++input) {
2180     if (input < dst_input) {
2181       state.operands.push_back(dst->getOperand(input));
2182     } else if (input == dst_input) {
2183       state.operands.push_back(source->getResult(0));
2184     } else {
2185       state.operands.push_back(dst->getOperand(input - 1));
2186     }
2187   }
2188   state.attributes.assign(dst->getAttrs().begin(), dst->getAttrs().end());
2189   state.types.assign(dst->getResultTypes().begin(),
2190                      dst->getResultTypes().end());
2191   builder_.setInsertionPoint(dst);
2192   auto* new_dst = builder_.createOperation(state);
2193 
2194   // Replaces the output uses of the old operation by the corresponding
2195   // result of the new operation, and deletes the old operation.
2196   for (unsigned i = 0, e = dst->getNumResults(); i != e; ++i) {
2197     auto new_output = new_dst->getResult(i);
2198     dst->getResult(i).replaceAllUsesWith(new_output);
2199   }
2200   dst->dropAllReferences();
2201   dst->erase();
2202   return Status::OK();
2203 }
2204 
InferLibFunctionType(const FunctionBody & fbody)2205 StatusOr<mlir::FunctionType> ImporterBase::InferLibFunctionType(
2206     const FunctionBody& fbody) {
2207   mlir::Builder builder(context_);
2208 
2209   // The FunctionBody contains a graph with a single-output _Arg node for each
2210   // function argument and a single-input _Retval node for each function return
2211   // value.
2212   //
2213   // We already populated the ShapeRefiner with all the information about the
2214   // shapes of these graph edges, so we just query it to build the corresponding
2215   // MLIR function type signature.
2216 
2217   llvm::SmallVector<mlir::Type, 4> arg_types;
2218   if (specs_.inputs.empty()) {
2219     arg_types.reserve(fbody.arg_types.size());
2220     for (auto arg : fbody.arg_nodes) {
2221       // Find node in the graph using the node id instead of using `arg`
2222       // directly because the graph has been cloned.
2223       auto* node = graph_->FindNodeId(arg->id());
2224       TF_ASSIGN_OR_RETURN(auto type,
2225                           InferOutputType(*node, /*idx=*/0, builder));
2226       arg_types.push_back(type);
2227     }
2228   } else {
2229     arg_types.reserve(fbody.arg_types.size());
2230     for (const auto& it : llvm::enumerate(specs_.inputs)) {
2231       mlir::Type element_type;
2232       const auto& node_info = it.value().second;
2233       DataType dtype = node_info.imported_dtype;
2234       // Uses the existing output type of the arg node if the data type of the
2235       // the node isn't specified through the import configuration.
2236       if (dtype == DT_INVALID) {
2237         auto arg = fbody.arg_nodes[it.index()];
2238         auto* node = graph_->FindNodeId(arg->id());
2239         dtype = node->output_type(0);
2240         if (dtype == DT_INVALID) {
2241           return errors::InvalidArgument("Input ", it.index(),
2242                                          "has invalid data type");
2243         }
2244       }
2245       TF_RETURN_IF_ERROR(
2246           ::tensorflow::ConvertDataType(dtype, builder, &element_type));
2247       if (node_info.shape.unknown_rank()) {
2248         arg_types.push_back(mlir::UnrankedTensorType::get(element_type));
2249       } else {
2250         llvm::SmallVector<int64_t, 4> shape;
2251         TF_RETURN_IF_ERROR(ConvertToMlirShape(node_info.shape, &shape));
2252         arg_types.push_back(mlir::RankedTensorType::get(shape, element_type));
2253       }
2254     }
2255   }
2256 
2257   llvm::SmallVector<mlir::Type, 4> ret_types;
2258   ret_types.reserve(fbody.ret_types.size());
2259   for (auto ret : fbody.ret_nodes) {
2260     // Find node in the graph using the node id instead of using `ret` directly
2261     // because the graph has been cloned.
2262     auto* node = graph_->FindNodeId(ret->id());
2263     TF_ASSIGN_OR_RETURN(auto type, InferInputType(*node, /*idx=*/0, builder));
2264     ret_types.push_back(type);
2265   }
2266 
2267   return builder.getFunctionType(arg_types, ret_types);
2268 }
2269 
2270 // Stateful helper class to import a TensorFlow model expressed in GraphDef into
2271 // an MLIR Module.
2272 //
2273 // The nodes defined in the graph are converted to a function called
2274 // 'func_name'. All library function definitions are converted to MLIR functions
2275 // in the module.
2276 class GraphDefImporter : public ImporterBase {
2277  public:
2278   // Main entry point: converts the given graph to an MLIR Module.
2279   static StatusOr<mlir::OwningModuleRef> Convert(
2280       mlir::MLIRContext* context, const Graph& graph,
2281       const GraphDebugInfo& debug_info,
2282       const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs,
2283       std::unordered_map<std::string, std::string>& tf_name_to_mlir_name);
2284 
2285  private:
GraphDefImporter(const FunctionLibraryDefinition & flib,const GraphDebugInfo & debug_info,const GraphImportConfig & specs,mlir::ModuleOp module,std::unordered_map<std::string,std::string> * tf_name_to_mlir_name,NameUniquifier * function_name_uniquifier)2286   explicit GraphDefImporter(
2287       const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info,
2288       const GraphImportConfig& specs, mlir::ModuleOp module,
2289       std::unordered_map<std::string, std::string>* tf_name_to_mlir_name,
2290       NameUniquifier* function_name_uniquifier)
2291       : ImporterBase(flib, debug_info, specs, module, tf_name_to_mlir_name,
2292                      function_name_uniquifier) {}
2293 
2294   // Returns the function signature of the main function of converted MLIR
2295   // module, the input nodes and output nodes. The type and shape information
2296   // for the function arguments are read from `specs`, but the type and shape
2297   // information for the function returns are inferred by the shape refiner in
2298   // ImporterBase.
2299   StatusOr<mlir::FunctionType> InferMainFunctionType(
2300       const GraphImportConfig& specs, mlir::MLIRContext* context,
2301       absl::InlinedVector<OutputTensor, 4>* arg_nodes,
2302       absl::InlinedVector<OutputTensor, 4>* ret_nodes);
2303 
2304   // Returns the function signature of the main function, alongside input and
2305   // output nodes, for function graphs. Arguments and return values are
2306   // determined by node op type. Type and shape information of the function are
2307   // inferred by the shape refiner in ImporterBase.
2308   StatusOr<mlir::FunctionType> GetArgsRetsAndTypesFromFunctionGraph(
2309       mlir::MLIRContext* context,
2310       absl::InlinedVector<OutputTensor, 4>* arg_nodes,
2311       absl::InlinedVector<OutputTensor, 4>* ret_nodes);
2312 
2313   // Finds the graph's target nodes/function's control ret nodes based on
2314   // supplied node names in `control_outputs`. If `control_outputs` are not
2315   // unique or a control ret node is missing, an error will be returned.
2316   Status GetControlRetsFromGraph(
2317       llvm::ArrayRef<std::string> control_outputs,
2318       absl::InlinedVector<Node*, 4>* control_ret_nodes);
2319 };
2320 
Convert(mlir::MLIRContext * context,const Graph & graph,const GraphDebugInfo & debug_info,const FunctionLibraryDefinition & flib_def,const GraphImportConfig & specs,std::unordered_map<std::string,std::string> & tf_name_to_mlir_name)2321 StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
2322     mlir::MLIRContext* context, const Graph& graph,
2323     const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def,
2324     const GraphImportConfig& specs,
2325     std::unordered_map<std::string, std::string>& tf_name_to_mlir_name) {
2326   LoadImporterDialects(*context);
2327   mlir::OwningModuleRef module =
2328       mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
2329   NameUniquifier function_name_uniquifier(flib_def);
2330 
2331   // importer.PrepareConvert below will attemp to clone the original `graph`
2332   // via conversion to the graph def first. Convert graph to graph_def here
2333   // first and avoid extra copies later.
2334   auto graph_def = std::make_unique<GraphDef>();
2335   graph.ToGraphDef(graph_def.get());
2336 
2337   static std::atomic<uint32> counter(0);
2338   uint32 current_file_prefix = counter++;
2339   const auto* graph_crash_handle = crash_analysis::ReportProtoDataOnCrash(
2340       absl::StrCat(current_file_prefix, "_mlir_import_graph.pbtxt"),
2341       *graph_def);
2342   auto reachable_flib = flib_def.ReachableDefinitions(*graph_def);
2343   const auto* flib_crash_handle = crash_analysis::ReportProtoDataOnCrash(
2344       absl::StrCat(current_file_prefix, "_mlir_import_flib.pbtxt"),
2345       reachable_flib.ToProto());
2346 
2347   auto scope_exit = llvm::make_scope_exit([&]() {
2348     crash_analysis::RemoveReportData(graph_crash_handle);
2349     crash_analysis::RemoveReportData(flib_crash_handle);
2350   });
2351 
2352   VLOG(1) << "Importing: "
2353           << ::tensorflow::DumpGraphToFile("tf_mlir_importer_base", graph,
2354                                            &flib_def);
2355 
2356   GraphDefImporter importer(flib_def, debug_info, specs, module.get(),
2357                             &tf_name_to_mlir_name, &function_name_uniquifier);
2358 
2359   TF_RETURN_IF_ERROR(importer.PrepareConvert(graph, std::move(graph_def)));
2360 
2361   mlir::FunctionType func_type;
2362   absl::InlinedVector<OutputTensor, 4> arg_nodes;
2363   absl::InlinedVector<OutputTensor, 4> ret_nodes;
2364   absl::InlinedVector<Node*, 4> control_ret_nodes;
2365   llvm::SmallVector<mlir::NamedAttribute, 1> attrs;
2366   if (specs.graph_as_function) {
2367     if (specs.prune_unused_nodes || !specs.inputs.empty() ||
2368         !specs.outputs.empty())
2369       return errors::InvalidArgument(
2370           "Pruning of graph is currently unsupported when the main graph is "
2371           "converted to a function.");
2372 
2373     TF_ASSIGN_OR_RETURN(func_type,
2374                         importer.GetArgsRetsAndTypesFromFunctionGraph(
2375                             context, &arg_nodes, &ret_nodes));
2376 
2377     TF_RETURN_IF_ERROR(importer.GetControlRetsFromGraph(specs.control_outputs,
2378                                                         &control_ret_nodes));
2379 
2380     mlir::Builder b(context);
2381     std::string s;
2382     llvm::raw_string_ostream ss(s);
2383     auto node_name = [&](const OutputTensor& tensor) {
2384       ss << tensor.node->name();
2385     };
2386     llvm::interleave(arg_nodes, ss, node_name, ",");
2387     auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str()));
2388     s.clear();
2389     llvm::interleave(ret_nodes, ss, node_name, ",");
2390     auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str()));
2391     s.clear();
2392     llvm::interleave(specs.control_outputs, ss, ",");
2393     auto control_outputs =
2394         b.getNamedAttr("control_outputs", b.getStringAttr(ss.str()));
2395 
2396     // Under `graph_as_function` mode, `tf.entry_function` is always set as it
2397     // is assumed feed, fetch, and target nodes are set correctly.
2398     attrs.push_back(b.getNamedAttr(
2399         "tf.entry_function",
2400         b.getDictionaryAttr({inputs, outputs, control_outputs})));
2401   } else {
2402     // Collects the argument and return nodes by looking up the node names
2403     // specified by the user.
2404     TF_ASSIGN_OR_RETURN(func_type, importer.InferMainFunctionType(
2405                                        specs, context, &arg_nodes, &ret_nodes));
2406 
2407     TF_RETURN_IF_ERROR(importer.GetControlRetsFromGraph(specs.control_outputs,
2408                                                         &control_ret_nodes));
2409 
2410     // TODO(prakalps): Refactor to keep tf.entry_function attribute encoding and
2411     // decoding in a centralized place.
2412     // Record the input and output mapping.
2413     if (!specs.inputs.empty() || !specs.outputs.empty() ||
2414         !specs.control_outputs.empty()) {
2415       mlir::Builder b(context);
2416       std::string s;
2417       llvm::raw_string_ostream ss(s);
2418       llvm::interleave(
2419           specs.inputs, ss,
2420           [&](const std::pair<std::string, ArrayInfo>& v) { ss << v.first; },
2421           ",");
2422       auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str()));
2423       s.clear();
2424       llvm::interleave(specs.outputs, ss, ",");
2425       auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str()));
2426       s.clear();
2427       llvm::interleave(specs.control_outputs, ss, ",");
2428       auto control_outputs =
2429           b.getNamedAttr("control_outputs", b.getStringAttr(ss.str()));
2430 
2431       attrs.push_back(b.getNamedAttr(
2432           "tf.entry_function",
2433           b.getDictionaryAttr({inputs, outputs, control_outputs})));
2434     }
2435   }
2436 
2437   // Record version info.
2438   PopulateTfVersions(module.get(), graph.versions());
2439 
2440   const auto& graph_func_name = specs.graph_func_name.empty()
2441                                     ? kImportModelDefaultGraphFuncName
2442                                     : specs.graph_func_name;
2443   TF_RETURN_IF_ERROR(importer.ImporterBase::Convert(graph_func_name, func_type,
2444                                                     arg_nodes, ret_nodes,
2445                                                     control_ret_nodes, attrs));
2446   TF_RETURN_IF_ERROR(importer.ImporterBase::ConvertDeferredFunctions());
2447 
2448   // Mark main function public, others private.
2449   for (auto function : module.get().getOps<mlir::FuncOp>()) {
2450     auto visibility = function.getName() == graph_func_name
2451                           ? mlir::FuncOp::Visibility::Public
2452                           : mlir::FuncOp::Visibility::Private;
2453     function.setVisibility(visibility);
2454   }
2455   return module;
2456 }
2457 
InferMainFunctionType(const GraphImportConfig & specs,mlir::MLIRContext * context,absl::InlinedVector<OutputTensor,4> * arg_nodes,absl::InlinedVector<OutputTensor,4> * ret_nodes)2458 StatusOr<mlir::FunctionType> GraphDefImporter::InferMainFunctionType(
2459     const GraphImportConfig& specs, mlir::MLIRContext* context,
2460     absl::InlinedVector<OutputTensor, 4>* arg_nodes,
2461     absl::InlinedVector<OutputTensor, 4>* ret_nodes) {
2462   // Find all the input nodes and output nodes.
2463   // Feeds have been remapped to single output nodes (Placeholder), so an exact
2464   // name match is sufficient.
2465   absl::flat_hash_map<absl::string_view, int> inputs;
2466   for (auto input_and_idx : llvm::enumerate(specs.inputs)) {
2467     TensorId tensor = ParseTensorName(input_and_idx.value().first);
2468     auto remapped_it = remapped_feeds_.find(tensor);
2469     if (remapped_it != remapped_feeds_.end()) {
2470       inputs.insert({remapped_it->second, input_and_idx.index()});
2471     } else {
2472       inputs.insert({tensor.node(), input_and_idx.index()});
2473     }
2474   }
2475 
2476   absl::flat_hash_set<absl::string_view> output_node_names;
2477   std::vector<TensorId> outputs;
2478   output_node_names.reserve(specs.outputs.size());
2479   for (const auto& output : specs.outputs) {
2480     TensorId tensor = ParseTensorName(output);
2481     auto remapped_it = remapped_feeds_.find(tensor);
2482     if (remapped_it != remapped_feeds_.end()) {
2483       output_node_names.insert(remapped_it->second);
2484       outputs.push_back({remapped_it->second, 0});
2485     } else {
2486       output_node_names.insert(tensor.node());
2487       outputs.push_back(tensor);
2488     }
2489   }
2490 
2491   if (!inputs.empty() || !outputs.empty()) {
2492     arg_nodes->resize(inputs.size());
2493     ret_nodes->resize(outputs.size());
2494 
2495     for (Node* n : GetOrderedNodes()) {
2496       // Handle inputs/arguments.
2497       auto input_it = inputs.find(n->name());
2498       if (input_it != inputs.end()) {
2499         (*arg_nodes)[input_it->second] = {n, 0};
2500       }
2501 
2502       // Handle outputs/returns.
2503       if (output_node_names.contains(n->name())) {
2504         for (int i = 0, e = outputs.size(); i != e; ++i) {
2505           TensorId tensor = outputs[i];
2506           if (n->name() != tensor.node()) continue;
2507           (*ret_nodes)[i] = {n, tensor.index()};
2508         }
2509       }
2510     }
2511   }
2512 
2513   // Starts to construct the function type.
2514   mlir::Builder builder(context);
2515   llvm::SmallVector<mlir::Type, 4> arg_types;
2516   arg_types.reserve(specs.inputs.size());
2517   int i = 0;
2518   for (const auto& it : specs.inputs) {
2519     Node* arg_node = arg_nodes->at(i).node;
2520     if (arg_node == nullptr) {
2521       return errors::InvalidArgument("Input ", it.first,
2522                                      " was not found in graph");
2523     }
2524     mlir::Type element_type;
2525     const auto& node_info = it.second;
2526     DataType imported_dtype = node_info.imported_dtype;
2527     // Uses the existing output type of the arg node if the data type of the
2528     // the node isn't specified through the import configuration.
2529     if (imported_dtype == DT_INVALID) {
2530       imported_dtype = arg_node->output_type(0);
2531       if (imported_dtype == DT_INVALID) {
2532         return errors::InvalidArgument("Input ", i, "has invalid data type");
2533       }
2534     }
2535     TF_RETURN_IF_ERROR(
2536         ::tensorflow::ConvertDataType(imported_dtype, builder, &element_type));
2537     if (node_info.shape.unknown_rank()) {
2538       arg_types.push_back(mlir::UnrankedTensorType::get(element_type));
2539     } else {
2540       llvm::SmallVector<int64_t, 4> shape;
2541       TF_RETURN_IF_ERROR(ConvertToMlirShape(node_info.shape, &shape));
2542       arg_types.push_back(mlir::RankedTensorType::get(shape, element_type));
2543     }
2544     i++;
2545   }
2546 
2547   llvm::SmallVector<mlir::Type, 4> ret_types;
2548   ret_types.reserve(specs.outputs.size());
2549   for (int i = 0, e = specs.outputs.size(); i != e; ++i) {
2550     if (ret_nodes->at(i).node == nullptr) {
2551       return errors::InvalidArgument("Output ", specs.outputs[i],
2552                                      " was not found in graph");
2553     }
2554   }
2555   for (const auto& ret : *ret_nodes) {
2556     if (ret.node->num_outputs() <= ret.index) {
2557       return errors::InvalidArgument("Invalid output index ", ret.index,
2558                                      " specified for node: ", ret.node->name());
2559     }
2560     TF_ASSIGN_OR_RETURN(auto type,
2561                         InferOutputType(*ret.node, ret.index, builder));
2562     ret_types.push_back(type);
2563   }
2564 
2565   return builder.getFunctionType(arg_types, ret_types);
2566 }
2567 
2568 StatusOr<mlir::FunctionType>
GetArgsRetsAndTypesFromFunctionGraph(mlir::MLIRContext * context,absl::InlinedVector<OutputTensor,4> * arg_nodes,absl::InlinedVector<OutputTensor,4> * ret_nodes)2569 GraphDefImporter::GetArgsRetsAndTypesFromFunctionGraph(
2570     mlir::MLIRContext* context, absl::InlinedVector<OutputTensor, 4>* arg_nodes,
2571     absl::InlinedVector<OutputTensor, 4>* ret_nodes) {
2572   auto add_node = [](Node* node, absl::InlinedVector<OutputTensor, 4>* nodes) {
2573     auto* attr = node->attrs().Find("index");
2574     if (!attr)
2575       return errors::InvalidArgument(node->type_string(), " node '",
2576                                      node->name(),
2577                                      "' is missing attribute 'index'");
2578 
2579     auto index = attr->i();
2580     const int num_nodes = nodes->size();
2581     if (num_nodes < index + 1) nodes->resize(index + 1);
2582 
2583     if ((*nodes)[index].node != nullptr)
2584       return errors::InvalidArgument(node->type_string(), " node '",
2585                                      node->name(), "' has attribute 'index' ",
2586                                      index, " that conflicts with node '",
2587                                      (*nodes)[index].node->name(), "'");
2588     (*nodes)[index] = {node, 0};
2589 
2590     return Status::OK();
2591   };
2592 
2593   // Collect arg and ret nodes from graph.
2594   for (auto* node : GetOrderedNodes())
2595     if (node->IsArg())
2596       TF_RETURN_IF_ERROR(add_node(node, arg_nodes));
2597     else if (node->IsRetval())
2598       TF_RETURN_IF_ERROR(add_node(node, ret_nodes));
2599 
2600   // Collect arg and ret types and create function type.
2601   mlir::Builder builder(context);
2602   llvm::SmallVector<mlir::Type, 4> arg_types;
2603   arg_types.reserve(arg_nodes->size());
2604   for (auto arg_node_and_idx : llvm::enumerate(*arg_nodes)) {
2605     auto& arg_node = arg_node_and_idx.value();
2606     if (arg_node.node == nullptr)
2607       return errors::InvalidArgument("Graph missing _Arg at index ",
2608                                      arg_node_and_idx.index());
2609 
2610     TF_ASSIGN_OR_RETURN(auto type,
2611                         InferOutputType(*arg_node.node, /*idx=*/0, builder));
2612     arg_types.push_back(type);
2613   }
2614 
2615   llvm::SmallVector<mlir::Type, 4> ret_types;
2616   ret_types.reserve(ret_nodes->size());
2617   for (auto ret_node_and_idx : llvm::enumerate(*ret_nodes)) {
2618     auto& ret_node = ret_node_and_idx.value();
2619     if (ret_node.node == nullptr)
2620       return errors::InvalidArgument("Graph missing _Retval at index ",
2621                                      ret_node_and_idx.index());
2622 
2623     TF_ASSIGN_OR_RETURN(auto type,
2624                         InferInputType(*ret_node.node, /*idx=*/0, builder));
2625     ret_types.push_back(type);
2626   }
2627 
2628   return builder.getFunctionType(arg_types, ret_types);
2629 }
2630 
GetControlRetsFromGraph(llvm::ArrayRef<std::string> control_outputs,absl::InlinedVector<Node *,4> * control_ret_nodes)2631 Status GraphDefImporter::GetControlRetsFromGraph(
2632     llvm::ArrayRef<std::string> control_outputs,
2633     absl::InlinedVector<Node*, 4>* control_ret_nodes) {
2634   if (control_outputs.empty()) return Status::OK();
2635 
2636   llvm::SmallDenseMap<llvm::StringRef, int32_t> controls_to_idx;
2637   for (auto control_and_idx : llvm::enumerate(control_outputs))
2638     controls_to_idx.insert({control_and_idx.value(), control_and_idx.index()});
2639 
2640   if (controls_to_idx.size() != control_outputs.size())
2641     return errors::InvalidArgument("Control outputs must be unique");
2642 
2643   control_ret_nodes->resize(controls_to_idx.size());
2644 
2645   for (auto* node : GetOrderedNodes()) {
2646     auto it = controls_to_idx.find(node->name());
2647     if (it != controls_to_idx.end()) (*control_ret_nodes)[it->second] = node;
2648   }
2649 
2650   for (auto node_and_name : llvm::zip(*control_ret_nodes, control_outputs))
2651     if (std::get<0>(node_and_name) == nullptr)
2652       return errors::InvalidArgument(
2653           "Control output '", std::get<1>(node_and_name), "' is missing");
2654 
2655   return Status::OK();
2656 }
2657 
2658 // Stateful helper class to import a TensorFlow model expressed in SavedModel
2659 // into an MLIR Module.
2660 class SavedModelObjectGraphImporter : public ImporterBase {
2661  public:
2662   // Main entry point: converts all functions in the given meta graph to an MLIR
2663   // Module.
2664   static StatusOr<mlir::OwningModuleRef> Convert(
2665       SavedModelV2Bundle* saved_model, absl::Span<std::string> exported_names,
2666       mlir::MLIRContext* context, bool add_default_attributes);
2667 
2668  private:
SavedModelObjectGraphImporter(const FunctionLibraryDefinition & flib,const GraphDebugInfo & debug_info,const GraphImportConfig & specs,mlir::ModuleOp module,std::unordered_map<std::string,std::string> * tf_name_to_mlir_name,NameUniquifier * function_name_uniquifier)2669   explicit SavedModelObjectGraphImporter(
2670       const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info,
2671       const GraphImportConfig& specs, mlir::ModuleOp module,
2672       std::unordered_map<std::string, std::string>* tf_name_to_mlir_name,
2673       NameUniquifier* function_name_uniquifier)
2674       : ImporterBase(flib, debug_info, specs, module, tf_name_to_mlir_name,
2675                      function_name_uniquifier) {}
2676 };
2677 
2678 // Determines the names used to reference objects in the SavedObjectGraph.
2679 class ObjectNames {
2680  public:
2681   explicit ObjectNames(const SavedObjectGraph& object_graph,
2682                        absl::Span<std::string> exported_names);
2683 
2684   // Gets the names that external users of the SavedModel can use to refer to
2685   // this node.
2686   llvm::ArrayRef<llvm::StringRef> GetExportedNames(int node_id) const;
2687 
2688   // Gets the name in the module symbol table for this node.
2689   // This name is only used for internal IR references.
2690   llvm::StringRef GetSymbolTableName(int node_id) const;
2691 
2692  private:
2693   // In the absence of any other information, use this name as the symbol table
2694   // name for this node.
2695   std::string GetDefaultSymbolTableName(int node_id) const;
2696   // Determines if a name is exported.
2697   bool IsExported(const std::string& name);
2698   // Main object graph traversal function.
2699   void RecursivelyVisitObjectGraph(int node_id);
2700   // Gets a stable StringRef from a std::string.
2701   llvm::StringRef SaveString(const std::string& s) const;
2702 
2703   // The object graph we are traversing.
2704   const SavedObjectGraph& object_graph_;
2705   // The set of names to export. Empty means "export all".
2706   std::unordered_set<std::string> names_to_export_;
2707 
2708   // When we recursively follow the object graph tree structure from the root,
2709   // we track its path in the object graph by pushing and popping from here
2710   // during traversal.
2711   llvm::SmallVector<std::string, 8> path_segments_;
2712   // The set of node_id's that are on the current DFS stack.
2713   // For cyclic object graphs, this prevents infinite recursion.
2714   std::unordered_set<int> on_stack_nodes_;
2715 
2716   // Key: node_id.
2717   // Value: all object names that node_id appears as.
2718   // Each object name corresponds to a unique path from the root of the object
2719   // graph.
2720   // The common intuitive case is when there is only one name for a given
2721   // object, which corresponds to the object graph being a tree.
2722   //
2723   // But, there cases where the object graph is a general graph. For
2724   // example, this happens commonly in Keras models, where `foo.bar` is
2725   // also reachable via the name `keras_api.foo.bar`.
2726   // Cycles are possible too.
2727   absl::flat_hash_map<int, std::vector<std::string>> object_names_;
2728 
2729   // Key: node_id
2730   // Value: all names that this object is exported as
2731   absl::flat_hash_map<int, llvm::SmallVector<llvm::StringRef, 1>>
2732       exported_names_;
2733   // Key: node_id
2734   // Value: pretty symbol table name to use for internal references to this
2735   // object.
2736   absl::flat_hash_map<int, llvm::StringRef> pretty_symbol_table_name_;
2737 
2738   // Stable strings we can take StringRef's into. Used only by the SaveString
2739   // method.
2740   mutable std::unordered_set<std::string> saved_strings_;
2741 };
2742 
ObjectNames(const SavedObjectGraph & object_graph,absl::Span<std::string> exported_names)2743 ObjectNames::ObjectNames(const SavedObjectGraph& object_graph,
2744                          absl::Span<std::string> exported_names)
2745     : object_graph_(object_graph),
2746       names_to_export_(exported_names.begin(), exported_names.end()) {
2747   // Visit all reachable nodes from the root of the object graph.
2748   // This builds up object_names_ to contain all names like `foo.bar` that a
2749   // particular node in the graph can be reached from.
2750   RecursivelyVisitObjectGraph(/*node_id=*/0);
2751 
2752   // Populate the exported_names_ map.
2753   // TODO(silvasean): Diagnose typos in exported names?
2754   for (auto& kv : object_names_) {
2755     // Make object names map independent of our particular choice of object
2756     // graph traversal.
2757     std::sort(kv.second.begin(), kv.second.end(),
2758               [](absl::string_view a, absl::string_view b) {
2759                 // The sort order here influences the "pretty name" we assign
2760                 // below. We want the most debuggable name to be first.
2761                 //
2762                 // Debuggability heuristics:
2763                 // 1. Names that end in digits are likely to be internal aliases
2764                 // to the "real" names.
2765                 // 2. Longer names are more likely to be internal aliases.
2766                 //
2767                 // Example set of object names created by Keras for the weight
2768                 // matrix of a fully connected layer on a trivial FC mnist
2769                 // model:
2770                 // - `model.layer-1.kernel` (this is the "best" name)
2771                 // - `model.keras_api.layers.1.kernel`
2772                 // - `model.variables.0`
2773                 // - `model.keras_api.layers.1.keras_api.trainable_variables.0`
2774                 // - ... 10 more long aliases ending in digits ...
2775                 return std::make_tuple(isdigit(a.back()), a.size(), a) <
2776                        std::make_tuple(isdigit(b.back()), b.size(), b);
2777               });
2778     for (const std::string& name : kv.second) {
2779       if (IsExported(name)) {
2780         exported_names_[kv.first].push_back(SaveString(name));
2781       }
2782     }
2783   }
2784   // Create "pretty" symbol table names for nodes where that is applicable.
2785   // We could make all symbol table names use the default, which is basically
2786   // just the node id. But for debugging purposes, it's nicer if we can mix in
2787   // a recognizable object name if we have the information to do so.
2788   for (auto& kv : object_names_) {
2789     int node_id = kv.first;
2790     std::string internal_name =
2791         absl::StrCat(GetDefaultSymbolTableName(node_id), "__");
2792     // If the object has an exported name, we prefer that since it is probably
2793     // the most recognizable. Otherwise, we grab some non-exported name of the
2794     // object.
2795     if (exported_names_.find(node_id) != exported_names_.end()) {
2796       internal_name += exported_names_[node_id][0].str();
2797     } else {
2798       internal_name += object_names_[node_id][0];
2799     }
2800     pretty_symbol_table_name_[node_id] = SaveString(internal_name);
2801   }
2802 }
2803 
GetExportedNames(int node_id) const2804 llvm::ArrayRef<llvm::StringRef> ObjectNames::GetExportedNames(
2805     int node_id) const {
2806   auto it = exported_names_.find(node_id);
2807   if (it != exported_names_.end()) {
2808     return it->second;
2809   }
2810   return {};
2811 }
2812 
GetSymbolTableName(int node_id) const2813 llvm::StringRef ObjectNames::GetSymbolTableName(int node_id) const {
2814   auto it = pretty_symbol_table_name_.find(node_id);
2815   if (it != pretty_symbol_table_name_.end()) {
2816     return it->second;
2817   }
2818   return SaveString(GetDefaultSymbolTableName(node_id));
2819 }
2820 
GetDefaultSymbolTableName(int node_id) const2821 std::string ObjectNames::GetDefaultSymbolTableName(int node_id) const {
2822   return absl::StrCat("__sm_node", node_id);
2823 }
2824 
IsExported(const std::string & name)2825 bool ObjectNames::IsExported(const std::string& name) {
2826   if (names_to_export_.empty()) {
2827     return true;
2828   }
2829   return names_to_export_.find(name) != names_to_export_.end();
2830 }
2831 
RecursivelyVisitObjectGraph(int node_id)2832 void ObjectNames::RecursivelyVisitObjectGraph(int node_id) {
2833   const SavedObject& object = object_graph_.nodes(node_id);
2834 
2835   switch (object.kind_case()) {
2836     case SavedObject::kConstant:
2837     case SavedObject::kFunction:
2838     case SavedObject::kVariable: {
2839       object_names_[node_id].push_back(absl::StrJoin(path_segments_, "."));
2840       break;
2841     }
2842     default:
2843       break;
2844   }
2845 
2846   for (const auto& child_ref : object.children()) {
2847     bool on_stack = !on_stack_nodes_.insert(child_ref.node_id()).second;
2848     if (on_stack) {
2849       // This is a backedge. Don't traverse it.
2850       continue;
2851     }
2852 
2853     path_segments_.push_back(child_ref.local_name());
2854     RecursivelyVisitObjectGraph(child_ref.node_id());
2855     path_segments_.pop_back();
2856 
2857     on_stack_nodes_.erase(child_ref.node_id());
2858   }
2859 }
2860 
SaveString(const std::string & s) const2861 llvm::StringRef ObjectNames::SaveString(const std::string& s) const {
2862   return llvm::StringRef(*saved_strings_.insert(s).first);
2863 }
2864 
2865 // Extracts a TensorProto for a Const op from a GraphDef, given an op_name.
2866 // Returns nullptr on not found or other mismatch.
2867 // This returns a pointer to the actual node within the graph_def so as to
2868 // avoid expensive copies.
ExtractConstTensorFromGraph(const GraphDef & graph_def,const std::string & op_name)2869 const TensorProto* ExtractConstTensorFromGraph(const GraphDef& graph_def,
2870                                                const std::string& op_name) {
2871   const NodeDef* match_node = nullptr;
2872   for (const auto& node : graph_def.node()) {
2873     if (node.name() == op_name) {
2874       match_node = &node;
2875     }
2876   }
2877 
2878   if (!match_node) {
2879     return nullptr;
2880   }
2881 
2882   auto value_it = match_node->attr().find("value");
2883   if (value_it == match_node->attr().end()) {
2884     return nullptr;
2885   }
2886 
2887   if (!value_it->second.has_tensor()) {
2888     return nullptr;
2889   }
2890 
2891   return &value_it->second.tensor();
2892 }
2893 
2894 const TrackableObjectGraph::TrackableObject::SerializedTensor*
FindSerializedTensorInTrackable(const TrackableObjectGraph::TrackableObject & trackable_object,StringPiece name)2895 FindSerializedTensorInTrackable(
2896     const TrackableObjectGraph::TrackableObject& trackable_object,
2897     StringPiece name) {
2898   for (const auto& maybe_serialized_tensor : trackable_object.attributes()) {
2899     if (maybe_serialized_tensor.name() == name) {
2900       return &maybe_serialized_tensor;
2901     }
2902   }
2903   return nullptr;
2904 }
2905 
DiagnoseMultipleConcreteFunctions(const SavedObjectGraph & object_graph,const ObjectNames & object_names)2906 Status DiagnoseMultipleConcreteFunctions(const SavedObjectGraph& object_graph,
2907                                          const ObjectNames& object_names) {
2908   for (int node_id = 0; node_id < object_graph.nodes_size(); node_id++) {
2909     const SavedObject& object = object_graph.nodes(node_id);
2910     if (object_names.GetExportedNames(node_id).empty()) {
2911       continue;
2912     }
2913     if (object.kind_case() == SavedObject::kFunction) {
2914       // We only allow a single input signature to each SavedFunction.
2915       // This assumption means we have a 1:1 correspondence between
2916       // tf.function <=> SavedFunction <=> SavedConcreteFunction <=> FunctionDef
2917       // This makes defining the ABI easier (or even well-defined at all).
2918       // TODO(silvasean): How to detect a function that doesn't have an
2919       // explicitly user-provided input signature, but happens to have been
2920       // traced exactly once?
2921       if (object.function().concrete_functions_size() != 1) {
2922         llvm::SmallVector<std::string, 4> names;
2923         for (llvm::StringRef s : object_names.GetExportedNames(node_id)) {
2924           names.push_back("'" + s.str() + "'");
2925         }
2926         return errors::InvalidArgument(
2927             "Exported function with exported name(s) ",
2928             absl::StrJoin(names, ", "),
2929             " with multiple concrete functions. Add "
2930             "@tf.function(input_signature=[...]) on this function, or use a "
2931             "narrower list of exported names that excludes this function.");
2932       }
2933     }
2934   }
2935   return Status::OK();
2936 }
2937 
2938 // Recursively traverses a StructuredValue, linearizing all the leaves.
2939 //
2940 // This currently only handles the subset of StructuredValue that is needed for
2941 // signatures.
2942 //
2943 // Given a StructuredValue with structure [{"x": leaf0}], the "index path"
2944 // needed to reach leaf0 is `[0, "x"]`, as it would be if you were operating on
2945 // a Python object (`obj[0]["x"] is leaf0`). Each leaf corresponds to a
2946 // linearized function argument or return on a FunctionDef, and hence to an
2947 // mlir::FuncOp argument / return.
2948 //
2949 // This must match the linearization that happens in `tf.nest.flatten`.
2950 // In particular, dict values should be linearized in sorted key order.
2951 //
2952 // The linearized index paths can be returned back to a structured
2953 // representation (e.g. to emit C structs matching a signature) with a simple
2954 // algorithm that recurses on each run of index paths with identical first
2955 // elements.
2956 class StructuredValueLinearizer {
2957  public:
2958   StructuredValueLinearizer(const StructuredValue& value,
2959                             mlir::MLIRContext* context);
2960 
2961   // Returns the list of index paths to each leaf of the StructuredValue,
2962   // in a linearized order matching `tf.nest.flatten`.
2963   //
2964   // If an error occurred during the linearization process, an error message
2965   // with `error_context` prepended will be included in the returned status.
2966   StatusOr<llvm::ArrayRef<mlir::ArrayAttr>> GetLeafIndexPaths(
2967       llvm::StringRef error_context) const;
2968 
2969  private:
2970   // Main function that recursively traverses the StructuredValue.
2971   void RecursivelyFindLeaves(const StructuredValue& value);
2972 
2973   mlir::Builder builder_;
2974   // The current index path. We push/pop this during recursive traversal of the
2975   // StructuredValue.
2976   llvm::SmallVector<mlir::Attribute, 4> current_index_path_;
2977   // The list of leaf index paths we have discovered so far.
2978   llvm::SmallVector<mlir::ArrayAttr, 4> leaf_index_paths_;
2979   // If non-empty, an error message to report.
2980   std::string error_message_;
2981 };
2982 
StructuredValueLinearizer(const StructuredValue & value,mlir::MLIRContext * context)2983 StructuredValueLinearizer::StructuredValueLinearizer(
2984     const StructuredValue& value, mlir::MLIRContext* context)
2985     : builder_(context) {
2986   RecursivelyFindLeaves(value);
2987 }
2988 
2989 StatusOr<llvm::ArrayRef<mlir::ArrayAttr>>
GetLeafIndexPaths(llvm::StringRef error_context) const2990 StructuredValueLinearizer::GetLeafIndexPaths(
2991     llvm::StringRef error_context) const {
2992   if (error_message_.empty()) {
2993     return llvm::makeArrayRef(leaf_index_paths_);
2994   }
2995   return errors::InvalidArgument(
2996       error_context.str(), error_message_,
2997       "This likely means that you have @tf.function "
2998       "on an exported function instead of "
2999       "@tf.function(input_signature=[...]). Consider annotating an "
3000       "input_signature or narrowing your set of "
3001       "exported names to not include this function.");
3002 }
3003 
RecursivelyFindLeaves(const StructuredValue & value)3004 void StructuredValueLinearizer::RecursivelyFindLeaves(
3005     const StructuredValue& value) {
3006   switch (value.kind_case()) {
3007     case StructuredValue::kDictValue: {
3008       // Dict values must be linearized in sorted order of keys.
3009       const DictValue& dict = value.dict_value();
3010       using FieldTy = protobuf::MapPair<std::string, StructuredValue>;
3011       llvm::SmallVector<const FieldTy*, 4> fields;
3012       for (auto& field : dict.fields()) {
3013         fields.push_back(&field);
3014       }
3015       llvm::sort(fields, [](const FieldTy* a, const FieldTy* b) {
3016         return a->first < b->first;
3017       });
3018       for (auto& field : fields) {
3019         current_index_path_.push_back(builder_.getStringAttr(field->first));
3020         RecursivelyFindLeaves(field->second);
3021         current_index_path_.pop_back();
3022       }
3023       return;
3024     }
3025     case StructuredValue::kTupleValue: {
3026       const TupleValue& tuple = value.tuple_value();
3027       for (int i = 0, e = tuple.values_size(); i < e; i++) {
3028         current_index_path_.push_back(builder_.getI64IntegerAttr(i));
3029         RecursivelyFindLeaves(tuple.values(i));
3030         current_index_path_.pop_back();
3031       }
3032       return;
3033     }
3034     // We don't differentiate between tuples and lists.
3035     case StructuredValue::kListValue: {
3036       const ListValue& list = value.list_value();
3037       for (int i = 0, e = list.values_size(); i < e; i++) {
3038         current_index_path_.push_back(builder_.getI64IntegerAttr(i));
3039         RecursivelyFindLeaves(list.values(i));
3040         current_index_path_.pop_back();
3041       }
3042       return;
3043     }
3044     case StructuredValue::kTensorSpecValue: {
3045       // Base case: record the current path stack as the index path needed to
3046       // get to this leaf.
3047       leaf_index_paths_.push_back(builder_.getArrayAttr(current_index_path_));
3048       return;
3049     }
3050     case StructuredValue::kNoneValue: {
3051       // Base case: do nothing.
3052       // This arises, for example, as the top-level object of an output
3053       // signature when there are no return values.
3054       return;
3055     }
3056     default: {
3057       llvm::raw_string_ostream os(error_message_);
3058       // TODO(silvasean): Use an enumerant name string instead of a number.
3059       os << "Unhandled structured value kind " << value.kind_case()
3060          << " at index path: <value>";
3061       for (auto path_element : current_index_path_) {
3062         os << ".";
3063         if (auto integer = path_element.dyn_cast<mlir::IntegerAttr>()) {
3064           os << integer.getValue();
3065         } else {
3066           auto str = path_element.cast<mlir::StringAttr>();
3067           os << str.getValue();
3068         }
3069       }
3070       os << "\n";
3071     }
3072   }
3073 }
3074 
3075 // For exported functions with bound inputs, rewrite the function
3076 // signature to match the requirements of tf_saved_model bound input args.
3077 //
3078 // The raw imported functions have `tensor<*x!tf_type.resource>` as the type for
3079 // mutable bound inputs and `tensor<...>` as the type for immutable
3080 // bound inputs. Here we canonicalize both of them into
3081 // `tensor<!tf_type.resource<tensor<...>>>`.
AdjustBoundInputArgTypes(mlir::ModuleOp module)3082 void AdjustBoundInputArgTypes(mlir::ModuleOp module) {
3083   mlir::SymbolTable symbol_table(module);
3084   for (auto func : module.getOps<mlir::FuncOp>()) {
3085     if (!mlir::tf_saved_model::IsExported(func)) continue;
3086     mlir::OpBuilder builder(func.getBody());
3087     llvm::SmallVector<mlir::Type, 4> new_input_types;
3088     for (int i = 0, e = func.getNumArguments(); i < e; i++) {
3089       auto arg = func.getArgument(i);
3090       auto global_tensor = mlir::tf_saved_model::LookupBoundInputOfType<
3091           mlir::tf_saved_model::GlobalTensorOp>(func, i, symbol_table);
3092       if (global_tensor) {
3093         auto old_type = arg.getType();
3094         auto new_type =
3095             mlir::tf_saved_model::GetBoundInputArgTypeFor(global_tensor);
3096         arg.setType(new_type);
3097         if (global_tensor.is_mutable()) {
3098           auto arg_with_original_type = builder.create<mlir::TF::CastOp>(
3099               global_tensor.getLoc(), old_type, arg,
3100               /*Truncate=*/builder.getBoolAttr(false));
3101           arg.replaceAllUsesWith(arg_with_original_type);
3102           // The RAUW replaces the arg with itself, so we need to set it back.
3103           arg_with_original_type.setOperand(arg);
3104         } else {
3105           auto arg_with_original_type =
3106               builder.create<mlir::TF::ReadVariableOp>(global_tensor.getLoc(),
3107                                                        old_type, arg);
3108           arg.replaceAllUsesWith(arg_with_original_type);
3109           // The RAUW replaces the arg with itself, so we need to set it back.
3110           arg_with_original_type.setOperand(arg);
3111         }
3112       }
3113       new_input_types.push_back(arg.getType());
3114     }
3115     func.setType(mlir::FunctionType::get(module.getContext(), new_input_types,
3116                                          func.getType().getResults()));
3117   }
3118 }
3119 
3120 // Marks the visibility of functions in the saved model module.
MarkSavedModelFunctionVisibility(mlir::ModuleOp module)3121 void MarkSavedModelFunctionVisibility(mlir::ModuleOp module) {
3122   for (auto func : module.getOps<mlir::FuncOp>()) {
3123     auto visibility = mlir::tf_saved_model::IsExported(func)
3124                           ? mlir::FuncOp::Visibility::Public
3125                           : mlir::FuncOp::Visibility::Private;
3126     func.setVisibility(visibility);
3127   }
3128 }
3129 
3130 // Reorder the ops in the module to make testing easier and less dependent
3131 // on implementation details such as the order of functions in the
3132 // FunctionDefLibrary.
3133 //
3134 // The order this ensures is:
3135 // 1. GlobalTensorOp's
3136 // 2. FuncOps's.
3137 //
3138 // Within each of 1. and 2., ops are sorted by exported name (if
3139 // available, and only the first exported name is considered), followed by
3140 // non-exported ops.
SortSavedModelModule(mlir::ModuleOp module)3141 void SortSavedModelModule(mlir::ModuleOp module) {
3142   struct NamedGlobalTensor {
3143     llvm::StringRef name;
3144     GlobalTensorOp global_tensor;
3145   };
3146   llvm::SmallVector<NamedGlobalTensor, 8> named_global_tensors;
3147   for (auto global_tensor : module.getOps<GlobalTensorOp>()) {
3148     auto exported_names = mlir::tf_saved_model::GetExportedNames(global_tensor);
3149     // We use stable_sort, so duplicate empty names are fine here.
3150     named_global_tensors.push_back(
3151         {exported_names.empty() ? "" : exported_names.front(), global_tensor});
3152   }
3153   llvm::stable_sort(named_global_tensors,
3154                     [](const NamedGlobalTensor& a, const NamedGlobalTensor& b) {
3155                       return std::make_tuple(a.name.empty(), a.name) <
3156                              std::make_tuple(b.name.empty(), b.name);
3157                     });
3158 
3159   struct NamedFunc {
3160     llvm::StringRef name;
3161     mlir::FuncOp func;
3162   };
3163   llvm::SmallVector<NamedFunc, 8> named_funcs;
3164   llvm::SmallVector<mlir::FuncOp, 8> private_funcs;
3165   for (auto func : module.getOps<mlir::FuncOp>()) {
3166     auto exported_names = mlir::tf_saved_model::GetExportedNames(func);
3167     if (!exported_names.empty())
3168       named_funcs.push_back({exported_names.front(), func});
3169     else
3170       private_funcs.push_back(func);
3171   }
3172   llvm::stable_sort(named_funcs, [](const NamedFunc& a, const NamedFunc& b) {
3173     return a.name < b.name;
3174   });
3175   llvm::stable_sort(private_funcs, [](mlir::FuncOp a, mlir::FuncOp b) {
3176     return a.getName() < b.getName();
3177   });
3178 
3179   struct NamedAsset {
3180     llvm::StringRef name;
3181     AssetOp asset;
3182   };
3183   llvm::SmallVector<NamedAsset, 4> assets;
3184   for (auto asset : module.getOps<AssetOp>()) {
3185     assets.push_back({asset.getName(), asset});
3186   }
3187   llvm::stable_sort(assets, [](const NamedAsset& a, const NamedAsset& b) {
3188     return a.name < b.name;
3189   });
3190 
3191   // Move onto the front of the module in reverse of the final desired order.
3192   for (auto func : llvm::reverse(private_funcs)) {
3193     func.getOperation()->moveBefore(&module.getBody()->front());
3194   }
3195   for (auto named_func : llvm::reverse(named_funcs)) {
3196     named_func.func.getOperation()->moveBefore(&module.getBody()->front());
3197   }
3198   for (auto named_global_tensor : llvm::reverse(named_global_tensors)) {
3199     named_global_tensor.global_tensor.getOperation()->moveBefore(
3200         &module.getBody()->front());
3201   }
3202 
3203   for (auto asset : assets) {
3204     asset.asset.getOperation()->moveBefore(&module.getBody()->front());
3205   }
3206 
3207   auto initializers = module.getOps<SessionInitializerOp>();
3208   if (!initializers.empty()) {
3209     (*initializers.begin())
3210         .getOperation()
3211         ->moveBefore(&module.getBody()->front());
3212   }
3213 }
3214 
CreateSavedModelIR(const ObjectNames & object_names,mlir::ModuleOp module,const SavedObjectGraph & object_graph,const std::unordered_map<std::string,std::string> & tf_name_to_mlir_name,SavedModelV2Bundle * saved_model)3215 Status CreateSavedModelIR(
3216     const ObjectNames& object_names, mlir::ModuleOp module,
3217     const SavedObjectGraph& object_graph,
3218     const std::unordered_map<std::string, std::string>& tf_name_to_mlir_name,
3219     SavedModelV2Bundle* saved_model) {
3220   mlir::OpBuilder builder(module.getBodyRegion());
3221   mlir::SymbolTable symbol_table(module);
3222 
3223   // Create a side data-structure, indexed by the object_graph node_id to
3224   // a TrackableObject that is restorable.
3225   absl::flat_hash_map<int, const TrackableObjectGraph::TrackableObject*>
3226       restored_objects;
3227   TF_RETURN_IF_ERROR(saved_model->VisitObjectsToRestore(
3228       [&](int saved_node_id,
3229           const TrackableObjectGraph::TrackableObject& trackable_object) {
3230         restored_objects.insert(
3231             std::make_pair(saved_node_id, &trackable_object));
3232         return Status::OK();
3233       }));
3234 
3235   for (int node_id = 0; node_id < object_graph.nodes_size(); node_id++) {
3236     const SavedObject& object = object_graph.nodes(node_id);
3237     // For correctness, we cannot import functions that don't have exported
3238     // names, since they don't necessarily have a well-defined ABI (diagnosed
3239     // earlier).
3240     //
3241     // For variables/constants, pruning them is purely an optimization,
3242     // and more complicated since it requires use-def analysis of which
3243     // functions use which variables/constants, so we don't do anything
3244     // special for them here as part of our initial IR construction.
3245     if (object.kind_case() == SavedObject::kFunction) {
3246       if (object_names.GetExportedNames(node_id).empty()) {
3247         continue;
3248       }
3249       std::string error_context =
3250           "While importing SavedModel function '" +
3251           object_names.GetExportedNames(node_id)[0].str() + "': ";
3252       const SavedFunction& function = object.function();
3253       auto orig_func = symbol_table.lookup<mlir::FuncOp>(
3254           tf_name_to_mlir_name.find(function.concrete_functions(0))->second);
3255       mlir::FuncOp func = orig_func;
3256       // If there are potentially references to this func from within the
3257       // module, create a wrapper around it and decorate the wrapper with the
3258       // tf_saved_model attributes instead.
3259       if (!mlir::SymbolTable::symbolKnownUseEmpty(orig_func.getName(),
3260                                                   &module.getBodyRegion())) {
3261         func = orig_func.cloneWithoutRegions();
3262         module.insert(module.getBody()->begin(), func);
3263         func.addEntryBlock();
3264         func.setName("__sm_exported_" + orig_func.getName().str());
3265         llvm::SmallVector<mlir::Value, 4> args_as_values;
3266         for (auto block_argument : func.getArguments()) {
3267           args_as_values.push_back(block_argument);
3268         }
3269         mlir::OpBuilder body_builder(&func.getBody());
3270         auto call = body_builder.create<mlir::TF::StatefulPartitionedCallOp>(
3271             func.getLoc(), orig_func.getType().getResults(), args_as_values,
3272             builder.getSymbolRefAttr(orig_func.getName()),
3273             /*config=*/builder.getStringAttr(""),
3274             /*config_proto=*/builder.getStringAttr(""),
3275             /*executor_type=*/builder.getStringAttr(""));
3276         body_builder.create<mlir::ReturnOp>(func.getLoc(), call.getResults());
3277       }
3278       func->setAttr(
3279           "tf_saved_model.exported_names",
3280           builder.getStrArrayAttr(object_names.GetExportedNames(node_id)));
3281       const SavedConcreteFunction& concrete_function =
3282           object_graph.concrete_functions().at(function.concrete_functions(0));
3283 
3284       // We do not handle the other element of this tuple, which corresponds to
3285       // Python kwonlyargs, since currently TensorFlow prohibits this in
3286       // combination with input_signature:
3287       // https://github.com/tensorflow/tensorflow/blob/8cb8627abb5ef83a6fba34f8fd0e4ee430562eb1/tensorflow/python/eager/function.py#L2027-L2030
3288       // Our SavedModel import requires input_signature on the tf.function, so
3289       // we never need to handle the kwonlyargs.
3290       auto positional_arg_structure =
3291           concrete_function.canonicalized_input_signature()
3292               .tuple_value()
3293               .values(0);
3294       StructuredValueLinearizer input_linearizer(positional_arg_structure,
3295                                                  builder.getContext());
3296 
3297       int bound_input_base =
3298           func.getNumArguments() - concrete_function.bound_inputs_size();
3299       TF_ASSIGN_OR_RETURN(auto input_index_paths,
3300                           input_linearizer.GetLeafIndexPaths(
3301                               error_context + "in input signature: "));
3302       const int input_index_paths_size = input_index_paths.size();
3303       if (bound_input_base != input_index_paths_size) {
3304         return errors::InvalidArgument(
3305             error_context,
3306             "Argument mismatch between concrete function input signature "
3307             "vs underlying FunctionDef for concrete function '",
3308             function.concrete_functions(0), "' (", input_index_paths.size(),
3309             " vs ", bound_input_base, ")");
3310       }
3311       for (auto index_path : llvm::enumerate(input_index_paths)) {
3312         func.setArgAttr(index_path.index(), "tf_saved_model.index_path",
3313                         index_path.value());
3314       }
3315 
3316       for (auto& bound_input :
3317            llvm::enumerate(concrete_function.bound_inputs())) {
3318         int arg_index = bound_input_base + bound_input.index();
3319         auto symbol_ref = builder.getSymbolRefAttr(
3320             object_names.GetSymbolTableName(bound_input.value()));
3321         func.setArgAttr(arg_index, "tf_saved_model.bound_input", symbol_ref);
3322       }
3323 
3324       StructuredValueLinearizer output_linearizer(
3325           concrete_function.output_signature(), builder.getContext());
3326       TF_ASSIGN_OR_RETURN(auto output_index_paths,
3327                           output_linearizer.GetLeafIndexPaths(
3328                               error_context + "in output signature: "));
3329       if (func.getNumResults() != output_index_paths.size()) {
3330         return errors::InvalidArgument(
3331             error_context,
3332             "Result mismatch between concrete function output signature "
3333             "vs underlying FunctionDef for concrete function '",
3334             function.concrete_functions(0), "' (", output_index_paths.size(),
3335             " vs ", func.getNumResults(), ")");
3336       }
3337       for (auto index_path : llvm::enumerate(output_index_paths)) {
3338         func.setResultAttr(index_path.index(), "tf_saved_model.index_path",
3339                            index_path.value());
3340       }
3341     } else if (object.kind_case() == SavedObject::kVariable) {
3342       const SavedVariable& variable = object.variable();
3343       // Find the trackable in the side data structure.
3344       auto variable_trackable_it = restored_objects.find(node_id);
3345       if (variable_trackable_it == restored_objects.end()) {
3346         return errors::FailedPrecondition("Could not restore saved variable: ",
3347                                           variable.name());
3348       }
3349       const auto* serialized_tensor_attr = FindSerializedTensorInTrackable(
3350           *variable_trackable_it->second, "VARIABLE_VALUE");
3351       if (!serialized_tensor_attr) {
3352         return errors::FailedPrecondition(
3353             "Could not find serialized tensor for saved variable: ",
3354             variable.name());
3355       }
3356       const auto& checkpoint_key = serialized_tensor_attr->checkpoint_key();
3357 
3358       // Load it from the reader.
3359       Tensor value;
3360       TF_RETURN_WITH_CONTEXT_IF_ERROR(
3361           saved_model->variable_reader()->Lookup(checkpoint_key, &value),
3362           "Could not read checkpoint key from variables bundle: ",
3363           checkpoint_key);
3364       TF_ASSIGN_OR_RETURN(auto value_attr, ConvertTensor(value, &builder));
3365       // A variable can have a partially known type, such as tensor<?x27x?xf32>,
3366       // even if the initializer is a specific static shape.
3367       TF_ASSIGN_OR_RETURN(
3368           auto type, ConvertToMlirTensorType(variable.shape(), variable.dtype(),
3369                                              &builder));
3370       auto op = builder.create<GlobalTensorOp>(
3371           builder.getUnknownLoc(),
3372           builder.getStringAttr(object_names.GetSymbolTableName(node_id)),
3373           value_attr,
3374           /*type=*/mlir::TypeAttr::get(type),
3375           /*is_mutable=*/builder.getUnitAttr());
3376       op->setAttr(
3377           "tf_saved_model.exported_names",
3378           builder.getStrArrayAttr(object_names.GetExportedNames(node_id)));
3379     } else if (object.kind_case() == SavedObject::kConstant) {
3380       const SavedConstant& constant = object.constant();
3381       const TensorProto* value = ExtractConstTensorFromGraph(
3382           saved_model->meta_graph_def().graph_def(), constant.operation());
3383       if (!value) {
3384         return errors::FailedPrecondition(
3385             "Unable to find const node referenced in object graph: ",
3386             constant.operation());
3387       }
3388       TF_ASSIGN_OR_RETURN(auto value_attr,
3389                           ConvertTensorProto(*value, &builder));
3390       auto op = builder.create<GlobalTensorOp>(
3391           builder.getUnknownLoc(),
3392           builder.getStringAttr(object_names.GetSymbolTableName(node_id)),
3393           value_attr,
3394           /*type=*/mlir::TypeAttr::get(value_attr.Attribute::getType()),
3395           /*is_mutable=*/nullptr);
3396       op->setAttr(
3397           "tf_saved_model.exported_names",
3398           builder.getStrArrayAttr(object_names.GetExportedNames(node_id)));
3399     }
3400   }
3401   AdjustBoundInputArgTypes(module);
3402   module->setAttr("tf_saved_model.semantics", builder.getUnitAttr());
3403   SortSavedModelModule(module);
3404   MarkSavedModelFunctionVisibility(module);
3405   return Status::OK();
3406 }
3407 
Convert(SavedModelV2Bundle * saved_model,absl::Span<std::string> exported_names,mlir::MLIRContext * context,bool add_default_attributes)3408 StatusOr<mlir::OwningModuleRef> SavedModelObjectGraphImporter::Convert(
3409     SavedModelV2Bundle* saved_model, absl::Span<std::string> exported_names,
3410     mlir::MLIRContext* context, bool add_default_attributes) {
3411   LoadImporterDialects(*context);
3412   GraphDebugInfo dummy_debug_info;
3413   const GraphDebugInfo& debug_info =
3414       saved_model->debug_info() ? *saved_model->debug_info() : dummy_debug_info;
3415 
3416   GraphImportConfig specs;
3417   specs.prune_unused_nodes = true;
3418   mlir::OwningModuleRef module =
3419       mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
3420   std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
3421 
3422   const auto& graphdef = saved_model->meta_graph_def().graph_def();
3423   PopulateTfVersions(module.get(), graphdef.versions());
3424 
3425   GraphConstructorOptions options;
3426   options.allow_internal_ops = true;
3427   options.add_default_attributes = add_default_attributes;
3428   Graph graph(OpRegistry::Global());
3429 
3430   GraphDef preprocessed_graphdef(graphdef);
3431   if (add_default_attributes) {
3432     TF_RETURN_IF_ERROR(PreprocessGraphDef(nullptr, &preprocessed_graphdef));
3433   }
3434 
3435   TF_RETURN_IF_ERROR(
3436       ConvertGraphDefToGraph(options, preprocessed_graphdef, &graph));
3437 
3438   NameUniquifier function_name_uniquifier(graph.flib_def());
3439   SavedModelObjectGraphImporter importer(graph.flib_def(), debug_info, specs,
3440                                          module.get(), &tf_name_to_mlir_name,
3441                                          &function_name_uniquifier);
3442 
3443   TF_RETURN_IF_ERROR(importer.PrepareConvert(graph));
3444 
3445   auto fn_names = graph.flib_def().ListFunctionNames();
3446   for (const auto& fn_name : fn_names) {
3447     TF_RETURN_IF_ERROR(importer.ConvertLibFunction(fn_name));
3448   }
3449   TF_RETURN_IF_ERROR(importer.ConvertDeferredFunctions());
3450 
3451   if (!saved_model->meta_graph_def().has_object_graph_def()) {
3452     return errors::InvalidArgument(
3453         "SavedModel does not have an object graph. Please use TF2.");
3454   }
3455   auto& object_graph = saved_model->meta_graph_def().object_graph_def();
3456   ObjectNames object_names(object_graph, exported_names);
3457 
3458   // Clean up a couple func's that always seem to be present when importing a
3459   // SavedModel. This is not strictly needed, as there is a separate pass that
3460   // will clean them up, but this makes staring at the raw IR of minimal
3461   // examples quite a bit nicer.
3462   for (auto func : llvm::make_early_inc_range(module->getOps<mlir::FuncOp>())) {
3463     if (func.getName().startswith("__inference__traced_save_") ||
3464         func.getName().startswith("__inference__traced_restore_") ||
3465         func.getName().startswith("__inference_signature_wrapper_")) {
3466       func.erase();
3467     }
3468   }
3469 
3470   // Diagnose SavedFunction's with multiple input signatures.
3471   TF_RETURN_IF_ERROR(
3472       DiagnoseMultipleConcreteFunctions(object_graph, object_names));
3473 
3474   // Construct the SavedModel IR.
3475   TF_RETURN_IF_ERROR(CreateSavedModelIR(object_names, module.get(),
3476                                         object_graph, tf_name_to_mlir_name,
3477                                         saved_model));
3478   assert(mlir::succeeded(mlir::verify(module.get())));
3479 
3480   return module;
3481 }
3482 
3483 class SimpleSavedModelMLIRImportInput : public SavedModelMLIRImportInput {
3484  public:
Create(const MLIRImportOptions & import_options,const MetaGraphDef * meta_graph_def,const GraphDebugInfo & debug_info)3485   static StatusOr<SimpleSavedModelMLIRImportInput> Create(
3486       const MLIRImportOptions& import_options,
3487       const MetaGraphDef* meta_graph_def, const GraphDebugInfo& debug_info) {
3488     DCHECK(meta_graph_def);
3489     GraphDef graph_def;
3490     if (import_options.enable_grappler) {
3491       // Grappler is best-effort.
3492       auto statusor = RunGrappler(*meta_graph_def);
3493       if (statusor.ok()) {
3494         graph_def = std::move(statusor).ValueOrDie();
3495       } else {
3496         // If the grappler fails, use the original graph def.
3497         LOG(WARNING) << "SimpleSavedModelMLIRImportInput: grappler failed: "
3498                      << statusor.status();
3499         graph_def = meta_graph_def->graph_def();
3500       }
3501     } else {
3502       graph_def = meta_graph_def->graph_def();
3503     }
3504 
3505     auto graph = std::make_unique<Graph>(OpRegistry::Global());
3506 
3507     if (import_options.upgrade_legacy) {
3508       TF_RETURN_IF_ERROR(GenerateResourceSharedNameIfEmpty(
3509           graph_def, graph->flib_def().default_registry()));
3510     }
3511 
3512     GraphConstructorOptions graph_ctor_options;
3513     graph_ctor_options.allow_internal_ops = true;
3514     graph_ctor_options.add_default_attributes = true;
3515     TF_RETURN_IF_ERROR(
3516         ConvertGraphDefToGraph(graph_ctor_options, graph_def, graph.get()));
3517 
3518     if (import_options.upgrade_legacy) {
3519       // TODO(jpienaar): Remove need to const_cast.
3520       TF_RETURN_IF_ERROR(UpgradeLegacyGraph(
3521           graph.get(),
3522           const_cast<FunctionLibraryDefinition*>(&graph->flib_def()),
3523           /*restrict_functionalization_to_tpu_nodes=*/false));
3524     }
3525 
3526     return SimpleSavedModelMLIRImportInput(meta_graph_def, debug_info,
3527                                            std::move(graph));
3528   }
3529 
SimpleSavedModelMLIRImportInput(const MetaGraphDef * meta_graph_def,const GraphDebugInfo & debug_info,std::unique_ptr<Graph> graph)3530   SimpleSavedModelMLIRImportInput(const MetaGraphDef* meta_graph_def,
3531                                   const GraphDebugInfo& debug_info,
3532                                   std::unique_ptr<Graph> graph)
3533       : SavedModelMLIRImportInput(meta_graph_def, debug_info),
3534         graph_(std::move(graph)) {}
3535 
GetSubGraph(absl::string_view name,const GraphImportConfig & specs)3536   StatusOr<const Graph*> GetSubGraph(absl::string_view name,
3537                                      const GraphImportConfig& specs) override {
3538     DCHECK(CheckGraphNameValidity(name));
3539     DCHECK(CheckGraphContainsFeedsAndFetches(specs));
3540     return graph_.get();
3541   }
3542 
3543  private:
CheckGraphContainsFeedsAndFetches(const GraphImportConfig & specs) const3544   bool CheckGraphContainsFeedsAndFetches(const GraphImportConfig& specs) const {
3545     absl::flat_hash_set<std::string> feed_fetch_nodes;
3546     for (const auto& iter : specs.inputs) {
3547       TensorId tensor_id = ParseTensorName(iter.first);
3548       feed_fetch_nodes.insert(std::string(tensor_id.node()));
3549     }
3550     for (const auto& output : llvm::concat<const std::string>(
3551              specs.outputs, specs.control_outputs)) {
3552       TensorId tensor_id = ParseTensorName(output);
3553       feed_fetch_nodes.insert(std::string(tensor_id.node()));
3554     }
3555 
3556     for (Node* node : graph_->op_nodes()) {
3557       feed_fetch_nodes.erase(node->name());
3558     }
3559 
3560     return feed_fetch_nodes.empty();
3561   }
3562 
CheckGraphNameValidity(absl::string_view name) const3563   bool CheckGraphNameValidity(absl::string_view name) const {
3564     // If it is one of the signature name, it is valid.
3565     const auto& signature_defs = meta_graph_def().signature_def();
3566     if (signature_defs.contains(std::string(name))) return true;
3567 
3568     // If it is the restore graph name, it is valid.
3569     if (meta_graph_def().has_saver_def() &&
3570         meta_graph_def().saver_def().restore_op_name() == name)
3571       return true;
3572 
3573     // If it is the init graph name, it is valid.
3574     std::string init_op_name;
3575     if (internal::GetInitOp("", meta_graph_def(), &init_op_name).ok()) {
3576       if (init_op_name == name) return true;
3577     }
3578 
3579     return false;
3580   }
3581 
3582   // `graph_` contains the entire graph in the original MetaGraphDef.
3583   std::unique_ptr<Graph> graph_;
3584 };
3585 
GetOriginalTfFuncNamesFromGraphDef(const GraphDef & graph_def)3586 static absl::flat_hash_set<std::string> GetOriginalTfFuncNamesFromGraphDef(
3587     const GraphDef& graph_def) {
3588   absl::flat_hash_set<std::string> original_func_tf_names;
3589   for (const auto& function : graph_def.library().function()) {
3590     original_func_tf_names.insert(function.signature().name());
3591   }
3592   return original_func_tf_names;
3593 }
3594 
3595 // A helper class to import a TensorFlow model expressed in SavedModel V1 into
3596 // an MLIR Module in SavedModel dialect.
3597 //
3598 // TODO(b/179683149): Rename this class to avoid confusion with TFLite.
3599 class SavedModelSignatureDefImporterLite {
3600  public:
3601   // Main entry point: converts all functions (specified by SignatureDefs) in
3602   // the given meta graph to an MLIR Module.
3603   //
3604   // `import_restore` is introduced to control whether restore graph
3605   // is imported in eg. SavedModelSignatureDefImporter. Ideally, we don't need
3606   // this option to control this as restore graph should be always imported.
3607   // However, right now, SavedModelSignatureDefImporter cannot handle restore
3608   // graph correctly.
3609   //
3610   // TODO(chky): Remove import_restore once the restore graph is correctly
3611   // handled in SavedModelSignatureDefImporter.
Convert(SavedModelMLIRImportInput & input,absl::optional<absl::Span<const std::string>> exported_names,mlir::MLIRContext * context,bool import_restore=true)3612   static StatusOr<mlir::OwningModuleRef> Convert(
3613       SavedModelMLIRImportInput& input,
3614       absl::optional<absl::Span<const std::string>> exported_names,
3615       mlir::MLIRContext* context, bool import_restore = true) {
3616     SavedModelSignatureDefImporterLite importer(input, exported_names, context,
3617                                                 import_restore);
3618     return importer.ConvertSignatures();
3619   }
3620 
3621  private:
SavedModelSignatureDefImporterLite(SavedModelMLIRImportInput & input,absl::optional<absl::Span<const std::string>> exported_names,mlir::MLIRContext * context,bool import_restore)3622   SavedModelSignatureDefImporterLite(
3623       SavedModelMLIRImportInput& input,
3624       absl::optional<absl::Span<const std::string>> exported_names,
3625       mlir::MLIRContext* context, bool import_restore)
3626       : input_(input),
3627         original_func_tf_names_(GetOriginalTfFuncNamesFromGraphDef(
3628             input.meta_graph_def().graph_def())),
3629         exported_names_(exported_names),
3630         module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(context))),
3631         symbol_table_(module_.get()),
3632         import_restore_(import_restore) {}
3633 
3634   // Converts the SavedModel to the SavedModel dialect. Creates an MLIR function
3635   // for each signature.
3636   StatusOr<mlir::OwningModuleRef> ConvertSignatures();
3637   Status ConvertSignature(const std::string& sig_def_key,
3638                           const SignatureDef& signature_def);
3639 
3640   struct AssetInfo {
3641     std::string tensor_name;
3642     mlir::tf_saved_model::AssetOp op;
3643   };
3644   StatusOr<std::vector<AssetInfo>> ConvertAssets();
3645   // Converts the initialization graph in the SavedModel to an MLIR function.
3646   Status ConvertInitializer(const std::string& target_node_name,
3647                             const std::vector<AssetInfo>& assets);
3648 
3649   // Converts a graph with feeds and fetches to an MLIR function.
3650   StatusOr<mlir::OwningModuleRef> ConvertGraph(
3651       const std::string& name,
3652       const std::vector<std::pair<std::string, TensorInfo>>& inputs,
3653       const std::vector<std::pair<std::string, TensorInfo>>& outputs,
3654       const std::vector<std::string> control_outputs,
3655       std::unordered_map<std::string, std::string>& tf_name_to_mlir_name);
3656 
3657   // Moves the functions in `sub_module` to `module_` and skips the duplicate
3658   // functions.
3659   Status MoveConvertedFunctionsToModule(
3660       absl::string_view name, mlir::ModuleOp sub_module,
3661       const std::unordered_map<std::string, std::string>& tf_name_to_mlir_name);
3662 
3663   StatusOr<GraphImportConfig::InputArrays> ParseInputArrays(
3664       llvm::ArrayRef<std::pair<std::string, TensorInfo>> inputs);
3665 
3666  private:
3667   SavedModelMLIRImportInput& input_;
3668   absl::flat_hash_set<std::string> original_func_tf_names_;
3669   absl::optional<absl::Span<const std::string>> exported_names_;
3670   mlir::OwningModuleRef module_;
3671   mlir::SymbolTable symbol_table_;
3672   bool import_restore_ = true;
3673 };
3674 
3675 StatusOr<std::vector<SavedModelSignatureDefImporterLite::AssetInfo>>
ConvertAssets()3676 SavedModelSignatureDefImporterLite::ConvertAssets() {
3677   std::vector<AssetFileDef> asset_file_defs;
3678   TF_RETURN_IF_ERROR(
3679       internal::GetAssetFileDefs(input_.meta_graph_def(), &asset_file_defs));
3680 
3681   std::vector<AssetInfo> results;
3682   results.reserve(asset_file_defs.size());
3683 
3684   mlir::OpBuilder builder(module_->getBodyRegion());
3685   unsigned i = 0;  // Use to generate unique sym_name(s) for duplicate assets.
3686   for (const auto& asset : asset_file_defs) {
3687     auto asset_op = builder.create<mlir::tf_saved_model::AssetOp>(
3688         module_->getLoc(),
3689         /*sym_name=*/
3690         builder.getStringAttr(
3691             absl::StrCat("__tf_saved_model_asset", i++, "_", asset.filename())),
3692         /*filename=*/
3693         builder.getStringAttr(
3694             io::JoinPath(kSavedModelAssetsDirectory, asset.filename())));
3695 
3696     results.push_back({asset.tensor_info().name(), asset_op});
3697   }
3698 
3699   return results;
3700 }
3701 
MoveConvertedFunctionsToModule(absl::string_view name,mlir::ModuleOp sub_module,const std::unordered_map<std::string,std::string> & tf_name_to_mlir_name)3702 Status SavedModelSignatureDefImporterLite::MoveConvertedFunctionsToModule(
3703     absl::string_view name, mlir::ModuleOp sub_module,
3704     const std::unordered_map<std::string, std::string>& tf_name_to_mlir_name) {
3705   mlir::Builder builder(sub_module.getContext());
3706   mlir::SymbolTable sub_module_symbol_table(sub_module);
3707 
3708   // Functions originally from graphdef library might have a different name
3709   // after conversion, we build the set of the converted names
3710   absl::flat_hash_set<std::string> original_func_mlir_names;
3711   for (const auto& kv : tf_name_to_mlir_name) {
3712     if (original_func_tf_names_.contains(kv.first))
3713       original_func_mlir_names.insert(kv.second);
3714   }
3715 
3716   // Prefix private functions with the unique signature name, so that it cannot
3717   // collide with private functions used in the other signatures.
3718   for (auto func : sub_module.getOps<mlir::FuncOp>()) {
3719     if (mlir::tf_saved_model::IsExported(func)) continue;
3720 
3721     // Skip the original functions from graphdef library
3722     if (original_func_mlir_names.count(func.sym_name().str())) continue;
3723 
3724     std::string new_sym_name = absl::StrCat(name, "/", func.sym_name().str());
3725     if (mlir::failed(sub_module_symbol_table.replaceAllSymbolUses(
3726             func, new_sym_name, sub_module)))
3727       return tensorflow::errors::InvalidArgument(absl::StrCat(
3728           "SavedModelSignatureDefImporterLite: failed to assign a unique "
3729           "name to the private function used in a signature: ",
3730           func.sym_name().str()));
3731 
3732     mlir::SymbolTable::setSymbolName(func, new_sym_name);
3733   }
3734 
3735   // Copy all functions used by this signature to the final MLIR module.
3736   for (auto func : sub_module.getOps<mlir::FuncOp>()) {
3737     // The insert here is a NO-OP if the function already exists.
3738     symbol_table_.insert(func.clone());
3739   }
3740 
3741   return Status::OK();
3742 }
3743 
ConvertInitializer(const std::string & target_node_name,const std::vector<AssetInfo> & assets)3744 Status SavedModelSignatureDefImporterLite::ConvertInitializer(
3745     const std::string& target_node_name, const std::vector<AssetInfo>& assets) {
3746   std::vector<std::pair<std::string, TensorInfo>> inputs;
3747   inputs.reserve(assets.size());
3748   for (const auto& asset : assets) {
3749     TensorInfo tensor_info;
3750     tensor_info.set_name(asset.tensor_name);
3751     tensor_info.set_dtype(DT_STRING);
3752     tensor_info.mutable_tensor_shape();
3753     inputs.push_back({asset.tensor_name, tensor_info});
3754   }
3755 
3756   std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
3757   TF_ASSIGN_OR_RETURN(auto sub_module,
3758                       ConvertGraph(target_node_name, inputs, {},
3759                                    {target_node_name}, tf_name_to_mlir_name));
3760 
3761   mlir::SymbolTable sub_symbol_table(*sub_module);
3762 
3763   auto init_func_op = sub_symbol_table.lookup<mlir::FuncOp>(target_node_name);
3764   init_func_op->removeAttr("tf.entry_function");
3765 
3766   mlir::OpBuilder builder(module_->getBodyRegion());
3767 
3768   // Bind asset inputs to asset ops.
3769   DCHECK_EQ(init_func_op.getNumArguments(), assets.size());
3770   for (const auto& iter : llvm::enumerate(assets)) {
3771     auto asset_op = iter.value().op;
3772     init_func_op.setArgAttr(iter.index(), "tf_saved_model.bound_input",
3773                             builder.getSymbolRefAttr(asset_op.getName()));
3774   }
3775 
3776   // Set the exported name of init function to an reserved name for
3777   // tf_saved_model.
3778   init_func_op->setAttr(
3779       "tf_saved_model.exported_names",
3780       builder.getStrArrayAttr({absl::StrCat(
3781           "__tf_saved_model_session_initializer_", target_node_name)}));
3782 
3783   // Move the converted functions to top level MLIR module.
3784   return MoveConvertedFunctionsToModule(target_node_name, *sub_module,
3785                                         tf_name_to_mlir_name);
3786 }
3787 
3788 StatusOr<mlir::OwningModuleRef>
ConvertGraph(const std::string & name,const std::vector<std::pair<std::string,TensorInfo>> & inputs,const std::vector<std::pair<std::string,TensorInfo>> & outputs,const std::vector<std::string> control_outputs,std::unordered_map<std::string,std::string> & tf_name_to_mlir_name)3789 SavedModelSignatureDefImporterLite::ConvertGraph(
3790     const std::string& name,
3791     const std::vector<std::pair<std::string, TensorInfo>>& inputs,
3792     const std::vector<std::pair<std::string, TensorInfo>>& outputs,
3793     const std::vector<std::string> control_outputs,
3794     std::unordered_map<std::string, std::string>& tf_name_to_mlir_name) {
3795   VLOG(1) << "Importing Signature: " << name;
3796 
3797   GraphImportConfig specs;
3798   specs.graph_func_name = name;
3799   specs.prune_unused_nodes = true;
3800   TF_ASSIGN_OR_RETURN(specs.inputs, ParseInputArrays(inputs));
3801   for (auto& output : outputs) specs.outputs.push_back(output.second.name());
3802   specs.control_outputs = control_outputs;
3803   specs.enable_shape_inference = false;
3804 
3805   TF_ASSIGN_OR_RETURN(const auto* subgraph, input_.GetSubGraph(name, specs));
3806 
3807   // Convert sub-graph to MLIR module.
3808   return GraphDefImporter::Convert(module_->getContext(), *subgraph,
3809                                    input_.debug_info(), subgraph->flib_def(),
3810                                    specs, tf_name_to_mlir_name);
3811 }
3812 
ConvertSignature(const std::string & sig_def_key,const SignatureDef & signature_def)3813 Status SavedModelSignatureDefImporterLite::ConvertSignature(
3814     const std::string& sig_def_key, const SignatureDef& signature_def) {
3815   // Create local vectors for the input and output and sort them to be
3816   // deterministic. We don't want anyone to really depend on the order, client
3817   // should lookup argument/result mapping by attribute name.
3818   // To avoid accidentally depending on the order we use an unintuitive sorting.
3819   std::vector<std::pair<std::string, TensorInfo>> inputs(
3820       signature_def.inputs().begin(), signature_def.inputs().end());
3821   llvm::sort(inputs, [](const auto& lhs, const auto& rhs) {
3822     return tensorflow::Fingerprint64(lhs.first) <
3823            tensorflow::Fingerprint64(rhs.first);
3824   });
3825   std::vector<std::pair<std::string, TensorInfo>> outputs(
3826       signature_def.outputs().begin(), signature_def.outputs().end());
3827   llvm::sort(outputs, [](const auto& lhs, const auto& rhs) {
3828     return tensorflow::Fingerprint64(lhs.first) <
3829            tensorflow::Fingerprint64(rhs.first);
3830   });
3831 
3832   std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
3833 
3834   // Convert sub-graph to MLIR module.
3835   TF_ASSIGN_OR_RETURN(
3836       auto sub_module,
3837       ConvertGraph(sig_def_key, inputs, outputs, {}, tf_name_to_mlir_name));
3838   mlir::OpBuilder builder(sub_module->getBodyRegion());
3839 
3840   // Find the FuncOp which corresponds to current SignatureDef.
3841   mlir::SymbolTable sub_symbol_table(*sub_module);
3842   auto func_op = sub_symbol_table.lookup<mlir::FuncOp>(sig_def_key);
3843   TF_RET_CHECK(func_op)
3844       << "Graphdef importer should have created a function named "
3845       << sig_def_key << ".";
3846 
3847   // Use unique SignatureDef key as exported name.
3848   func_op->setAttr("tf_saved_model.exported_names",
3849                    builder.getStrArrayAttr({sig_def_key}));
3850 
3851   // Transfer input and output parameter names to index_path attributes.
3852   for (auto input_and_idx : llvm::enumerate(inputs)) {
3853     func_op.setArgAttr(input_and_idx.index(), "tf_saved_model.index_path",
3854                        builder.getStrArrayAttr({input_and_idx.value().first}));
3855   }
3856   for (auto output_and_idx : llvm::enumerate(outputs)) {
3857     func_op.setResultAttr(
3858         output_and_idx.index(), "tf_saved_model.index_path",
3859         builder.getStrArrayAttr({output_and_idx.value().first}));
3860   }
3861 
3862   // Move the converted functions to top level MLIR module.
3863   return MoveConvertedFunctionsToModule(sig_def_key, *sub_module,
3864                                         tf_name_to_mlir_name);
3865 }
3866 
3867 StatusOr<GraphImportConfig::InputArrays>
ParseInputArrays(llvm::ArrayRef<std::pair<std::string,TensorInfo>> inputs)3868 SavedModelSignatureDefImporterLite::ParseInputArrays(
3869     llvm::ArrayRef<std::pair<std::string, TensorInfo>> inputs) {
3870   GraphImportConfig::InputArrays results;
3871   for (const auto& iter : inputs) {
3872     const auto& tensor_info = iter.second;
3873 
3874     // TODO(b/184675681): Support other encoding cases.
3875     //
3876     // TODO(b/184679394): Add unit test for this check.
3877     TF_RET_CHECK(tensor_info.encoding_case() == tensorflow::TensorInfo::kName)
3878         << "Only dense tensor is supported, but got encoding case "
3879         << tensor_info.encoding_case();
3880 
3881     VLOG(1) << "Importing Signature Input: input_name = " << iter.first
3882             << ", tensor_info = " << tensor_info.DebugString();
3883 
3884     ArrayInfo array_info;
3885     array_info.imported_dtype = tensor_info.dtype();
3886 
3887     if (tensor_info.has_tensor_shape()) {
3888       array_info.shape = tensor_info.tensor_shape();
3889     } else {
3890       // If there is no tensor shape in the tensor info, conservatively set
3891       // unknown_rank to true.
3892       array_info.shape.set_unknown_rank(true);
3893     }
3894 
3895     results.insert(std::pair<std::string, ArrayInfo>(tensor_info.name(),
3896                                                      std::move(array_info)));
3897   }
3898   return results;
3899 }
3900 
3901 StatusOr<mlir::OwningModuleRef>
ConvertSignatures()3902 SavedModelSignatureDefImporterLite::ConvertSignatures() {
3903   LoadImporterDialects(*module_->getContext());
3904 
3905   const auto& signatures = input_.meta_graph_def().signature_def();
3906   PopulateTfVersions(module_.get(),
3907                      input_.meta_graph_def().graph_def().versions());
3908 
3909   llvm::DenseSet<llvm::StringRef> exported_name_set;
3910   bool import_all_signatures = !exported_names_.has_value();
3911   if (exported_names_.has_value()) {
3912     exported_name_set.insert(exported_names_->begin(), exported_names_->end());
3913   }
3914 
3915   for (const auto& key_and_signature_def : signatures) {
3916     const std::string& sig_def_key = key_and_signature_def.first;
3917     const SignatureDef& signature_def = key_and_signature_def.second;
3918 
3919     // It is safe to skip "__saved_model_init_op" since it is an internal
3920     // signature that is not user-accessible. This signature will be handled in
3921     // ConvertInitializer().
3922     if (sig_def_key == "__saved_model_init_op") {
3923       continue;
3924     }
3925     if (!import_all_signatures && exported_name_set.count(sig_def_key) == 0) {
3926       continue;
3927     }
3928 
3929     TF_RETURN_IF_ERROR(ConvertSignature(sig_def_key, signature_def));
3930   }
3931 
3932   TF_ASSIGN_OR_RETURN(auto assets, ConvertAssets());
3933 
3934   mlir::OpBuilder builder(module_->getBodyRegion());
3935   llvm::SmallVector<mlir::Attribute, 2> init_sym_refs;
3936 
3937   if (import_restore_ && input_.meta_graph_def().has_saver_def()) {
3938     std::vector<AssetInfo> variable_and_assets;
3939 
3940     // Create an AssetOp for the variable checkpoint files. The relative
3941     // filename is used here.
3942     auto variable_filename_op = builder.create<mlir::tf_saved_model::AssetOp>(
3943         module_->getLoc(),
3944         /*sym_name=*/
3945         builder.getStringAttr("__tf_saved_model_variables"),
3946         /*filename=*/
3947         builder.getStringAttr(io::JoinPath(kSavedModelVariablesDirectory,
3948                                            kSavedModelVariablesFilename)));
3949     variable_and_assets.push_back(
3950         {input_.meta_graph_def().saver_def().filename_tensor_name(),
3951          variable_filename_op});
3952     variable_and_assets.insert(variable_and_assets.end(), assets.begin(),
3953                                assets.end());
3954 
3955     const auto& restore_op_name =
3956         input_.meta_graph_def().saver_def().restore_op_name();
3957     TF_RETURN_IF_ERROR(
3958         ConvertInitializer(restore_op_name, variable_and_assets));
3959     init_sym_refs.push_back(builder.getSymbolRefAttr(restore_op_name));
3960   }
3961 
3962   std::string init_op_name;
3963   TF_RETURN_IF_ERROR(
3964       internal::GetInitOp("", input_.meta_graph_def(), &init_op_name));
3965   if (!init_op_name.empty()) {
3966     TF_RETURN_IF_ERROR(ConvertInitializer(init_op_name, assets));
3967     init_sym_refs.push_back(builder.getSymbolRefAttr(init_op_name));
3968   }
3969 
3970   builder.create<mlir::tf_saved_model::SessionInitializerOp>(
3971       module_->getLoc(), builder.getArrayAttr(init_sym_refs));
3972 
3973   (*module_)->setAttr("tf_saved_model.semantics", builder.getUnitAttr());
3974 
3975   SortSavedModelModule(*module_);
3976   MarkSavedModelFunctionVisibility(*module_);
3977 
3978   return std::move(module_);
3979 }
3980 
3981 // A helper class to import a TensorFlow model expressed in SavedModel V1 into
3982 // an MLIR Module in SavedModel dialect. In addition to importing the model, it
3983 // performs a few graph transformations, including:
3984 //  1) Convert read-only ref variables to resource variables
3985 //  2) Lift resource variables to global_tensors by using a TF session.
3986 class SavedModelSignatureDefImporter {
3987  public:
3988   // Main entry point: converts all functions (specified by SignatureDefs) in
3989   // the given meta graph to an MLIR Module.
Convert(const SavedModelBundle & bundle,absl::optional<absl::Span<const std::string>> exported_names,mlir::MLIRContext * context,tensorflow::MLIRImportOptions options,bool lift_varhandle_ops_to_args=true)3990   static StatusOr<mlir::OwningModuleRef> Convert(
3991       const SavedModelBundle& bundle,
3992       absl::optional<absl::Span<const std::string>> exported_names,
3993       mlir::MLIRContext* context, tensorflow::MLIRImportOptions options,
3994       bool lift_varhandle_ops_to_args = true) {
3995     // debug_info might not be loaded with loader_lite.
3996     GraphDebugInfo debug_info;
3997     if (bundle.debug_info != nullptr) debug_info = *bundle.debug_info;
3998 
3999     TF_ASSIGN_OR_RETURN(auto input,
4000                         SimpleSavedModelMLIRImportInput::Create(
4001                             options, &bundle.meta_graph_def, debug_info));
4002 
4003     TF_ASSIGN_OR_RETURN(auto module,
4004                         SavedModelSignatureDefImporterLite::Convert(
4005                             input, exported_names, context,
4006                             /*import_restore=*/false));
4007 
4008     mlir::OpBuilder builder(module->getContext());
4009     (*module)->setAttr("tf_saved_model.under_construction",
4010                        builder.getUnitAttr());
4011     TF_RETURN_IF_ERROR(
4012         LiftVariables(bundle, *module, lift_varhandle_ops_to_args));
4013     (*module)->removeAttr("tf_saved_model.under_construction");
4014 
4015     return module;
4016   }
4017 
4018  private:
4019   // Lifts the variables in `module`.
4020   static Status LiftVariables(const SavedModelBundle& bundle,
4021                               mlir::ModuleOp module,
4022                               bool lift_varhandle_ops_to_args);
4023 };
4024 
LiftVariables(const SavedModelBundle & bundle,mlir::ModuleOp module,bool lift_varhandle_ops_to_args)4025 Status SavedModelSignatureDefImporter::LiftVariables(
4026     const SavedModelBundle& bundle, mlir::ModuleOp module,
4027     bool lift_varhandle_ops_to_args) {
4028   mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
4029 
4030   mlir::PassManager pm(module.getContext());
4031   SetCrashReproducer(pm);
4032   pm.addNestedPass<mlir::FuncOp>(
4033       mlir::tf_executor::CreateTFExecutorGraphPruningPass());
4034   pm.addNestedPass<mlir::FuncOp>(
4035       mlir::CreateExecutorDialectToFunctionalConversionPass());
4036   pm.addPass(
4037       mlir::tf_saved_model::CreateRemoveVariablesInSessionInitializerPass());
4038   pm.addNestedPass<mlir::FuncOp>(
4039       mlir::TF::
4040           CreateConvertReadonlyReferenceVariablesToResourceVariablesPass());
4041   if (lift_varhandle_ops_to_args) {
4042     pm.addNestedPass<mlir::FuncOp>(
4043         mlir::tf_saved_model::CreateMarkInitializedVariablesPass(
4044             bundle.GetSession()));
4045     pm.addPass(mlir::TF::CreatePromoteVarHandlesToArgsPass());
4046     pm.addPass(
4047         mlir::tf_saved_model::CreateLiftVariablesPass(bundle.GetSession()));
4048   } else {
4049     pm.addPass(
4050         mlir::tf_saved_model::CreateInitializeVariablesInSessionInitializerPass(
4051             bundle.GetSession()));
4052   }
4053   pm.addNestedPass<mlir::FuncOp>(
4054       mlir::tf_saved_model::CreateDedupBoundInputBindingPass());
4055   if (mlir::failed(pm.run(module)))
4056     return diag_handler.Combine(errors::Internal("Failed to lift variables."));
4057 
4058   return Status::OK();
4059 }
4060 
4061 }  // namespace
4062 
~SavedModelMLIRImportInput()4063 SavedModelMLIRImportInput::~SavedModelMLIRImportInput() {}
4064 
ConvertGraphdefToMlir(const GraphDef & graphdef,const GraphDebugInfo & debug_info,const GraphImportConfig & specs,mlir::MLIRContext * context,bool add_default_attributes)4065 StatusOr<mlir::OwningModuleRef> ConvertGraphdefToMlir(
4066     const GraphDef& graphdef, const GraphDebugInfo& debug_info,
4067     const GraphImportConfig& specs, mlir::MLIRContext* context,
4068     bool add_default_attributes) {
4069   GraphConstructorOptions options;
4070   options.allow_internal_ops = true;
4071   options.add_default_attributes = add_default_attributes;
4072   Graph graph(OpRegistry::Global());
4073 
4074   GraphDef preprocessed_graphdef(graphdef);
4075   if (add_default_attributes) {
4076     TF_RETURN_IF_ERROR(PreprocessGraphDef(&specs, &preprocessed_graphdef));
4077   }
4078   if (specs.upgrade_legacy) {
4079     TF_RETURN_IF_ERROR(GenerateResourceSharedNameIfEmpty(
4080         preprocessed_graphdef, graph.flib_def().default_registry()));
4081   }
4082   TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
4083       options, std::move(preprocessed_graphdef), &graph));
4084   return ConvertGraphToMlir(graph, debug_info, graph.flib_def(), specs,
4085                             context);
4086 }
4087 
ConvertGraphToMlir(const Graph & graph,const GraphDebugInfo & debug_info,const FunctionLibraryDefinition & flib_def,const GraphImportConfig & specs,mlir::MLIRContext * context)4088 StatusOr<mlir::OwningModuleRef> ConvertGraphToMlir(
4089     const Graph& graph, const GraphDebugInfo& debug_info,
4090     const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs,
4091     mlir::MLIRContext* context) {
4092   // TODO(jpienaar): Remove need to const_cast.
4093   if (specs.upgrade_legacy) {
4094     TF_RETURN_IF_ERROR(
4095         UpgradeLegacyGraph(const_cast<Graph*>(&graph),
4096                            const_cast<FunctionLibraryDefinition*>(&flib_def),
4097                            specs.restrict_functionalization_to_tpu_nodes));
4098   }
4099   std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
4100   return GraphDefImporter::Convert(context, graph, debug_info, flib_def, specs,
4101                                    tf_name_to_mlir_name);
4102 }
4103 
ConvertFunctionToMlir(const FunctionBody * fbody,const FunctionLibraryDefinition & flib_def,mlir::MLIRContext * context)4104 stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertFunctionToMlir(
4105     const FunctionBody* fbody, const FunctionLibraryDefinition& flib_def,
4106     mlir::MLIRContext* context) {
4107   tensorflow::GraphDebugInfo dummy_debug_info;
4108   tensorflow::GraphImportConfig specs;
4109   specs.graph_func_name = fbody->fdef.signature().name();
4110   specs.enable_shape_inference = false;
4111   specs.graph_as_function = true;
4112   for (const auto* control_ret_node : fbody->control_ret_nodes)
4113     specs.control_outputs.push_back(control_ret_node->name());
4114   std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
4115   return GraphDefImporter::Convert(context, *fbody->graph, dummy_debug_info,
4116                                    flib_def, specs, tf_name_to_mlir_name);
4117 }
4118 
ConvertSavedModelToMlir(SavedModelV2Bundle * saved_model,mlir::MLIRContext * context,absl::Span<std::string> exported_names,bool add_default_attributes)4119 StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(
4120     SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
4121     absl::Span<std::string> exported_names, bool add_default_attributes) {
4122   return SavedModelObjectGraphImporter::Convert(
4123       saved_model, exported_names, context, add_default_attributes);
4124 }
4125 
ConvertSavedModelV1ToMlir(const SavedModelBundle & saved_model,absl::Span<std::string> exported_names,mlir::MLIRContext * context,MLIRImportOptions options,bool lift_variables)4126 StatusOr<mlir::OwningModuleRef> ConvertSavedModelV1ToMlir(
4127     const SavedModelBundle& saved_model, absl::Span<std::string> exported_names,
4128     mlir::MLIRContext* context, MLIRImportOptions options,
4129     bool lift_variables) {
4130   absl::optional<absl::Span<const std::string>> optional_exported_names;
4131   // TODO(b/187062560): Change ConvertSavedModelV1ToMlir() to take an optional
4132   // `exported_names` so that it can be configured to import only restore/init
4133   // graphs.
4134   if (!exported_names.empty()) optional_exported_names = exported_names;
4135   return SavedModelSignatureDefImporter::Convert(
4136       saved_model, optional_exported_names, context, options, lift_variables);
4137 }
4138 
ConvertSavedModelV1ToMlirLite(const MetaGraphDef & meta_graph_def,const GraphDebugInfo & debug_info,absl::optional<absl::Span<const std::string>> exported_names,mlir::MLIRContext * context,MLIRImportOptions options)4139 StatusOr<mlir::OwningModuleRef> ConvertSavedModelV1ToMlirLite(
4140     const MetaGraphDef& meta_graph_def, const GraphDebugInfo& debug_info,
4141     absl::optional<absl::Span<const std::string>> exported_names,
4142     mlir::MLIRContext* context, MLIRImportOptions options) {
4143   TF_ASSIGN_OR_RETURN(auto input, SimpleSavedModelMLIRImportInput::Create(
4144                                       options, &meta_graph_def, debug_info));
4145   return ConvertSavedModelV1ToMlirLite(input, exported_names, context);
4146 }
4147 
ConvertSavedModelV1ToMlirLite(SavedModelMLIRImportInput & input,absl::optional<absl::Span<const std::string>> exported_names,mlir::MLIRContext * context)4148 StatusOr<mlir::OwningModuleRef> ConvertSavedModelV1ToMlirLite(
4149     SavedModelMLIRImportInput& input,
4150     absl::optional<absl::Span<const std::string>> exported_names,
4151     mlir::MLIRContext* context) {
4152   return SavedModelSignatureDefImporterLite::Convert(input, exported_names,
4153                                                      context);
4154 }
4155 
MlirModuleToString(mlir::ModuleOp module,mlir::OpPrintingFlags flags)4156 std::string MlirModuleToString(mlir::ModuleOp module,
4157                                mlir::OpPrintingFlags flags) {
4158   std::string txt_module;
4159   {
4160     llvm::raw_string_ostream os{txt_module};
4161     module.print(os, flags);
4162   }
4163   return txt_module;
4164 }
4165 
MlirModuleToString(mlir::ModuleOp module,bool show_debug_info)4166 std::string MlirModuleToString(mlir::ModuleOp module, bool show_debug_info) {
4167   mlir::OpPrintingFlags flags;
4168   if (show_debug_info) flags.enableDebugInfo();
4169   return MlirModuleToString(module, flags);
4170 }
4171 
4172 }  // namespace tensorflow
4173