• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
17 
18 #include <utility>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/container/inlined_vector.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/string_view.h"
25 #include "absl/types/optional.h"
26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/ADT/StringRef.h"
30 #include "llvm/Support/Casting.h"
31 #include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
32 #include "mlir/IR/Attributes.h"  // TF:llvm-project
33 #include "mlir/IR/Builders.h"  // TF:llvm-project
34 #include "mlir/IR/Function.h"  // TF:llvm-project
35 #include "mlir/IR/Identifier.h"  // TF:llvm-project
36 #include "mlir/IR/Location.h"  // TF:llvm-project
37 #include "mlir/IR/Module.h"  // TF:llvm-project
38 #include "mlir/IR/Operation.h"  // TF:llvm-project
39 #include "mlir/IR/Types.h"  // TF:llvm-project
40 #include "mlir/Pass/Pass.h"  // TF:llvm-project
41 #include "mlir/Pass/PassManager.h"  // TF:llvm-project
42 #include "mlir/Support/DebugStringHelper.h"  // TF:llvm-project
43 #include "mlir/Support/LogicalResult.h"  // TF:llvm-project
44 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
45 #include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h"
46 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
47 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
48 #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
49 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
50 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
51 #include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h"
52 #include "tensorflow/compiler/xla/status_macros.h"
53 #include "tensorflow/core/framework/graph.pb.h"
54 #include "tensorflow/core/framework/graph_to_functiondef.h"
55 #include "tensorflow/core/framework/node_def.pb.h"
56 #include "tensorflow/core/framework/node_def_util.h"
57 #include "tensorflow/core/framework/op.h"
58 #include "tensorflow/core/framework/types.pb.h"
59 #include "tensorflow/core/framework/versions.pb.h"
60 #include "tensorflow/core/graph/algorithm.h"
61 #include "tensorflow/core/graph/graph.h"
62 #include "tensorflow/core/graph/tensor_id.h"
63 #include "tensorflow/core/lib/core/errors.h"
64 #include "tensorflow/core/lib/core/status.h"
65 
66 namespace tensorflow {
67 using llvm::dyn_cast;
68 using llvm::isa;
69 using mlir::BlockArgument;
70 using mlir::Dialect;
71 using mlir::Operation;
72 using mlir::OperationState;
73 using mlir::Value;
74 using stream_executor::port::StatusOr;
75 
76 namespace {
77 
78 constexpr char kInvalidExecutorGraphMsg[] =
79     "Functions must be of a single Graph with single op Islands: ";
80 
IsLegalChar(char c,bool first_char)81 bool IsLegalChar(char c, bool first_char) {
82   if (isalpha(c)) return true;
83   if (isdigit(c)) return true;
84   if (c == '.') return true;
85   if (c == '_') return true;
86 
87   // First character of a node name can only be a letter, digit, dot or
88   // underscore.
89   if (first_char) return false;
90 
91   if (c == '/') return true;
92   if (c == '-') return true;
93 
94   return false;
95 }
96 
97 // Convert characters in name that are considered illegal in TensorFlow Node
98 // name to '.'.
LegalizeNodeName(llvm::StringRef name)99 std::string LegalizeNodeName(llvm::StringRef name) {
100   assert(!name.empty() && "expected non-empty name");
101 
102   std::string legalized_name;
103   bool first = true;
104   for (auto c : name) {
105     if (IsLegalChar(c, first)) {
106       legalized_name += c;
107     } else {
108       legalized_name += '.';
109     }
110     first = false;
111   }
112 
113   return legalized_name;
114 }
115 
116 // OpOrArgLocNameMapper that legalizes the returned name.
117 class LegalizedOpOrValLocNameMapper : public OpOrArgLocNameMapper {
118  private:
GetName(OpOrVal op_or_val)119   std::string GetName(OpOrVal op_or_val) override {
120     return LegalizeNodeName(OpOrArgLocNameMapper::GetName(op_or_val));
121   }
122 };
123 
124 // Checks functions in module are of single tf_executor.graph and each
125 // tf_executor.island in tf_executor.graph only has a single op.
HasSingleGraphSingleOpIslandsFunctions(mlir::ModuleOp module)126 Status HasSingleGraphSingleOpIslandsFunctions(mlir::ModuleOp module) {
127   Status status = Status::OK();
128   module.walk([&](mlir::FuncOp function) {
129     if (function.getBlocks().size() != 1) {
130       status = errors::FailedPrecondition(
131           kInvalidExecutorGraphMsg,
132           "only single block functions are supported.");
133       return mlir::WalkResult::interrupt();
134     }
135 
136     auto block = function.front().without_terminator();
137     auto graph = llvm::dyn_cast<mlir::tf_executor::GraphOp>(block.begin());
138     if (!graph) {
139       status = errors::FailedPrecondition(
140           kInvalidExecutorGraphMsg,
141           "first op in function is not a tf_executor.graph.");
142       return mlir::WalkResult::interrupt();
143     }
144 
145     if (!has_single_element(block)) {
146       status = errors::FailedPrecondition(
147           kInvalidExecutorGraphMsg,
148           "function does not only contain a single tf_executor.graph.");
149       return mlir::WalkResult::interrupt();
150     }
151 
152     for (Operation& op : graph.GetBody()) {
153       auto island = llvm::dyn_cast<mlir::tf_executor::IslandOp>(op);
154       if (!island) continue;
155 
156       if (!island.WrapsSingleOp()) {
157         status = errors::FailedPrecondition(
158             kInvalidExecutorGraphMsg,
159             "tf_executor.island must perfectly wrap a single op.");
160         return mlir::WalkResult::interrupt();
161       }
162     }
163 
164     return mlir::WalkResult::advance();
165   });
166 
167   return status;
168 }
169 
170 // Finds first inner op if `op` is a tf_executor.island. Otherwise `op` is
171 // returned.
GetIslandInnerOpOrSelf(mlir::Operation * op)172 Operation* GetIslandInnerOpOrSelf(mlir::Operation* op) {
173   auto island = llvm::dyn_cast<mlir::tf_executor::IslandOp>(op);
174   if (island) return &island.GetBody().front();
175   return op;
176 }
177 
178 // Stateful helper class to export a function into a Graph.
179 class Exporter {
180  public:
181   // Converts the given Module to a Graph. The given module should only contain
182   // one entry function, which is identified by name "main". This entry function
183   // is converted to the base of the graph graph. The rest of the functions are
184   // converted to the library functions in that graph.
185   static Status Convert(mlir::ModuleOp module, const GraphExportConfig& configs,
186                         std::unique_ptr<Graph>* graph,
187                         FunctionLibraryDefinition* flib_def,
188                         absl::flat_hash_set<Node*>* control_ret_nodes);
189 
190   // Converts a given FuncOp to a FunctionDef and adds it to the function
191   // definition library
192   static Status ConvertLibFunction(const GraphExportConfig& configs,
193                                    const Dialect* tf_dialect,
194                                    mlir::FuncOp function,
195                                    FunctionDefLibrary* flib);
196   // Converts the given FuncOp to a Graph. The arguments and returns of
197   // function are added to the graph with special op names kArgOp and kRetOp.
198   // Later on, this graph can be converted a function definition and added to
199   // another graph.
200   static StatusOr<std::unique_ptr<Graph>> Convert(
201       const GraphExportConfig& configs, const Dialect* tf_dialect,
202       mlir::FuncOp function, FunctionDefLibrary* flib,
203       absl::flat_hash_set<Node*>* control_ret_nodes);
204 
205  private:
Exporter(Graph * graph,const Dialect * tf_dialect)206   explicit Exporter(Graph* graph, const Dialect* tf_dialect)
207       : graph_(graph), tf_dialect_(tf_dialect) {}
208 
209   Status AddArgumentNode(BlockArgument arg, unsigned index,
210                          llvm::StringRef name);
211   Status AddFetchNode(mlir::FuncOp function, mlir::tf_executor::FetchOp fetch,
212                       llvm::ArrayRef<llvm::StringRef> names);
213   Status AddInstructionNode(Operation* inst);
214   Status AddEdge(Operation* inst);
215 
216   StatusOr<std::unique_ptr<NodeDef>> GetArgumentNode(BlockArgument arg,
217                                                      unsigned index,
218                                                      llvm::StringRef name);
219   StatusOr<std::unique_ptr<NodeDef>> GetReturnNode(mlir::FuncOp function,
220                                                    Value operand,
221                                                    unsigned index,
222                                                    llvm::StringRef name);
223   Status GetControlRetNodes(mlir::tf_executor::FetchOp fetch,
224                             absl::flat_hash_set<Node*>* control_ret_nodes);
225   // Adds one edge between src_node and dst_node. If it is not a control edge,
226   // an index is used to find out the right operand of the dst_node.
227   Status AddEdgeBetweenNodes(Value src, Node* dst_node, unsigned dst_index);
228 
229   Graph* graph_;
230   LegalizedOpOrValLocNameMapper op_to_name_;
231   absl::flat_hash_map<Operation*, Node*> nodes_;
232   llvm::DenseMap<BlockArgument, Node*> args_;
233   // One single return operation can return multiple results, and each of them
234   // will be converted to one node in the graph.
235   typedef absl::InlinedVector<Node*, 4> NodeVector;
236   absl::flat_hash_map<Operation*, NodeVector> returns_;
237   const mlir::Dialect* tf_dialect_;
238 };
239 
GetArgumentNode(BlockArgument arg,unsigned index,llvm::StringRef name)240 StatusOr<std::unique_ptr<NodeDef>> Exporter::GetArgumentNode(
241     BlockArgument arg, unsigned index, llvm::StringRef name) {
242   auto func = arg.getParentRegion()->getParentOfType<mlir::FuncOp>();
243 
244   auto node_def = absl::make_unique<NodeDef>();
245   if (!name.empty())
246     node_def->set_name(name.str());
247   else
248     node_def->set_name(
249         std::string(op_to_name_.GetUniqueName(func.getName().str())));
250 
251   node_def->set_op(FunctionLibraryDefinition::kArgOp);
252 
253   DataType dtype;
254   TF_RETURN_IF_ERROR(ConvertToDataType(
255       arg.getType().cast<mlir::TensorType>().getElementType(), &dtype));
256   AttrValue type_attr;
257   type_attr.set_type(dtype);
258   (*node_def->mutable_attr())["T"] = type_attr;
259 
260   AttrValue index_attr;
261   index_attr.set_i(index);
262   (*node_def->mutable_attr())["index"] = index_attr;
263 
264   if (auto device_attr =
265           func.getArgAttrOfType<mlir::StringAttr>(index, "tf.device")) {
266     *node_def->mutable_device() = device_attr.getValue().str();
267   }
268 
269   if (auto resource_arg_unique_id_attr =
270           func.getArgAttrOfType<mlir::IntegerAttr>(
271               index, "tf.resource_arg_unique_id")) {
272     AttrValue unique_id_attr;
273     unique_id_attr.set_i(resource_arg_unique_id_attr.getInt());
274     (*node_def->mutable_attr())["_resource_arg_unique_id"] = unique_id_attr;
275   }
276 
277   return node_def;
278 }
279 
GetReturnNode(mlir::FuncOp function,Value operand,unsigned index,llvm::StringRef name)280 StatusOr<std::unique_ptr<NodeDef>> Exporter::GetReturnNode(
281     mlir::FuncOp function, Value operand, unsigned index,
282     llvm::StringRef name) {
283   auto node_def = absl::make_unique<NodeDef>();
284   if (!name.empty())
285     node_def->set_name(name.str());
286   else
287     node_def->set_name(
288         std::string(op_to_name_.GetUniqueName(function.getName().str())));
289 
290   node_def->set_op(FunctionLibraryDefinition::kRetOp);
291   DataType dtype;
292   TF_RETURN_IF_ERROR(ConvertToDataType(
293       operand.getType().cast<mlir::TensorType>().getElementType(), &dtype));
294   AttrValue type_attr;
295   type_attr.set_type(dtype);
296   (*node_def->mutable_attr())["T"] = type_attr;
297   AttrValue index_attr;
298   index_attr.set_i(index);
299   (*node_def->mutable_attr())["index"] = index_attr;
300   return node_def;
301 }
302 
AddEdgeBetweenNodes(Value src,Node * dst_node,unsigned dst_index)303 Status Exporter::AddEdgeBetweenNodes(Value src, Node* dst_node,
304                                      unsigned dst_index) {
305   if (auto input_result = src.dyn_cast<mlir::OpResult>()) {
306     auto* input_inst = GetIslandInnerOpOrSelf(input_result.getOwner());
307     // Replaces the input node with NextIteration sink if it is a NextIteration
308     // source.
309     if (auto next_iter_source =
310             llvm::dyn_cast<mlir::tf_executor::NextIterationSourceOp>(
311                 input_inst))
312       input_inst = next_iter_source.GetSink();
313 
314     auto node_it = nodes_.find(input_inst);
315     TF_RET_CHECK(node_it != nodes_.end())
316         << "Use of OpResult encountered before def!";
317     if (input_result.getType().isa<mlir::tf_executor::ControlType>()) {
318       graph_->AddControlEdge(node_it->second, dst_node);
319     } else {
320       graph_->AddEdge(node_it->second, input_result.getResultNumber(), dst_node,
321                       dst_index);
322     }
323     return Status::OK();
324   }
325 
326   auto input_arg = src.cast<BlockArgument>();
327   auto input_node_it = args_.find(input_arg);
328   TF_RET_CHECK(input_node_it != args_.end())
329       << "Use of BlockArgument encounted before def!";
330   // For argument, there is only one result output, so the index is always 0.
331   graph_->AddEdge(input_node_it->second, 0, dst_node, dst_index);
332   return Status::OK();
333 }
334 
AddEdge(Operation * inst)335 Status Exporter::AddEdge(Operation* inst) {
336   // For tf_executor.fetch, add only its data edges. Control edges are captured
337   // later.
338   if (auto fetch = llvm::dyn_cast<mlir::tf_executor::FetchOp>(inst)) {
339     for (auto operand_and_idx : llvm::enumerate(fetch.getOperands())) {
340       Value operand = operand_and_idx.value();
341       if (operand.getType().isa<mlir::tf_executor::ControlType>()) break;
342 
343       auto* dst_node = returns_[fetch][operand_and_idx.index()];
344       TF_RETURN_IF_ERROR(AddEdgeBetweenNodes(operand, dst_node, 0));
345     }
346 
347     return Status::OK();
348   }
349 
350   // For tf_executor.NextIteration.Sink, skip its token operand and add data and
351   // control edges with their index offset by 1.
352   if (auto next_iter_sink =
353           llvm::dyn_cast<mlir::tf_executor::NextIterationSinkOp>(inst)) {
354     auto* dst_node = nodes_[inst];
355     TF_RETURN_IF_ERROR(
356         AddEdgeBetweenNodes(next_iter_sink.input(), dst_node, 0));
357     for (auto control_and_idx : llvm::enumerate(next_iter_sink.controlInputs()))
358       TF_RETURN_IF_ERROR(AddEdgeBetweenNodes(control_and_idx.value(), dst_node,
359                                              control_and_idx.index() + 1));
360 
361     return Status::OK();
362   }
363 
364   // For tf_executor.NextIteration.Source, op can be skipped as it is assumed
365   // there are no operands.
366   if (llvm::isa<mlir::tf_executor::NextIterationSourceOp>(inst)) {
367     assert(inst->getNumOperands() == 0);
368     return Status::OK();
369   }
370 
371   Operation* op = GetIslandInnerOpOrSelf(inst);
372   auto* dst_node = nodes_[op];
373   int operand_offset = 0;
374   // For tf_executor.island, add data edges from its wrapped op before control
375   // edges.
376   if (auto island = llvm::dyn_cast<mlir::tf_executor::IslandOp>(inst)) {
377     for (auto operand_and_idx : llvm::enumerate(op->getOperands()))
378       TF_RETURN_IF_ERROR(AddEdgeBetweenNodes(operand_and_idx.value(), dst_node,
379                                              operand_and_idx.index()));
380 
381     operand_offset = op->getNumOperands();
382   }
383 
384   // For all other ops (including tf_executor.island), add remaining edges.
385   for (auto operand_and_idx : llvm::enumerate(inst->getOperands()))
386     TF_RETURN_IF_ERROR(
387         AddEdgeBetweenNodes(operand_and_idx.value(), dst_node,
388                             operand_and_idx.index() + operand_offset));
389 
390   return Status::OK();
391 }
392 
AddInstructionNode(Operation * inst)393 Status Exporter::AddInstructionNode(Operation* inst) {
394   std::unique_ptr<NodeDef> node_def;
395   auto name = op_to_name_.GetUniqueName(inst);
396   // Convert registered TF ops to NodeDef. Only registered ops are handled to
397   // ensure that PopulateDerivedAttrs adds the correct attributes.
398   TF_ASSIGN_OR_RETURN(node_def,
399                       ConvertTFDialectOpToNodeDef(
400                           inst, name, /*ignore_unregistered_attrs=*/false));
401 
402   Status status;
403   Node* node = graph_->AddNode(*node_def, &status);
404   TF_RETURN_IF_ERROR(status);
405   DCHECK(node != nullptr);
406   nodes_[inst] = node;
407   return Status::OK();
408 }
409 
IsEntryFunctionArg(BlockArgument arg)410 bool IsEntryFunctionArg(BlockArgument arg) {
411   return arg.getParentRegion()->getParentOfType<mlir::FuncOp>().getName() ==
412          "main";
413 }
414 
415 // Creates argument nodes from Block argument. If a name is supplied, that
416 // name will be used instead of generating a unique name.
AddArgumentNode(BlockArgument arg,unsigned index,llvm::StringRef name)417 Status Exporter::AddArgumentNode(BlockArgument arg, unsigned index,
418                                  llvm::StringRef name) {
419   if (!IsEntryFunctionArg(arg) || !name.empty()) {
420     TF_ASSIGN_OR_RETURN(auto node_def, GetArgumentNode(arg, index, name));
421     Status status;
422     Node* node = graph_->AddNode(*node_def, &status);
423     TF_RETURN_IF_ERROR(status);
424     args_[arg] = node;
425     return status;
426   }
427 
428   // If it is an argument from the "main" function, it has only one user, which
429   // is an input node. We recover the original input node and skip adding the
430   // argument node. The new input node will be handled as normal in the
431   // following steps.
432   if (!arg.hasOneUse()) {
433     return errors::FailedPrecondition(
434         "Arg in 'main' should only have one user.");
435   }
436   auto* input = *arg.user_begin();
437   auto* parent = input->getParentOp();
438   auto island = llvm::dyn_cast_or_null<mlir::tf_executor::IslandOp>(parent);
439   if (!island)
440     return errors::FailedPrecondition(
441         "User of arg in 'main' must be in an inner op of a "
442         "tf_executor.island.");
443 
444   if (!island.control().use_empty())
445     return errors::FailedPrecondition(
446         "tf_executor.island of user of arg in 'main' must have no control "
447         "output users.");
448 
449   auto input_name = input->getName().getStringRef();
450   input_name.consume_back(".input");
451 
452   mlir::OpBuilder builder(island.getContext());
453   builder.setInsertionPointToStart(&island.GetBody());
454   auto loc = mlir::NameLoc::get(
455       builder.getIdentifier(op_to_name_.GetUniqueName(input)),
456       builder.getContext());
457   OperationState state(loc, input_name.str());
458   state.attributes.append(input->getAttrs().begin(), input->getAttrs().end());
459   for (auto op : input->getOperands()) {
460     // Skip the argument in the new operation.
461     if (op.isa<BlockArgument>()) continue;
462     state.operands.push_back(op);
463   }
464   state.types.append(input->getResultTypes().begin(),
465                      input->getResultTypes().end());
466   auto* inst = builder.createOperation(state);
467   // If it is one of the specified input names, then the new instruction should
468   // have the same name.
469   op_to_name_.InitOpName(inst, op_to_name_.GetUniqueName(input));
470   for (int index : llvm::seq<int>(0, input->getNumResults())) {
471     input->getResult(index).replaceAllUsesWith(inst->getResult(index));
472   }
473   input->dropAllReferences();
474   input->erase();
475   return Status::OK();
476 }
477 
478 // Creates return nodes per operand of a FetchOp. If names is supplied, those
479 // names will be used per node in order instead of generating a unique name.
AddFetchNode(mlir::FuncOp function,mlir::tf_executor::FetchOp fetch,llvm::ArrayRef<llvm::StringRef> names)480 Status Exporter::AddFetchNode(mlir::FuncOp function,
481                               mlir::tf_executor::FetchOp fetch,
482                               llvm::ArrayRef<llvm::StringRef> names) {
483   Status status;
484   auto& return_nodes = returns_[fetch];
485   for (auto operand_and_idx : llvm::enumerate(fetch.getOperands())) {
486     if (operand_and_idx.value().getType().isa<mlir::tf_executor::ControlType>())
487       break;
488 
489     TF_ASSIGN_OR_RETURN(
490         auto node_def,
491         GetReturnNode(function, operand_and_idx.value(),
492                       operand_and_idx.index(),
493                       names.empty() ? "" : names[operand_and_idx.index()]));
494     Node* node = graph_->AddNode(*node_def, &status);
495     TF_RETURN_IF_ERROR(status);
496     return_nodes.push_back(node);
497   }
498   return Status::OK();
499 }
500 
501 // Collects control ret Nodes based on tf_executor.graph's associated
502 // tf_executor.fetch control inputs.
GetControlRetNodes(mlir::tf_executor::FetchOp fetch,absl::flat_hash_set<Node * > * control_ret_nodes)503 Status Exporter::GetControlRetNodes(
504     mlir::tf_executor::FetchOp fetch,
505     absl::flat_hash_set<Node*>* control_ret_nodes) {
506   for (Value fetch_operand : fetch.getOperands()) {
507     if (fetch_operand.getType().isa<mlir::tf_executor::ControlType>()) {
508       Operation* defining_op =
509           GetIslandInnerOpOrSelf(fetch_operand.getDefiningOp());
510       auto node_it = nodes_.find(defining_op);
511       TF_RET_CHECK(node_it != nodes_.end());
512       control_ret_nodes->insert(node_it->second);
513     }
514   }
515   return Status::OK();
516 }
517 
Convert(const GraphExportConfig & configs,const Dialect * tf_dialect,mlir::FuncOp function,FunctionDefLibrary * flib,absl::flat_hash_set<Node * > * control_ret_nodes)518 StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
519     const GraphExportConfig& configs, const Dialect* tf_dialect,
520     mlir::FuncOp function, FunctionDefLibrary* flib,
521     absl::flat_hash_set<Node*>* control_ret_nodes) {
522   mlir::Block& block = function.front();
523 
524   // Determine if _Arg and _Retval nodes should use input and output names.
525   bool graph_as_function = false;
526 
527   // Extract input & output names if set.
528   llvm::SmallVector<llvm::StringRef, 2> input_names;
529   llvm::SmallVector<llvm::StringRef, 2> output_names;
530   auto dict_attr =
531       function.getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
532   if (dict_attr) {
533     TF_RET_CHECK(dict_attr.get("inputs").isa<mlir::StringAttr>())
534         << "inputs missing in entry function attribute";
535     TF_RET_CHECK(dict_attr.get("outputs").isa<mlir::StringAttr>())
536         << "outputs missing in entry function attribute";
537     dict_attr.get("inputs").cast<mlir::StringAttr>().getValue().split(
538         input_names, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false);
539     dict_attr.get("outputs").cast<mlir::StringAttr>().getValue().split(
540         output_names, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false);
541     graph_as_function = configs.graph_as_function;
542   }
543 
544   auto graph = absl::make_unique<Graph>(OpRegistry::Global());
545 
546   // Extract version info.
547   auto version_attr = function.getParentOfType<mlir::ModuleOp>()
548                           .getAttrOfType<mlir::DictionaryAttr>("tf.versions");
549   if (version_attr) {
550     VersionDef versions;
551     versions.set_producer(
552         version_attr.get("producer").cast<mlir::IntegerAttr>().getInt());
553     versions.set_min_consumer(
554         version_attr.get("min_consumer").cast<mlir::IntegerAttr>().getInt());
555     for (auto bad_consumer :
556          version_attr.get("bad_consumers").cast<mlir::ArrayAttr>()) {
557       versions.mutable_bad_consumers()->Add(
558           bad_consumer.cast<mlir::IntegerAttr>().getInt());
559     }
560     graph->set_versions(versions);
561   }
562 
563   // We have to add the function library here, so a custom operation, which is
564   // defined in the function library can be added to the graph.
565   TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(*flib));
566   Exporter exporter(graph.get(), tf_dialect);
567 
568   auto graph_op = llvm::cast<mlir::tf_executor::GraphOp>(block.front());
569 
570   // Set input and output names and increment the use counter for them to help
571   // generate unique names.
572   if (!output_names.empty()) {
573     const int num_data_results = graph_op.getNumResults();
574     TF_RET_CHECK(output_names.size() == num_data_results)
575         << "output names (" << output_names.size()
576         << ") != terminator operands (" << num_data_results << ")";
577     llvm::DenseMap<Operation*, llvm::StringRef> output_op_to_name;
578     llvm::StringMap<Operation*> name_to_op;
579     for (auto it : llvm::enumerate(graph_op.GetFetch().getOperands())) {
580       // Skip control rets.
581       if (it.index() >= num_data_results) break;
582       // If there is a result index specified, ensure only one and that it
583       // matches the result index of the op.
584       auto result = it.value().cast<mlir::OpResult>();
585       std::string orig_name(output_names[it.index()]);
586       auto tensor_id = ParseTensorName(orig_name);
587       auto name = LegalizeNodeName(
588           llvm::StringRef(tensor_id.node().data(), tensor_id.node().size()));
589 
590       if (graph_as_function) {
591         // Ensure name does not get reused.
592         (void)exporter.op_to_name_.GetUniqueName(name);
593         continue;
594       }
595 
596       TF_RET_CHECK(result.getResultNumber() == tensor_id.index());
597       Operation* defining_op = GetIslandInnerOpOrSelf(result.getDefiningOp());
598       if (output_op_to_name.insert({defining_op, name}).second) {
599         TF_RET_CHECK(name_to_op.insert({name, defining_op}).second)
600             << "multiple operations associated with the same name";
601         exporter.op_to_name_.InitOpName(defining_op, name);
602       } else {
603         TF_RET_CHECK(output_op_to_name[defining_op] == name)
604             << "associating multiple names with the same op not supported";
605       }
606     }
607   }
608 
609   if (!input_names.empty()) {
610     TF_RET_CHECK(input_names.size() == block.getNumArguments());
611     for (auto it : llvm::enumerate(function.getArguments())) {
612       // TODO(lyandy): Update when changing feed/fetch import.
613       std::string orig_name(input_names[it.index()]);
614       std::string name = LegalizeNodeName(orig_name);
615       auto tensor_id = ParseTensorName(name);
616       TF_RET_CHECK(tensor_id.index() == 0)
617           << "input port designation not supported";
618       // Only assign user of argument the input name if the main graph did not
619       // have its _Arg nodes lifted into the functions arguments.
620       if (graph_as_function) {
621         // Ensure name does not get reused.
622         (void)exporter.op_to_name_.GetUniqueName(name);
623       } else {
624         Operation* defining_op =
625             GetIslandInnerOpOrSelf(*it.value().user_begin());
626         exporter.op_to_name_.InitOpName(defining_op, name);
627       }
628     }
629   }
630 
631   // Adds nodes for basic block (function) arguments.
632   for (auto it : llvm::enumerate(block.getArguments())) {
633     int index = it.index();
634     auto arg = it.value();
635     mlir::Type type = arg.getType();
636     if (!type.isa<mlir::TensorType>()) {
637       return errors::InvalidArgument(
638           "FuncOps arguments must have tensor types. Found ",
639           mlir::debugString(type), " in function ", function.getName().str());
640     }
641 
642     TF_RETURN_IF_ERROR(exporter.AddArgumentNode(
643         arg, index,
644         graph_as_function && !input_names.empty() ? input_names[index] : ""));
645   }
646 
647   auto convert_called_function = [&](llvm::StringRef name) {
648     auto func =
649         function.getParentOfType<mlir::ModuleOp>().lookupSymbol<mlir::FuncOp>(
650             name);
651     if (func != nullptr) {
652       TF_RETURN_IF_ERROR(ConvertLibFunction(configs, tf_dialect, func, flib));
653       TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(*flib));
654     }
655     return Status::OK();
656   };
657 
658   // Adds nodes for operations.
659   for (Operation& inst : graph_op.GetBody()) {
660     for (auto type : inst.getResultTypes())
661       if (!type.isa<mlir::TensorType>() &&
662           !type.isa<mlir::tf_executor::ControlType>() &&
663           !type.isa<mlir::tf_executor::TokenType>())
664         return errors::InvalidArgument(
665             "Values must be of tensor type, TensorFlow control type, or "
666             "TensorFlow token type. Found ",
667             mlir::debugString(type));
668 
669     if (llvm::isa<mlir::tf_executor::NextIterationSourceOp>(inst)) {
670       // Skip tf_executor.NextIteration.Source as associated
671       // tf_executor.NextIteration.Sink will be used instead.
672       continue;
673     } else if (auto fetch = llvm::dyn_cast<mlir::tf_executor::FetchOp>(inst)) {
674       TF_RETURN_IF_ERROR(exporter.AddFetchNode(
675           function, fetch,
676           graph_as_function ? output_names
677                             : llvm::ArrayRef<llvm::StringRef>()));
678     } else if (auto island =
679                    llvm::dyn_cast<mlir::tf_executor::IslandOp>(inst)) {
680       Operation& inner_op = island.GetBody().front();
681       auto op_name = GetTensorFlowOpName(inner_op.getName().getStringRef());
682       if (op_name.ok()) {
683         // If it is TF Control dialect specific op, look up custom operation
684         // in the module and first convert that, then add it to function
685         // definition library
686         // TODO(prakalps): If two functions have cyclic dependence, this will
687         // introduce an infinite loop.
688         TF_RETURN_IF_ERROR(convert_called_function(op_name.ValueOrDie().str()));
689       }
690 
691       if (IsLegacyCallInstruction(&inner_op)) {
692         TF_RETURN_IF_ERROR(convert_called_function(
693             inner_op.getAttrOfType<mlir::SymbolRefAttr>("f")
694                 .getLeafReference()));
695       }
696 
697       TF_RETURN_IF_ERROR(exporter.AddInstructionNode(&inner_op));
698     } else {
699       TF_RETURN_IF_ERROR(exporter.AddInstructionNode(&inst));
700     }
701   }
702   // Adds edges between the argument, operation and return nodes.
703   for (Operation& inst : graph_op.GetBody()) {
704     TF_RETURN_IF_ERROR(exporter.AddEdge(&inst));
705   }
706   // Fixes the edges between the inserted nodes and special "_SOURCE" and
707   // "_SINK".
708   FixupSourceAndSinkEdges(graph.get());
709 
710   TF_RETURN_IF_ERROR(
711       exporter.GetControlRetNodes(graph_op.GetFetch(), control_ret_nodes));
712 
713   return graph;
714 }
715 
ConvertLibFunction(const GraphExportConfig & configs,const Dialect * tf_dialect,mlir::FuncOp function,FunctionDefLibrary * flib)716 Status Exporter::ConvertLibFunction(const GraphExportConfig& configs,
717                                     const Dialect* tf_dialect,
718                                     mlir::FuncOp function,
719                                     FunctionDefLibrary* flib) {
720   // First look for the function in the current function library. If found,
721   // nothing needs to be done.
722   OpRegistry empty_registry;
723   FunctionLibraryDefinition flib_def(&empty_registry, *flib);
724   auto function_name = function.getName().str();
725   if (flib_def.Find(function_name)) return Status::OK();
726 
727   // TODO(fengliuai): use a small flib_def to reduce overhead
728   absl::flat_hash_set<Node*> control_ret_nodes;
729   TF_ASSIGN_OR_RETURN(auto sub_graph,
730                       Exporter::Convert(configs, tf_dialect, function, flib,
731                                         &control_ret_nodes));
732   const auto control_ret = [&](const Node* n) -> absl::optional<string> {
733     return control_ret_nodes.contains(n)
734                ? absl::make_optional<string>(n->name())
735                : absl::nullopt;
736   };
737   FunctionDef func_def;
738   TF_RETURN_IF_ERROR(
739       GraphToFunctionDef(*sub_graph, function_name, control_ret, &func_def));
740 
741   // The node defs in FunctionDef might contain debug info which was added
742   // by the GraphToFunctionDef method. We should remove it if we don't want
743   // to export them to avoid failing the roundtrip test.
744   if (!configs.export_debug_info) {
745     for (auto& node_def : *func_def.mutable_node_def()) {
746       node_def.clear_experimental_debug_info();
747     }
748   }
749 
750   // Checks for gradient attribute. If present converts the gradient function
751   // and populates the GradientDef.
752   auto grad_string = mlir::TF::TensorFlowDialect::GetGradientAttrName();
753   if (auto attr =
754           function.getAttrOfType<mlir::FlatSymbolRefAttr>(grad_string)) {
755     auto grad_func =
756         function.getParentOfType<mlir::ModuleOp>().lookupSymbol<mlir::FuncOp>(
757             attr.getValue());
758     TF_RETURN_IF_ERROR(
759         ConvertLibFunction(configs, tf_dialect, grad_func, flib));
760     GradientDef grad;
761     grad.set_function_name(function_name);
762     grad.set_gradient_func(grad_func.getName().str());
763     *flib->add_gradient() = grad;
764   }
765 
766   auto stateful_string = mlir::TF::TensorFlowDialect::GetStatefulAttrName();
767   if (auto attr = function.getAttrOfType<mlir::UnitAttr>(stateful_string)) {
768     func_def.mutable_signature()->set_is_stateful(true);
769   }
770   for (int64 i = 0; i < function.getNumArguments(); ++i) {
771     if (auto resource_arg_unique_id_attr =
772             function.getArgAttrOfType<mlir::IntegerAttr>(
773                 i, "tf.resource_arg_unique_id")) {
774       (*func_def.mutable_resource_arg_unique_id())[i] =
775           resource_arg_unique_id_attr.getInt();
776     }
777   }
778 
779   // Ignore the gradient and is_stateful attribute on the function as they have
780   // been handled above.
781   absl::flat_hash_set<absl::string_view> attrs_to_ignore = {
782       grad_string.data(), stateful_string.data()};
783   llvm::SmallVector<mlir::NamedAttribute, 8> funcAttrs(
784       function.getDialectAttrs());
785   TF_RETURN_IF_ERROR(
786       ConvertAttributes(funcAttrs, attrs_to_ignore, func_def.mutable_attr()));
787   (*flib->add_function()) = func_def;
788   return Status::OK();
789 }
790 
Convert(mlir::ModuleOp module,const GraphExportConfig & configs,std::unique_ptr<Graph> * graph,FunctionLibraryDefinition * flib_def,absl::flat_hash_set<Node * > * control_ret_nodes)791 Status Exporter::Convert(mlir::ModuleOp module,
792                          const GraphExportConfig& configs,
793                          std::unique_ptr<Graph>* graph,
794                          FunctionLibraryDefinition* flib_def,
795                          absl::flat_hash_set<Node*>* control_ret_nodes) {
796   mlir::Identifier entry_func_id =
797       mlir::Identifier::get("main", module.getContext());
798   absl::optional<mlir::FuncOp> entry_func;
799   FunctionDefLibrary flib;
800   auto tf_dialect = module.getContext()->getRegisteredDialect("tf");
801   for (auto function : module.getOps<mlir::FuncOp>()) {
802     if (function.isExternal())
803       return errors::FailedPrecondition("External functions not supported");
804 
805     if (function.getName() == entry_func_id) {
806       entry_func.emplace(function);
807     } else {
808       TF_RETURN_IF_ERROR(
809           ConvertLibFunction(configs, tf_dialect, function, &flib));
810     }
811   }
812 
813   if (!entry_func.has_value())
814     return errors::FailedPrecondition("entry function `main` must be present");
815 
816   // Updates the graph and the function library definition.
817   TF_ASSIGN_OR_RETURN(
818       *graph, Exporter::Convert(configs, tf_dialect, entry_func.value(), &flib,
819                                 control_ret_nodes));
820   for (auto& func_def : flib.function()) {
821     TF_RETURN_IF_ERROR(flib_def->AddFunctionDef(func_def));
822   }
823   for (auto& grad_def : flib.gradient()) {
824     TF_RETURN_IF_ERROR(flib_def->AddGradientDef(grad_def));
825   }
826   return Status::OK();
827 }
828 }  // namespace
829 
ConvertMlirToGraph(mlir::ModuleOp module,const GraphExportConfig & configs,std::unique_ptr<Graph> * graph,FunctionLibraryDefinition * flib_def,absl::flat_hash_set<Node * > * control_ret_nodes)830 Status ConvertMlirToGraph(mlir::ModuleOp module,
831                           const GraphExportConfig& configs,
832                           std::unique_ptr<Graph>* graph,
833                           FunctionLibraryDefinition* flib_def,
834                           absl::flat_hash_set<Node*>* control_ret_nodes) {
835   TF_RETURN_IF_ERROR(HasSingleGraphSingleOpIslandsFunctions(module));
836   return Exporter::Convert(module, configs, graph, flib_def, control_ret_nodes);
837 }
838 
ConvertMlirToGraph(mlir::ModuleOp module,const GraphExportConfig & configs,std::unique_ptr<Graph> * graph,FunctionLibraryDefinition * flib_def)839 Status ConvertMlirToGraph(mlir::ModuleOp module,
840                           const GraphExportConfig& configs,
841                           std::unique_ptr<Graph>* graph,
842                           FunctionLibraryDefinition* flib_def) {
843   absl::flat_hash_set<Node*> control_ret_nodes;
844   return ConvertMlirToGraph(module, configs, graph, flib_def,
845                             &control_ret_nodes);
846 }
847 
ConvertMlirToGraphdef(mlir::ModuleOp module,const GraphExportConfig & configs)848 StatusOr<std::unique_ptr<GraphDef>> ConvertMlirToGraphdef(
849     mlir::ModuleOp module, const GraphExportConfig& configs) {
850   FunctionLibraryDefinition flib_def(OpRegistry::Global(),
851                                      FunctionDefLibrary());
852   auto graph = absl::make_unique<Graph>(flib_def);
853   TF_RETURN_IF_ERROR(ConvertMlirToGraph(module, configs, &graph, &flib_def));
854   auto graphdef = absl::make_unique<GraphDef>();
855   graph->ToGraphDef(graphdef.get());
856   if (!configs.export_library) graphdef->clear_library();
857   if (!configs.export_shapes) {
858     for (auto& node_def : *graphdef->mutable_node()) {
859       node_def.mutable_attr()->erase("shape");
860     }
861   }
862   if (!configs.export_debug_info) {
863     for (auto& node_def : *graphdef->mutable_node()) {
864       node_def.clear_experimental_debug_info();
865     }
866   }
867   return graphdef;
868 }
869 
870 }  // namespace tensorflow
871