• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
17 
18 #include <utility>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/container/inlined_vector.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/string_view.h"
25 #include "absl/types/optional.h"
26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/ADT/StringRef.h"
30 #include "llvm/Support/Casting.h"
31 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
32 #include "mlir/IR/Attributes.h"  // from @llvm-project
33 #include "mlir/IR/Builders.h"  // from @llvm-project
34 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
35 #include "mlir/IR/Identifier.h"  // from @llvm-project
36 #include "mlir/IR/Location.h"  // from @llvm-project
37 #include "mlir/IR/Operation.h"  // from @llvm-project
38 #include "mlir/IR/Types.h"  // from @llvm-project
39 #include "mlir/Pass/Pass.h"  // from @llvm-project
40 #include "mlir/Pass/PassManager.h"  // from @llvm-project
41 #include "mlir/Support/DebugStringHelper.h"  // from @llvm-project
42 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
43 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
45 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
46 #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
47 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
48 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
49 #include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h"
50 #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
51 #include "tensorflow/compiler/mlir/utils/name_utils.h"
52 #include "tensorflow/compiler/xla/status_macros.h"
53 #include "tensorflow/core/framework/graph.pb.h"
54 #include "tensorflow/core/framework/graph_to_functiondef.h"
55 #include "tensorflow/core/framework/node_def.pb.h"
56 #include "tensorflow/core/framework/node_def_util.h"
57 #include "tensorflow/core/framework/op.h"
58 #include "tensorflow/core/framework/types.pb.h"
59 #include "tensorflow/core/framework/versions.pb.h"
60 #include "tensorflow/core/graph/algorithm.h"
61 #include "tensorflow/core/graph/graph.h"
62 #include "tensorflow/core/graph/tensor_id.h"
63 #include "tensorflow/core/lib/core/errors.h"
64 #include "tensorflow/core/lib/core/status.h"
65 
66 namespace tensorflow {
67 using llvm::dyn_cast;
68 using llvm::isa;
69 using mlir::BlockArgument;
70 using mlir::Dialect;
71 using mlir::Operation;
72 using mlir::Value;
73 using stream_executor::port::StatusOr;
74 
75 namespace {
76 
77 constexpr char kInvalidExecutorGraphMsg[] =
78     "Functions must be of a single Graph with single op Islands: ";
79 
80 constexpr char kDeviceAttr[] = "tf.device";
81 constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id";
82 
83 // OpOrArgLocNameMapper that legalizes the returned name.
84 class LegalizedOpOrValLocNameMapper : public OpOrArgLocNameMapper {
85  private:
GetName(OpOrVal op_or_val)86   std::string GetName(OpOrVal op_or_val) override {
87     std::string name = OpOrArgLocNameMapper::GetName(op_or_val);
88     assert(!name.empty() && "expected non-empty name");
89     mlir::LegalizeNodeName(name);
90     return name;
91   }
92 };
93 
94 // Checks functions in module are of single tf_executor.graph and each
95 // tf_executor.island in tf_executor.graph only has a single op.
HasSingleGraphSingleOpIslandsFunctions(mlir::ModuleOp module)96 Status HasSingleGraphSingleOpIslandsFunctions(mlir::ModuleOp module) {
97   Status status = Status::OK();
98   module.walk([&](mlir::FuncOp function) {
99     if (!llvm::hasSingleElement(function)) {
100       status = errors::FailedPrecondition(
101           kInvalidExecutorGraphMsg,
102           "only single block functions are supported.");
103       return mlir::WalkResult::interrupt();
104     }
105 
106     auto block = function.front().without_terminator();
107     auto graph = llvm::dyn_cast<mlir::tf_executor::GraphOp>(block.begin());
108     if (!graph) {
109       status = errors::FailedPrecondition(
110           kInvalidExecutorGraphMsg,
111           "first op in function is not a tf_executor.graph.");
112       return mlir::WalkResult::interrupt();
113     }
114 
115     if (!hasSingleElement(block)) {
116       status = errors::FailedPrecondition(
117           kInvalidExecutorGraphMsg,
118           "function does not only contain a single tf_executor.graph.");
119       return mlir::WalkResult::interrupt();
120     }
121 
122     for (Operation& op : graph.GetBody()) {
123       auto island = llvm::dyn_cast<mlir::tf_executor::IslandOp>(op);
124       if (!island) continue;
125 
126       if (!island.WrapsSingleOp()) {
127         status = errors::FailedPrecondition(
128             kInvalidExecutorGraphMsg,
129             "tf_executor.island must perfectly wrap a single op.");
130         return mlir::WalkResult::interrupt();
131       }
132     }
133 
134     return mlir::WalkResult::advance();
135   });
136 
137   return status;
138 }
139 
140 // Finds first inner op if `op` is a tf_executor.island. Otherwise `op` is
141 // returned.
GetIslandInnerOpOrSelf(mlir::Operation * op)142 Operation* GetIslandInnerOpOrSelf(mlir::Operation* op) {
143   auto island = llvm::dyn_cast<mlir::tf_executor::IslandOp>(op);
144   if (island) return &island.GetBody().front();
145   return op;
146 }
147 
148 // Stateful helper class to export a function into a Graph.
149 class Exporter {
150  public:
151   // Converts the given Module to a Graph. The given module should only contain
152   // one entry function, which is identified by name "main". This entry function
153   // is converted to the base of the graph graph. The rest of the functions are
154   // converted to the library functions in that graph.
155   static Status Convert(mlir::ModuleOp module, const GraphExportConfig& configs,
156                         std::unique_ptr<Graph>* graph,
157                         FunctionLibraryDefinition* flib_def,
158                         absl::flat_hash_set<Node*>* control_ret_nodes);
159 
160   // Converts a given FuncOp to a FunctionDef and adds it to the function
161   // definition library
162   static Status ConvertLibFunction(const GraphExportConfig& configs,
163                                    const Dialect* tf_dialect,
164                                    mlir::FuncOp function,
165                                    FunctionDefLibrary* flib);
166   // Converts the given FuncOp to a Graph. The arguments and returns of
167   // function are added to the graph with special op names kArgOp and kRetOp.
168   // Later on, this graph can be converted a function definition and added to
169   // another graph.
170   static StatusOr<std::unique_ptr<Graph>> Convert(
171       const GraphExportConfig& configs, const Dialect* tf_dialect,
172       mlir::FuncOp function, FunctionDefLibrary* flib,
173       absl::flat_hash_set<Node*>* control_ret_nodes);
174 
175  private:
Exporter(Graph * graph,const Dialect * tf_dialect)176   explicit Exporter(Graph* graph, const Dialect* tf_dialect)
177       : graph_(graph), tf_dialect_(tf_dialect) {}
178 
179   Status AddArgumentNode(BlockArgument arg, unsigned index,
180                          llvm::StringRef name);
181   Status AddFetchNode(mlir::FuncOp function, mlir::tf_executor::FetchOp fetch,
182                       llvm::ArrayRef<llvm::StringRef> names);
183   Status AddInstructionNode(Operation* inst);
184   Status AddEdge(Operation* inst);
185 
186   StatusOr<std::unique_ptr<NodeDef>> GetArgumentNode(BlockArgument arg,
187                                                      unsigned index,
188                                                      llvm::StringRef name);
189   StatusOr<std::unique_ptr<NodeDef>> GetReturnNode(mlir::FuncOp function,
190                                                    Value operand,
191                                                    unsigned index,
192                                                    llvm::StringRef name);
193   Status GetControlRetNodes(mlir::tf_executor::FetchOp fetch,
194                             absl::flat_hash_set<Node*>* control_ret_nodes);
195   // Adds one edge between src_node and dst_node. If it is not a control edge,
196   // an index is used to find out the right operand of the dst_node.
197   Status AddEdgeBetweenNodes(Value src, Node* dst_node, unsigned dst_index);
198 
199   Graph* graph_;
200   LegalizedOpOrValLocNameMapper op_to_name_;
201   absl::flat_hash_map<Operation*, Node*> nodes_;
202   llvm::DenseMap<BlockArgument, Node*> args_;
203   // One single return operation can return multiple results, and each of them
204   // will be converted to one node in the graph.
205   typedef absl::InlinedVector<Node*, 4> NodeVector;
206   absl::flat_hash_map<Operation*, NodeVector> returns_;
207   const mlir::Dialect* tf_dialect_;
208 };
209 
GetArgumentNode(BlockArgument arg,unsigned index,llvm::StringRef name)210 StatusOr<std::unique_ptr<NodeDef>> Exporter::GetArgumentNode(
211     BlockArgument arg, unsigned index, llvm::StringRef name) {
212   auto func = arg.getParentRegion()->getParentOfType<mlir::FuncOp>();
213 
214   auto node_def = absl::make_unique<NodeDef>();
215   if (!name.empty())
216     node_def->set_name(name.str());
217   else
218     node_def->set_name(
219         std::string(op_to_name_.GetUniqueName(func.getName().str())));
220 
221   node_def->set_op(FunctionLibraryDefinition::kArgOp);
222 
223   TF_RETURN_IF_ERROR(SetShapeAttribute("_output_shapes",
224                                        arg.getType().cast<mlir::ShapedType>(),
225                                        node_def->mutable_attr()));
226 
227   DataType dtype;
228   TF_RETURN_IF_ERROR(ConvertToDataType(
229       arg.getType().cast<mlir::TensorType>().getElementType(), &dtype));
230   AttrValue type_attr;
231   type_attr.set_type(dtype);
232   (*node_def->mutable_attr())["T"] = type_attr;
233 
234   AttrValue index_attr;
235   index_attr.set_i(index);
236   (*node_def->mutable_attr())["index"] = index_attr;
237 
238   if (auto device_attr =
239           func.getArgAttrOfType<mlir::StringAttr>(index, kDeviceAttr))
240     *node_def->mutable_device() = device_attr.getValue().str();
241 
242   llvm::ArrayRef<mlir::NamedAttribute> func_arg_i_attrs =
243       func.getArgAttrs(index);
244   absl::flat_hash_set<absl::string_view> attrs_to_ignore = {kDeviceAttr};
245   TF_RETURN_IF_ERROR(ConvertAttributes(func_arg_i_attrs, attrs_to_ignore,
246                                        /*remove_ref_type=*/false,
247                                        node_def->mutable_attr()));
248 
249   return node_def;
250 }
251 
GetReturnNode(mlir::FuncOp function,Value operand,unsigned index,llvm::StringRef name)252 StatusOr<std::unique_ptr<NodeDef>> Exporter::GetReturnNode(
253     mlir::FuncOp function, Value operand, unsigned index,
254     llvm::StringRef name) {
255   auto node_def = absl::make_unique<NodeDef>();
256   if (!name.empty())
257     node_def->set_name(name.str());
258   else
259     node_def->set_name(
260         std::string(op_to_name_.GetUniqueName(function.getName().str())));
261 
262   node_def->set_op(FunctionLibraryDefinition::kRetOp);
263   DataType dtype;
264   TF_RETURN_IF_ERROR(ConvertToDataType(
265       operand.getType().cast<mlir::TensorType>().getElementType(), &dtype));
266   AttrValue type_attr;
267   type_attr.set_type(dtype);
268   (*node_def->mutable_attr())["T"] = type_attr;
269   AttrValue index_attr;
270   index_attr.set_i(index);
271   (*node_def->mutable_attr())["index"] = index_attr;
272 
273   if (auto device_attr =
274           function.getResultAttrOfType<mlir::StringAttr>(index, kDeviceAttr))
275     *node_def->mutable_device() = device_attr.getValue().str();
276 
277   llvm::ArrayRef<mlir::NamedAttribute> func_res_i_attrs =
278       function.getResultAttrs(index);
279   absl::flat_hash_set<absl::string_view> attrs_to_ignore = {kDeviceAttr};
280   TF_RETURN_IF_ERROR(ConvertAttributes(func_res_i_attrs, attrs_to_ignore,
281                                        /*remove_ref_type=*/false,
282                                        node_def->mutable_attr()));
283 
284   return node_def;
285 }
286 
AddEdgeBetweenNodes(Value src,Node * dst_node,unsigned dst_index)287 Status Exporter::AddEdgeBetweenNodes(Value src, Node* dst_node,
288                                      unsigned dst_index) {
289   if (auto input_result = src.dyn_cast<mlir::OpResult>()) {
290     auto* input_inst = GetIslandInnerOpOrSelf(input_result.getOwner());
291     // Replaces the input node with NextIteration sink if it is a NextIteration
292     // source.
293     if (auto next_iter_source =
294             llvm::dyn_cast<mlir::tf_executor::NextIterationSourceOp>(
295                 input_inst))
296       input_inst = next_iter_source.GetSink();
297 
298     auto node_it = nodes_.find(input_inst);
299     TF_RET_CHECK(node_it != nodes_.end())
300         << "Use of OpResult encountered before def!";
301     if (input_result.getType().isa<mlir::tf_executor::ControlType>()) {
302       graph_->AddControlEdge(node_it->second, dst_node);
303     } else {
304       graph_->AddEdge(node_it->second, input_result.getResultNumber(), dst_node,
305                       dst_index);
306     }
307     return Status::OK();
308   }
309 
310   auto input_arg = src.cast<BlockArgument>();
311   auto input_node_it = args_.find(input_arg);
312   TF_RET_CHECK(input_node_it != args_.end())
313       << "Use of BlockArgument encounted before def!";
314   // For argument, there is only one result output, so the index is always 0.
315   graph_->AddEdge(input_node_it->second, 0, dst_node, dst_index);
316   return Status::OK();
317 }
318 
AddEdge(Operation * inst)319 Status Exporter::AddEdge(Operation* inst) {
320   // For tf_executor.fetch, add only its data edges. Control edges are captured
321   // later.
322   if (auto fetch = llvm::dyn_cast<mlir::tf_executor::FetchOp>(inst)) {
323     for (auto operand_and_idx : llvm::enumerate(fetch.getOperands())) {
324       Value operand = operand_and_idx.value();
325       if (operand.getType().isa<mlir::tf_executor::ControlType>()) break;
326 
327       auto* dst_node = returns_[fetch][operand_and_idx.index()];
328       TF_RETURN_IF_ERROR(AddEdgeBetweenNodes(operand, dst_node, 0));
329     }
330 
331     return Status::OK();
332   }
333 
334   // For tf_executor.NextIteration.Sink, skip its token operand and add data and
335   // control edges with their index offset by 1.
336   if (auto next_iter_sink =
337           llvm::dyn_cast<mlir::tf_executor::NextIterationSinkOp>(inst)) {
338     auto* dst_node = nodes_[inst];
339     TF_RETURN_IF_ERROR(
340         AddEdgeBetweenNodes(next_iter_sink.input(), dst_node, 0));
341     for (auto control_and_idx : llvm::enumerate(next_iter_sink.controlInputs()))
342       TF_RETURN_IF_ERROR(AddEdgeBetweenNodes(control_and_idx.value(), dst_node,
343                                              control_and_idx.index() + 1));
344 
345     return Status::OK();
346   }
347 
348   // For tf_executor.NextIteration.Source, op can be skipped as it is assumed
349   // there are no operands.
350   if (llvm::isa<mlir::tf_executor::NextIterationSourceOp>(inst)) {
351     assert(inst->getNumOperands() == 0);
352     return Status::OK();
353   }
354 
355   Operation* op = GetIslandInnerOpOrSelf(inst);
356   auto* dst_node = nodes_[op];
357   int operand_offset = 0;
358   // For tf_executor.island, add data edges from its wrapped op before control
359   // edges.
360   if (auto island = llvm::dyn_cast<mlir::tf_executor::IslandOp>(inst)) {
361     for (auto operand_and_idx : llvm::enumerate(op->getOperands()))
362       TF_RETURN_IF_ERROR(AddEdgeBetweenNodes(operand_and_idx.value(), dst_node,
363                                              operand_and_idx.index()));
364 
365     operand_offset = op->getNumOperands();
366   }
367 
368   // For all other ops (including tf_executor.island), add remaining edges.
369   for (auto operand_and_idx : llvm::enumerate(inst->getOperands()))
370     TF_RETURN_IF_ERROR(
371         AddEdgeBetweenNodes(operand_and_idx.value(), dst_node,
372                             operand_and_idx.index() + operand_offset));
373 
374   return Status::OK();
375 }
376 
AddInstructionNode(Operation * inst)377 Status Exporter::AddInstructionNode(Operation* inst) {
378   std::unique_ptr<NodeDef> node_def;
379   auto name = op_to_name_.GetUniqueName(inst);
380   // Convert registered TF ops to NodeDef. Only registered ops are handled to
381   // ensure that PopulateDerivedAttrs adds the correct attributes.
382   TF_ASSIGN_OR_RETURN(node_def,
383                       ConvertTFDialectOpToNodeDef(
384                           inst, name, /*ignore_unregistered_attrs=*/false));
385 
386   Status status;
387   Node* node = graph_->AddNode(*node_def, &status);
388   TF_RETURN_IF_ERROR(status);
389   DCHECK(node != nullptr);
390   nodes_[inst] = node;
391   return Status::OK();
392 }
393 
IsEntryFunctionArg(BlockArgument arg)394 bool IsEntryFunctionArg(BlockArgument arg) {
395   return arg.getParentRegion()->getParentOfType<mlir::FuncOp>().getName() ==
396          "main";
397 }
398 
399 // Creates argument nodes from Block argument. If a name is supplied, that
400 // name will be used instead of generating a unique name.
AddArgumentNode(BlockArgument arg,unsigned index,llvm::StringRef name)401 Status Exporter::AddArgumentNode(BlockArgument arg, unsigned index,
402                                  llvm::StringRef name) {
403   TF_ASSIGN_OR_RETURN(auto node_def, GetArgumentNode(arg, index, name));
404   Status status;
405   Node* node = graph_->AddNode(*node_def, &status);
406   TF_RETURN_IF_ERROR(status);
407   args_[arg] = node;
408   return Status::OK();
409 }
410 
411 // Creates return nodes per operand of a FetchOp. If names is supplied, those
412 // names will be used per node in order instead of generating a unique name.
AddFetchNode(mlir::FuncOp function,mlir::tf_executor::FetchOp fetch,llvm::ArrayRef<llvm::StringRef> names)413 Status Exporter::AddFetchNode(mlir::FuncOp function,
414                               mlir::tf_executor::FetchOp fetch,
415                               llvm::ArrayRef<llvm::StringRef> names) {
416   Status status;
417   auto& return_nodes = returns_[fetch];
418   for (auto operand_and_idx : llvm::enumerate(fetch.getOperands())) {
419     if (operand_and_idx.value().getType().isa<mlir::tf_executor::ControlType>())
420       break;
421 
422     TF_ASSIGN_OR_RETURN(
423         auto node_def,
424         GetReturnNode(function, operand_and_idx.value(),
425                       operand_and_idx.index(),
426                       names.empty() ? "" : names[operand_and_idx.index()]));
427     Node* node = graph_->AddNode(*node_def, &status);
428     TF_RETURN_IF_ERROR(status);
429     return_nodes.push_back(node);
430   }
431   return Status::OK();
432 }
433 
434 // Collects control ret Nodes based on tf_executor.graph's associated
435 // tf_executor.fetch control inputs.
GetControlRetNodes(mlir::tf_executor::FetchOp fetch,absl::flat_hash_set<Node * > * control_ret_nodes)436 Status Exporter::GetControlRetNodes(
437     mlir::tf_executor::FetchOp fetch,
438     absl::flat_hash_set<Node*>* control_ret_nodes) {
439   for (Value fetch_operand : fetch.getOperands()) {
440     if (fetch_operand.getType().isa<mlir::tf_executor::ControlType>()) {
441       Operation* defining_op =
442           GetIslandInnerOpOrSelf(fetch_operand.getDefiningOp());
443       auto node_it = nodes_.find(defining_op);
444       TF_RET_CHECK(node_it != nodes_.end());
445       control_ret_nodes->insert(node_it->second);
446     }
447   }
448   return Status::OK();
449 }
450 
Convert(const GraphExportConfig & configs,const Dialect * tf_dialect,mlir::FuncOp function,FunctionDefLibrary * flib,absl::flat_hash_set<Node * > * control_ret_nodes)451 StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
452     const GraphExportConfig& configs, const Dialect* tf_dialect,
453     mlir::FuncOp function, FunctionDefLibrary* flib,
454     absl::flat_hash_set<Node*>* control_ret_nodes) {
455   mlir::Block& block = function.front();
456 
457   // Extract input & output names if set.
458   llvm::SmallVector<llvm::StringRef, 2> input_names;
459   llvm::SmallVector<llvm::StringRef, 2> output_names;
460   auto dict_attr =
461       function->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
462   if (dict_attr) {
463     TF_RET_CHECK(dict_attr.get("inputs").isa<mlir::StringAttr>())
464         << "inputs missing in entry function attribute";
465     TF_RET_CHECK(dict_attr.get("outputs").isa<mlir::StringAttr>())
466         << "outputs missing in entry function attribute";
467     dict_attr.get("inputs").cast<mlir::StringAttr>().getValue().split(
468         input_names, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false);
469     dict_attr.get("outputs").cast<mlir::StringAttr>().getValue().split(
470         output_names, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false);
471   }
472 
473   auto graph = absl::make_unique<Graph>(OpRegistry::Global());
474 
475   // Extract version info.
476   VersionDef versions;
477   auto module = function->getParentOfType<mlir::ModuleOp>();
478   if (mlir::succeeded(ExtractTfVersions(module, &versions))) {
479     graph->set_versions(versions);
480   }
481 
482   // We have to add the function library here, so a custom operation, which is
483   // defined in the function library can be added to the graph.
484   TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(*flib));
485   Exporter exporter(graph.get(), tf_dialect);
486 
487   auto graph_op = llvm::cast<mlir::tf_executor::GraphOp>(block.front());
488 
489   // Set input and output names and increment the use counter for them to help
490   // generate unique names.
491   if (!output_names.empty()) {
492     const int num_data_results = graph_op.getNumResults();
493     const int64 output_names_size = output_names.size();
494     TF_RET_CHECK(output_names_size == num_data_results)
495         << "output names (" << output_names.size()
496         << ") != terminator operands (" << num_data_results << ")";
497     llvm::DenseMap<Operation*, llvm::StringRef> output_op_to_name;
498     llvm::StringMap<Operation*> name_to_op;
499     for (const auto& it : llvm::enumerate(graph_op.GetFetch().getOperands())) {
500       // Skip control rets.
501       const int64 index = it.index();
502       if (index >= num_data_results) break;
503       // TODO(jpienaar): If there is a result index specified, ensure only one
504       // and that it matches the result index of the op.
505       std::string name(output_names[index]);
506       auto tensor_id = ParseTensorName(name);
507       std::string tensor_id_node(tensor_id.node());
508       assert(!tensor_id_node.empty() && "expected non-empty name");
509       mlir::LegalizeNodeName(tensor_id_node);
510 
511       // Ensure name does not get reused.
512       (void)exporter.op_to_name_.GetUniqueName(tensor_id_node);
513     }
514   }
515 
516   if (!input_names.empty()) {
517     TF_RET_CHECK(input_names.size() == block.getNumArguments());
518     for (const auto& it : llvm::enumerate(function.getArguments())) {
519       // TODO(lyandy): Update when changing feed/fetch import.
520       std::string name(input_names[it.index()]);
521       assert(!name.empty() && "expected non-empty name");
522       mlir::LegalizeNodeName(name);
523       auto tensor_id = ParseTensorName(name);
524       TF_RET_CHECK(tensor_id.index() == 0)
525           << "input port designation not supported";
526       // Only assign user of argument the input name if the main graph did not
527       // have its _Arg nodes lifted into the functions arguments.
528       // Ensure name does not get reused.
529       (void)exporter.op_to_name_.GetUniqueName(name);
530     }
531   }
532 
533   // Adds nodes for basic block (function) arguments.
534   for (auto it : llvm::enumerate(block.getArguments())) {
535     int index = it.index();
536     auto arg = it.value();
537     mlir::Type type = arg.getType();
538     if (!type.isa<mlir::TensorType>()) {
539       return errors::InvalidArgument(
540           "FuncOps arguments must have tensor types. Found ",
541           mlir::debugString(type), " in function ", function.getName().str());
542     }
543 
544     TF_RETURN_IF_ERROR(exporter.AddArgumentNode(
545         arg, index, !input_names.empty() ? input_names[index] : ""));
546   }
547 
548   auto convert_called_function = [&](llvm::StringRef name) {
549     auto func =
550         function->getParentOfType<mlir::ModuleOp>().lookupSymbol<mlir::FuncOp>(
551             name);
552     if (func != nullptr) {
553       TF_RETURN_IF_ERROR(ConvertLibFunction(configs, tf_dialect, func, flib));
554       TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(*flib));
555     }
556     return Status::OK();
557   };
558 
559   // Adds nodes for operations.
560   for (Operation& inst : graph_op.GetBody()) {
561     for (auto type : inst.getResultTypes())
562       if (!type.isa<mlir::TensorType, mlir::tf_executor::ControlType,
563                     mlir::tf_executor::TokenType>())
564         return errors::InvalidArgument(
565             "Values must be of tensor type, TensorFlow control type, or "
566             "TensorFlow token type. Found ",
567             mlir::debugString(type));
568 
569     if (llvm::isa<mlir::tf_executor::NextIterationSourceOp>(inst)) {
570       // Skip tf_executor.NextIteration.Source as associated
571       // tf_executor.NextIteration.Sink will be used instead.
572       continue;
573     } else if (auto fetch = llvm::dyn_cast<mlir::tf_executor::FetchOp>(inst)) {
574       TF_RETURN_IF_ERROR(exporter.AddFetchNode(function, fetch, output_names));
575     } else if (auto island =
576                    llvm::dyn_cast<mlir::tf_executor::IslandOp>(inst)) {
577       Operation& inner_op = island.GetBody().front();
578       auto op_name = GetTensorFlowOpName(inner_op.getName().getStringRef());
579       if (op_name.ok()) {
580         // If it is TF Control dialect specific op, look up custom operation
581         // in the module and first convert that, then add it to function
582         // definition library
583         // TODO(prakalps): If two functions have cyclic dependence, this will
584         // introduce an infinite loop.
585         TF_RETURN_IF_ERROR(convert_called_function(op_name.ValueOrDie().str()));
586       }
587 
588       if (IsLegacyCallInstruction(&inner_op)) {
589         TF_RETURN_IF_ERROR(convert_called_function(
590             inner_op.getAttrOfType<mlir::SymbolRefAttr>("f")
591                 .getLeafReference()));
592       }
593 
594       TF_RETURN_IF_ERROR(exporter.AddInstructionNode(&inner_op));
595     } else {
596       TF_RETURN_IF_ERROR(exporter.AddInstructionNode(&inst));
597     }
598   }
599   // Adds edges between the argument, operation and return nodes.
600   for (Operation& inst : graph_op.GetBody()) {
601     TF_RETURN_IF_ERROR(exporter.AddEdge(&inst));
602   }
603   // Fixes the edges between the inserted nodes and special "_SOURCE" and
604   // "_SINK".
605   FixupSourceAndSinkEdges(graph.get());
606 
607   TF_RETURN_IF_ERROR(
608       exporter.GetControlRetNodes(graph_op.GetFetch(), control_ret_nodes));
609 
610   return graph;
611 }
612 
ConvertLibFunction(const GraphExportConfig & configs,const Dialect * tf_dialect,mlir::FuncOp function,FunctionDefLibrary * flib)613 Status Exporter::ConvertLibFunction(const GraphExportConfig& configs,
614                                     const Dialect* tf_dialect,
615                                     mlir::FuncOp function,
616                                     FunctionDefLibrary* flib) {
617   // First look for the function in the current function library. If found,
618   // nothing needs to be done.
619   OpRegistry empty_registry;
620   FunctionLibraryDefinition flib_def(&empty_registry, *flib);
621   auto function_name = function.getName().str();
622   if (flib_def.Find(function_name)) return Status::OK();
623 
624   // TODO(fengliuai): use a small flib_def to reduce overhead
625   absl::flat_hash_set<Node*> control_ret_nodes;
626   TF_ASSIGN_OR_RETURN(auto sub_graph,
627                       Exporter::Convert(configs, tf_dialect, function, flib,
628                                         &control_ret_nodes));
629   const auto control_ret = [&](const Node* n) -> absl::optional<string> {
630     return control_ret_nodes.contains(n)
631                ? absl::make_optional<string>(n->name())
632                : absl::nullopt;
633   };
634   FunctionDef func_def;
635   TF_RETURN_IF_ERROR(
636       GraphToFunctionDef(*sub_graph, function_name, control_ret, &func_def));
637 
638   // The node defs in FunctionDef might contain debug info which was added
639   // by the GraphToFunctionDef method. We should remove it if we don't want
640   // to export them to avoid failing the roundtrip test.
641   if (!configs.export_debug_info) {
642     for (auto& node_def : *func_def.mutable_node_def()) {
643       node_def.clear_experimental_debug_info();
644     }
645   }
646 
647   // Checks for gradient attribute. If present converts the gradient function
648   // and populates the GradientDef.
649   auto grad_string = mlir::TF::TensorFlowDialect::GetGradientAttrName();
650   if (auto attr =
651           function->getAttrOfType<mlir::FlatSymbolRefAttr>(grad_string)) {
652     auto grad_func =
653         function->getParentOfType<mlir::ModuleOp>().lookupSymbol<mlir::FuncOp>(
654             attr.getValue());
655     TF_RETURN_IF_ERROR(
656         ConvertLibFunction(configs, tf_dialect, grad_func, flib));
657     GradientDef grad;
658     grad.set_function_name(function_name);
659     grad.set_gradient_func(grad_func.getName().str());
660     *flib->add_gradient() = grad;
661   }
662 
663   auto stateful_string = mlir::TF::TensorFlowDialect::GetStatefulAttrName();
664   if (auto attr = function->getAttrOfType<mlir::UnitAttr>(stateful_string)) {
665     func_def.mutable_signature()->set_is_stateful(true);
666   }
667 
668   // Ignore the gradient and is_stateful attribute on the function as they have
669   // been handled above.
670   absl::flat_hash_set<absl::string_view> attrs_to_ignore = {
671       grad_string.data(), stateful_string.data()};
672   llvm::SmallVector<mlir::NamedAttribute, 8> funcAttrs(
673       function->getDialectAttrs());
674   TF_RETURN_IF_ERROR(ConvertAttributes(funcAttrs, attrs_to_ignore,
675                                        /*remove_ref_type=*/false,
676                                        func_def.mutable_attr()));
677 
678   for (int i = 0, e = function.getNumArguments(); i < e; ++i) {
679     if (auto resource_arg_unique_id_attr =
680             function.getArgAttrOfType<mlir::IntegerAttr>(
681                 i, kResourceArgUniqueIdAttr)) {
682       (*func_def.mutable_resource_arg_unique_id())[i] =
683           resource_arg_unique_id_attr.getInt();
684     }
685 
686     llvm::ArrayRef<mlir::NamedAttribute> func_arg_i_attrs =
687         function.getArgAttrs(i);
688     if (func_arg_i_attrs.empty()) continue;
689     absl::flat_hash_set<absl::string_view> attrs_to_ignore = {
690         kDeviceAttr, kResourceArgUniqueIdAttr};
691     FunctionDef::ArgAttrs func_def_arg_i_attrs;
692     TF_RETURN_IF_ERROR(ConvertAttributes(func_arg_i_attrs, attrs_to_ignore,
693                                          /*remove_ref_type=*/false,
694                                          func_def_arg_i_attrs.mutable_attr()));
695     if (func_def_arg_i_attrs.attr().empty()) continue;
696     (*func_def.mutable_arg_attr())[i] = std::move(func_def_arg_i_attrs);
697   }
698 
699   (*flib->add_function()) = std::move(func_def);
700   return Status::OK();
701 }
702 
Convert(mlir::ModuleOp module,const GraphExportConfig & configs,std::unique_ptr<Graph> * graph,FunctionLibraryDefinition * flib_def,absl::flat_hash_set<Node * > * control_ret_nodes)703 Status Exporter::Convert(mlir::ModuleOp module,
704                          const GraphExportConfig& configs,
705                          std::unique_ptr<Graph>* graph,
706                          FunctionLibraryDefinition* flib_def,
707                          absl::flat_hash_set<Node*>* control_ret_nodes) {
708   mlir::Identifier entry_func_id =
709       mlir::Identifier::get("main", module.getContext());
710   absl::optional<mlir::FuncOp> entry_func;
711   FunctionDefLibrary flib;
712   auto tf_dialect = module.getContext()->getLoadedDialect("tf");
713   for (auto function : module.getOps<mlir::FuncOp>()) {
714     if (function.isExternal())
715       return errors::FailedPrecondition("External functions not supported");
716 
717     if (function.getName() == entry_func_id) {
718       entry_func.emplace(function);
719     } else {
720       TF_RETURN_IF_ERROR(
721           ConvertLibFunction(configs, tf_dialect, function, &flib));
722     }
723   }
724 
725   if (!entry_func.has_value())
726     return errors::FailedPrecondition("entry function `main` must be present");
727 
728   // Updates the graph and the function library definition.
729   TF_ASSIGN_OR_RETURN(
730       *graph, Exporter::Convert(configs, tf_dialect, entry_func.value(), &flib,
731                                 control_ret_nodes));
732   for (auto& func_def : flib.function()) {
733     TF_RETURN_IF_ERROR(flib_def->AddFunctionDef(func_def));
734   }
735   for (auto& grad_def : flib.gradient()) {
736     TF_RETURN_IF_ERROR(flib_def->AddGradientDef(grad_def));
737   }
738   return Status::OK();
739 }
740 }  // namespace
741 
ConvertMlirToGraph(mlir::ModuleOp module,const GraphExportConfig & configs,std::unique_ptr<Graph> * graph,FunctionLibraryDefinition * flib_def,absl::flat_hash_set<Node * > * control_ret_nodes)742 Status ConvertMlirToGraph(mlir::ModuleOp module,
743                           const GraphExportConfig& configs,
744                           std::unique_ptr<Graph>* graph,
745                           FunctionLibraryDefinition* flib_def,
746                           absl::flat_hash_set<Node*>* control_ret_nodes) {
747   TF_RETURN_IF_ERROR(HasSingleGraphSingleOpIslandsFunctions(module));
748   return Exporter::Convert(module, configs, graph, flib_def, control_ret_nodes);
749 }
750 
ConvertMlirToGraph(mlir::ModuleOp module,const GraphExportConfig & configs,std::unique_ptr<Graph> * graph,FunctionLibraryDefinition * flib_def)751 Status ConvertMlirToGraph(mlir::ModuleOp module,
752                           const GraphExportConfig& configs,
753                           std::unique_ptr<Graph>* graph,
754                           FunctionLibraryDefinition* flib_def) {
755   absl::flat_hash_set<Node*> control_ret_nodes;
756   return ConvertMlirToGraph(module, configs, graph, flib_def,
757                             &control_ret_nodes);
758 }
759 
ConvertMlirToGraphdef(mlir::ModuleOp module,const GraphExportConfig & configs)760 StatusOr<std::unique_ptr<GraphDef>> ConvertMlirToGraphdef(
761     mlir::ModuleOp module, const GraphExportConfig& configs) {
762   FunctionLibraryDefinition flib_def(OpRegistry::Global(),
763                                      FunctionDefLibrary());
764   auto graph = absl::make_unique<Graph>(flib_def);
765   TF_RETURN_IF_ERROR(ConvertMlirToGraph(module, configs, &graph, &flib_def));
766   auto graphdef = absl::make_unique<GraphDef>();
767   graph->ToGraphDef(graphdef.get());
768   if (!configs.export_library) graphdef->clear_library();
769   if (!configs.export_shapes) {
770     for (auto& node_def : *graphdef->mutable_node()) {
771       node_def.mutable_attr()->erase("shape");
772     }
773   }
774   if (!configs.export_debug_info) {
775     for (auto& node_def : *graphdef->mutable_node()) {
776       node_def.clear_experimental_debug_info();
777     }
778   }
779   return graphdef;
780 }
781 
ConvertMlirFunctionToFunctionLibraryDef(mlir::FuncOp func,const GraphExportConfig & configs,FunctionDef * function_def)782 stream_executor::port::Status ConvertMlirFunctionToFunctionLibraryDef(
783     mlir::FuncOp func, const GraphExportConfig& configs,
784     FunctionDef* function_def) {
785   Dialect* tf_dialect = func.getContext()->getLoadedDialect("tf");
786   FunctionDefLibrary flib;
787   TF_RETURN_IF_ERROR(
788       Exporter::ConvertLibFunction(configs, tf_dialect, func, &flib));
789   for (auto& func_def : flib.function()) {
790     if (func_def.signature().name() == func.getName()) {
791       *function_def = func_def;
792       return Status::OK();
793     }
794   }
795   return errors::InvalidArgument(
796       "Function couldn't be found in the FunctionDefLibrary after converting "
797       "from MLIR");
798 }
799 
800 }  // namespace tensorflow
801