• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
17 
18 #include <string>
19 #include <utility>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/container/inlined_vector.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/string_view.h"
26 #include "absl/types/optional.h"
27 #include "llvm/ADT/ArrayRef.h"
28 #include "llvm/ADT/DenseSet.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallVector.h"
31 #include "llvm/ADT/StringRef.h"
32 #include "llvm/Support/Casting.h"
33 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
34 #include "mlir/IR/Attributes.h"  // from @llvm-project
35 #include "mlir/IR/Builders.h"  // from @llvm-project
36 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
37 #include "mlir/IR/Identifier.h"  // from @llvm-project
38 #include "mlir/IR/Location.h"  // from @llvm-project
39 #include "mlir/IR/Operation.h"  // from @llvm-project
40 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
41 #include "mlir/IR/Types.h"  // from @llvm-project
42 #include "mlir/Pass/Pass.h"  // from @llvm-project
43 #include "mlir/Pass/PassManager.h"  // from @llvm-project
44 #include "mlir/Support/DebugStringHelper.h"  // from @llvm-project
45 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
46 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
47 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
48 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
49 #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
50 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
51 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
52 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
53 #include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h"
54 #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
55 #include "tensorflow/compiler/mlir/tensorflow/utils/verify_suitable_for_graph_export.h"
56 #include "tensorflow/compiler/mlir/utils/name_utils.h"
57 #include "tensorflow/compiler/xla/status_macros.h"
58 #include "tensorflow/core/framework/graph.pb.h"
59 #include "tensorflow/core/framework/graph_to_functiondef.h"
60 #include "tensorflow/core/framework/node_def.pb.h"
61 #include "tensorflow/core/framework/node_def_util.h"
62 #include "tensorflow/core/framework/op.h"
63 #include "tensorflow/core/framework/types.pb.h"
64 #include "tensorflow/core/framework/versions.pb.h"
65 #include "tensorflow/core/graph/algorithm.h"
66 #include "tensorflow/core/graph/graph.h"
67 #include "tensorflow/core/graph/tensor_id.h"
68 #include "tensorflow/core/lib/core/errors.h"
69 #include "tensorflow/core/lib/core/status.h"
70 
71 namespace tensorflow {
72 using llvm::dyn_cast;
73 using llvm::isa;
74 using mlir::BlockArgument;
75 using mlir::Dialect;
76 using mlir::FuncOp;
77 using mlir::Operation;
78 using mlir::SymbolTable;
79 using mlir::Value;
80 using stream_executor::port::StatusOr;
81 
82 namespace {
83 
84 constexpr char kDeviceAttr[] = "tf.device";
85 constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id";
86 constexpr char kEntryFuncAttr[] = "tf.entry_function";
87 constexpr char kAliasingAttr[] = "tf.aliasing_output";
88 
89 // OpOrArgLocNameMapper that legalizes the returned name.
90 class LegalizedOpOrValLocNameMapper : public OpOrArgLocNameMapper {
91  private:
GetName(OpOrVal op_or_val)92   std::string GetName(OpOrVal op_or_val) override {
93     std::string name = OpOrArgLocNameMapper::GetName(op_or_val);
94     assert(!name.empty() && "expected non-empty name");
95     mlir::LegalizeNodeName(name);
96     return name;
97   }
98 };
99 
100 // Finds first inner op if `op` is a tf_executor.island. Otherwise `op` is
101 // returned.
GetIslandInnerOpOrSelf(mlir::Operation * op)102 Operation* GetIslandInnerOpOrSelf(mlir::Operation* op) {
103   auto island = llvm::dyn_cast<mlir::tf_executor::IslandOp>(op);
104   if (island) return &island.GetBody().front();
105   return op;
106 }
107 
108 // Stateful helper class to export a function into a Graph.
109 class Exporter {
110  public:
111   // Converts the given Module to a Graph. The given module should only contain
112   // one entry function, which is identified by name "main". This entry function
113   // is converted to the base of the graph graph. The rest of the functions are
114   // converted to the library functions in that graph.
115   static Status Convert(mlir::ModuleOp module, const GraphExportConfig& configs,
116                         std::unique_ptr<Graph>* graph,
117                         FunctionLibraryDefinition* flib_def,
118                         absl::flat_hash_set<Node*>* control_ret_nodes);
119 
120   // Converts a given FuncOp to a FunctionDef and adds it to the function
121   // definition library
122   static Status ConvertLibFunction(
123       const GraphExportConfig& configs, const Dialect* tf_dialect,
124       const SymbolTable& symbol_table, FuncOp function,
125       FunctionDefLibrary* flib, llvm::SmallDenseSet<FuncOp>& visited_functions);
126 
127   // Converts the given FuncOp to a Graph. The arguments and returns of
128   // function are added to the graph with special op names kArgOp and kRetOp.
129   // Later on, this graph can be converted a function definition and added to
130   // another graph.
131   static StatusOr<std::unique_ptr<Graph>> Convert(
132       const GraphExportConfig& configs, const Dialect* tf_dialect,
133       const SymbolTable& symbol_table, FuncOp function,
134       FunctionDefLibrary* flib, llvm::SmallDenseSet<FuncOp>& visited_functions,
135       absl::flat_hash_set<Node*>* control_ret_nodes);
136 
137  private:
Exporter(Graph * graph,const Dialect * tf_dialect)138   explicit Exporter(Graph* graph, const Dialect* tf_dialect)
139       : graph_(graph), tf_dialect_(tf_dialect) {}
140 
141   Status AddArgumentNode(BlockArgument arg, unsigned index,
142                          llvm::StringRef name);
143   Status AddFetchNode(FuncOp function, mlir::tf_executor::FetchOp fetch,
144                       llvm::ArrayRef<llvm::StringRef> names);
145   Status AddInstructionNode(Operation* inst);
146   Status AddEdge(Operation* inst);
147 
148   StatusOr<std::unique_ptr<NodeDef>> GetArgumentNode(BlockArgument arg,
149                                                      unsigned index,
150                                                      llvm::StringRef name);
151   StatusOr<std::unique_ptr<NodeDef>> GetReturnNode(FuncOp function,
152                                                    Value operand,
153                                                    unsigned index,
154                                                    llvm::StringRef name);
155   Status GetControlRetNodes(mlir::tf_executor::FetchOp fetch,
156                             absl::flat_hash_set<Node*>* control_ret_nodes);
157   // Adds one edge between src_node and dst_node. If it is not a control edge,
158   // an index is used to find out the right operand of the dst_node.
159   Status AddEdgeBetweenNodes(Value src, Node* dst_node, unsigned dst_index);
160 
161   Graph* graph_;
162   LegalizedOpOrValLocNameMapper op_to_name_;
163   absl::flat_hash_map<Operation*, Node*> nodes_;
164   llvm::DenseMap<BlockArgument, Node*> args_;
165   // One single return operation can return multiple results, and each of them
166   // will be converted to one node in the graph.
167   typedef absl::InlinedVector<Node*, 4> NodeVector;
168   absl::flat_hash_map<Operation*, NodeVector> returns_;
169   const mlir::Dialect* tf_dialect_;
170 };
171 
GetArgumentNode(BlockArgument arg,unsigned index,llvm::StringRef name)172 StatusOr<std::unique_ptr<NodeDef>> Exporter::GetArgumentNode(
173     BlockArgument arg, unsigned index, llvm::StringRef name) {
174   auto func = arg.getParentRegion()->getParentOfType<FuncOp>();
175 
176   auto node_def = absl::make_unique<NodeDef>();
177   if (!name.empty())
178     node_def->set_name(name.str());
179   else
180     node_def->set_name(
181         std::string(op_to_name_.GetUniqueName(func.getName().str())));
182 
183   node_def->set_op(FunctionLibraryDefinition::kArgOp);
184 
185   mlir::TensorType arg_type = arg.getType().cast<mlir::TensorType>();
186   if (auto resource_type =
187           arg_type.getElementType().dyn_cast<mlir::TF::ResourceType>()) {
188     llvm::ArrayRef<mlir::TensorType> subtypes = resource_type.getSubtypes();
189     if (!subtypes.empty()) {
190       AttrValue handle_dtypes_attr;
191       AttrValue handle_shapes_attr;
192       for (mlir::TensorType subtype : subtypes) {
193         DataType dtype;
194         TF_RETURN_IF_ERROR(ConvertToDataType(subtype.getElementType(), &dtype));
195         handle_dtypes_attr.mutable_list()->add_type(dtype);
196 
197         SetTensorShapeProto(subtype,
198                             handle_shapes_attr.mutable_list()->add_shape());
199       }
200 
201       (*node_def->mutable_attr())["_handle_dtypes"] = handle_dtypes_attr;
202       (*node_def->mutable_attr())["_handle_shapes"] = handle_shapes_attr;
203     }
204   }
205 
206   TF_RETURN_IF_ERROR(
207       SetShapeAttribute("_output_shapes", arg_type, node_def->mutable_attr()));
208 
209   DataType dtype;
210   TF_RETURN_IF_ERROR(ConvertToDataType(arg_type.getElementType(), &dtype));
211   AttrValue type_attr;
212   type_attr.set_type(dtype);
213   (*node_def->mutable_attr())["T"] = type_attr;
214 
215   AttrValue index_attr;
216   index_attr.set_i(index);
217   (*node_def->mutable_attr())["index"] = index_attr;
218 
219   if (auto device_attr =
220           func.getArgAttrOfType<mlir::StringAttr>(index, kDeviceAttr))
221     *node_def->mutable_device() = device_attr.getValue().str();
222 
223   llvm::ArrayRef<mlir::NamedAttribute> func_arg_i_attrs =
224       func.getArgAttrs(index);
225   absl::flat_hash_set<absl::string_view> attrs_to_ignore = {kDeviceAttr,
226                                                             kAliasingAttr};
227   TF_RETURN_IF_ERROR(ConvertAttributes(func_arg_i_attrs, attrs_to_ignore,
228                                        /*remove_ref_type=*/false,
229                                        node_def->mutable_attr()));
230 
231   return node_def;
232 }
233 
GetReturnNode(FuncOp function,Value operand,unsigned index,llvm::StringRef name)234 StatusOr<std::unique_ptr<NodeDef>> Exporter::GetReturnNode(
235     FuncOp function, Value operand, unsigned index, llvm::StringRef name) {
236   auto node_def = absl::make_unique<NodeDef>();
237   if (!name.empty())
238     node_def->set_name(name.str());
239   else
240     node_def->set_name(
241         std::string(op_to_name_.GetUniqueName(function.getName().str())));
242 
243   node_def->set_op(FunctionLibraryDefinition::kRetOp);
244   DataType dtype;
245   TF_RETURN_IF_ERROR(ConvertToDataType(
246       operand.getType().cast<mlir::TensorType>().getElementType(), &dtype));
247   AttrValue type_attr;
248   type_attr.set_type(dtype);
249   (*node_def->mutable_attr())["T"] = type_attr;
250   AttrValue index_attr;
251   index_attr.set_i(index);
252   (*node_def->mutable_attr())["index"] = index_attr;
253 
254   if (auto device_attr =
255           function.getResultAttrOfType<mlir::StringAttr>(index, kDeviceAttr))
256     *node_def->mutable_device() = device_attr.getValue().str();
257 
258   llvm::ArrayRef<mlir::NamedAttribute> func_res_i_attrs =
259       function.getResultAttrs(index);
260   absl::flat_hash_set<absl::string_view> attrs_to_ignore = {kDeviceAttr};
261   TF_RETURN_IF_ERROR(ConvertAttributes(func_res_i_attrs, attrs_to_ignore,
262                                        /*remove_ref_type=*/false,
263                                        node_def->mutable_attr()));
264 
265   return node_def;
266 }
267 
AddEdgeBetweenNodes(Value src,Node * dst_node,unsigned dst_index)268 Status Exporter::AddEdgeBetweenNodes(Value src, Node* dst_node,
269                                      unsigned dst_index) {
270   if (auto input_result = src.dyn_cast<mlir::OpResult>()) {
271     auto* input_inst = GetIslandInnerOpOrSelf(input_result.getOwner());
272     // Replaces the input node with NextIteration sink if it is a NextIteration
273     // source.
274     if (auto next_iter_source =
275             llvm::dyn_cast<mlir::tf_executor::NextIterationSourceOp>(
276                 input_inst))
277       input_inst = next_iter_source.GetSink();
278 
279     auto node_it = nodes_.find(input_inst);
280     TF_RET_CHECK(node_it != nodes_.end())
281         << "Use of OpResult encountered before def!";
282     if (input_result.getType().isa<mlir::tf_executor::ControlType>()) {
283       graph_->AddControlEdge(node_it->second, dst_node);
284     } else {
285       graph_->AddEdge(node_it->second, input_result.getResultNumber(), dst_node,
286                       dst_index);
287     }
288     return Status::OK();
289   }
290 
291   auto input_arg = src.cast<BlockArgument>();
292   auto input_node_it = args_.find(input_arg);
293   TF_RET_CHECK(input_node_it != args_.end())
294       << "Use of BlockArgument encounted before def!";
295   // For argument, there is only one result output, so the index is always 0.
296   graph_->AddEdge(input_node_it->second, 0, dst_node, dst_index);
297   return Status::OK();
298 }
299 
AddEdge(Operation * inst)300 Status Exporter::AddEdge(Operation* inst) {
301   // For tf_executor.fetch, add only its data edges. Control edges are captured
302   // later.
303   if (auto fetch = llvm::dyn_cast<mlir::tf_executor::FetchOp>(inst)) {
304     for (auto operand_and_idx : llvm::enumerate(fetch.getOperands())) {
305       Value operand = operand_and_idx.value();
306       if (operand.getType().isa<mlir::tf_executor::ControlType>()) break;
307 
308       auto* dst_node = returns_[fetch][operand_and_idx.index()];
309       TF_RETURN_IF_ERROR(AddEdgeBetweenNodes(operand, dst_node, 0));
310     }
311 
312     return Status::OK();
313   }
314 
315   // For tf_executor.NextIteration.Sink, skip its token operand and add data and
316   // control edges with their index offset by 1.
317   if (auto next_iter_sink =
318           llvm::dyn_cast<mlir::tf_executor::NextIterationSinkOp>(inst)) {
319     auto* dst_node = nodes_[inst];
320     TF_RETURN_IF_ERROR(
321         AddEdgeBetweenNodes(next_iter_sink.input(), dst_node, 0));
322     for (auto control_and_idx : llvm::enumerate(next_iter_sink.controlInputs()))
323       TF_RETURN_IF_ERROR(AddEdgeBetweenNodes(control_and_idx.value(), dst_node,
324                                              control_and_idx.index() + 1));
325 
326     return Status::OK();
327   }
328 
329   // For tf_executor.NextIteration.Source, op can be skipped as it is assumed
330   // there are no operands.
331   if (llvm::isa<mlir::tf_executor::NextIterationSourceOp>(inst)) {
332     assert(inst->getNumOperands() == 0);
333     return Status::OK();
334   }
335 
336   Operation* op = GetIslandInnerOpOrSelf(inst);
337   auto* dst_node = nodes_[op];
338   int operand_offset = 0;
339   // For tf_executor.island, add data edges from its wrapped op before control
340   // edges.
341   if (auto island = llvm::dyn_cast<mlir::tf_executor::IslandOp>(inst)) {
342     for (auto operand_and_idx : llvm::enumerate(op->getOperands()))
343       TF_RETURN_IF_ERROR(AddEdgeBetweenNodes(operand_and_idx.value(), dst_node,
344                                              operand_and_idx.index()));
345 
346     operand_offset = op->getNumOperands();
347   }
348 
349   // For all other ops (including tf_executor.island), add remaining edges.
350   for (auto operand_and_idx : llvm::enumerate(inst->getOperands()))
351     TF_RETURN_IF_ERROR(
352         AddEdgeBetweenNodes(operand_and_idx.value(), dst_node,
353                             operand_and_idx.index() + operand_offset));
354 
355   return Status::OK();
356 }
357 
AddInstructionNode(Operation * inst)358 Status Exporter::AddInstructionNode(Operation* inst) {
359   std::unique_ptr<NodeDef> node_def;
360   auto name = op_to_name_.GetUniqueName(inst);
361   // Convert registered TF ops to NodeDef. Only registered ops are handled to
362   // ensure that PopulateDerivedAttrs adds the correct attributes.
363   TF_ASSIGN_OR_RETURN(node_def,
364                       ConvertTFDialectOpToNodeDef(
365                           inst, name, /*ignore_unregistered_attrs=*/false));
366 
367   Status status;
368   Node* node = graph_->AddNode(*node_def, &status);
369   TF_RETURN_IF_ERROR(status);
370   DCHECK(node != nullptr);
371   nodes_[inst] = node;
372   return Status::OK();
373 }
374 
IsEntryFunctionArg(BlockArgument arg)375 bool IsEntryFunctionArg(BlockArgument arg) {
376   return arg.getParentRegion()->getParentOfType<FuncOp>().getName() == "main";
377 }
378 
379 // Creates argument nodes from Block argument. If a name is supplied, that
380 // name will be used instead of generating a unique name.
AddArgumentNode(BlockArgument arg,unsigned index,llvm::StringRef name)381 Status Exporter::AddArgumentNode(BlockArgument arg, unsigned index,
382                                  llvm::StringRef name) {
383   TF_ASSIGN_OR_RETURN(auto node_def, GetArgumentNode(arg, index, name));
384   Status status;
385   Node* node = graph_->AddNode(*node_def, &status);
386   TF_RETURN_IF_ERROR(status);
387   args_[arg] = node;
388   return Status::OK();
389 }
390 
391 // Creates return nodes per operand of a FetchOp. If names is supplied, those
392 // names will be used per node in order instead of generating a unique name.
AddFetchNode(FuncOp function,mlir::tf_executor::FetchOp fetch,llvm::ArrayRef<llvm::StringRef> names)393 Status Exporter::AddFetchNode(FuncOp function, mlir::tf_executor::FetchOp fetch,
394                               llvm::ArrayRef<llvm::StringRef> names) {
395   Status status;
396   auto& return_nodes = returns_[fetch];
397   for (auto operand_and_idx : llvm::enumerate(fetch.getOperands())) {
398     if (operand_and_idx.value().getType().isa<mlir::tf_executor::ControlType>())
399       break;
400 
401     TF_ASSIGN_OR_RETURN(
402         auto node_def,
403         GetReturnNode(function, operand_and_idx.value(),
404                       operand_and_idx.index(),
405                       names.empty() ? "" : names[operand_and_idx.index()]));
406     Node* node = graph_->AddNode(*node_def, &status);
407     TF_RETURN_IF_ERROR(status);
408     return_nodes.push_back(node);
409   }
410   return Status::OK();
411 }
412 
413 // Collects control ret Nodes based on tf_executor.graph's associated
414 // tf_executor.fetch control inputs.
GetControlRetNodes(mlir::tf_executor::FetchOp fetch,absl::flat_hash_set<Node * > * control_ret_nodes)415 Status Exporter::GetControlRetNodes(
416     mlir::tf_executor::FetchOp fetch,
417     absl::flat_hash_set<Node*>* control_ret_nodes) {
418   for (Value fetch_operand : fetch.getOperands()) {
419     if (fetch_operand.getType().isa<mlir::tf_executor::ControlType>()) {
420       Operation* defining_op =
421           GetIslandInnerOpOrSelf(fetch_operand.getDefiningOp());
422       auto node_it = nodes_.find(defining_op);
423       TF_RET_CHECK(node_it != nodes_.end());
424       control_ret_nodes->insert(node_it->second);
425     }
426   }
427   return Status::OK();
428 }
429 
Convert(const GraphExportConfig & configs,const Dialect * tf_dialect,const SymbolTable & symbol_table,FuncOp function,FunctionDefLibrary * flib,llvm::SmallDenseSet<FuncOp> & visited_functions,absl::flat_hash_set<Node * > * control_ret_nodes)430 StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
431     const GraphExportConfig& configs, const Dialect* tf_dialect,
432     const SymbolTable& symbol_table, FuncOp function, FunctionDefLibrary* flib,
433     llvm::SmallDenseSet<FuncOp>& visited_functions,
434     absl::flat_hash_set<Node*>* control_ret_nodes) {
435   mlir::Block& block = function.front();
436 
437   // Extract input & output names if set.
438   llvm::SmallVector<llvm::StringRef, 2> input_names;
439   llvm::SmallVector<llvm::StringRef, 2> output_names;
440   auto dict_attr =
441       function->getAttrOfType<mlir::DictionaryAttr>(kEntryFuncAttr);
442   if (dict_attr) {
443     TF_RET_CHECK(dict_attr.get("inputs").isa<mlir::StringAttr>())
444         << "inputs missing in entry function attribute";
445     TF_RET_CHECK(dict_attr.get("outputs").isa<mlir::StringAttr>())
446         << "outputs missing in entry function attribute";
447     dict_attr.get("inputs").cast<mlir::StringAttr>().getValue().split(
448         input_names, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false);
449     dict_attr.get("outputs").cast<mlir::StringAttr>().getValue().split(
450         output_names, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false);
451   }
452 
453   auto graph = absl::make_unique<Graph>(OpRegistry::Global());
454 
455   // Extract version info.
456   VersionDef versions;
457   auto module = function->getParentOfType<mlir::ModuleOp>();
458   if (mlir::succeeded(ExtractTfVersions(module, &versions))) {
459     graph->set_versions(versions);
460   }
461 
462   Exporter exporter(graph.get(), tf_dialect);
463 
464   auto graph_op = llvm::cast<mlir::tf_executor::GraphOp>(block.front());
465 
466   // Set input and output names and increment the use counter for them to help
467   // generate unique names.
468   if (!output_names.empty()) {
469     const int num_data_results = graph_op.getNumResults();
470     const int64_t output_names_size = output_names.size();
471     TF_RET_CHECK(output_names_size == num_data_results)
472         << "output names (" << output_names.size()
473         << ") != terminator operands (" << num_data_results << ")";
474     llvm::DenseMap<Operation*, llvm::StringRef> output_op_to_name;
475     llvm::StringMap<Operation*> name_to_op;
476     for (const auto& it : llvm::enumerate(graph_op.GetFetch().getOperands())) {
477       // Skip control rets.
478       const int64_t index = it.index();
479       if (index >= num_data_results) break;
480       // TODO(jpienaar): If there is a result index specified, ensure only one
481       // and that it matches the result index of the op.
482       std::string name(output_names[index]);
483       auto tensor_id = ParseTensorName(name);
484       std::string tensor_id_node(tensor_id.node());
485       assert(!tensor_id_node.empty() && "expected non-empty name");
486       mlir::LegalizeNodeName(tensor_id_node);
487 
488       // Ensure name does not get reused.
489       (void)exporter.op_to_name_.GetUniqueName(tensor_id_node);
490     }
491   }
492 
493   if (!input_names.empty()) {
494     TF_RET_CHECK(input_names.size() == block.getNumArguments());
495     for (const auto& it : llvm::enumerate(function.getArguments())) {
496       // TODO(lyandy): Update when changing feed/fetch import.
497       std::string name(input_names[it.index()]);
498       assert(!name.empty() && "expected non-empty name");
499       mlir::LegalizeNodeName(name);
500       auto tensor_id = ParseTensorName(name);
501       TF_RET_CHECK(tensor_id.index() == 0)
502           << "input port designation not supported";
503       // Only assign user of argument the input name if the main graph did not
504       // have its _Arg nodes lifted into the functions arguments.
505       // Ensure name does not get reused.
506       (void)exporter.op_to_name_.GetUniqueName(name);
507     }
508   }
509 
510   // Adds nodes for basic block (function) arguments.
511   for (auto it : llvm::enumerate(block.getArguments())) {
512     int index = it.index();
513     auto arg = it.value();
514     mlir::Type type = arg.getType();
515     if (!type.isa<mlir::TensorType>()) {
516       return errors::InvalidArgument(
517           "FuncOps arguments must have tensor types. Found ",
518           mlir::debugString(type), " in function ", function.getName().str());
519     }
520 
521     TF_RETURN_IF_ERROR(exporter.AddArgumentNode(
522         arg, index, !input_names.empty() ? input_names[index] : ""));
523   }
524 
525   auto convert_called_function = [&](llvm::StringRef name) {
526     auto func = symbol_table.lookup<FuncOp>(name);
527     if (func != nullptr) {
528       TF_RETURN_IF_ERROR(ConvertLibFunction(configs, tf_dialect, symbol_table,
529                                             func, flib, visited_functions));
530       // TODO(prakalps): Optimize to only add the requested function to graph
531       // library rather than the all the functions exported so far.
532       TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(*flib));
533     }
534     return Status::OK();
535   };
536 
537   // Adds nodes for operations.
538   for (Operation& inst : graph_op.GetBody()) {
539     for (auto type : inst.getResultTypes())
540       if (!type.isa<mlir::TensorType, mlir::tf_executor::ControlType,
541                     mlir::tf_executor::TokenType>())
542         return errors::InvalidArgument(
543             "Values must be of tensor type, TensorFlow control type, or "
544             "TensorFlow token type. Found ",
545             mlir::debugString(type));
546 
547     if (llvm::isa<mlir::tf_executor::NextIterationSourceOp>(inst)) {
548       // Skip tf_executor.NextIteration.Source as associated
549       // tf_executor.NextIteration.Sink will be used instead.
550       continue;
551     } else if (auto fetch = llvm::dyn_cast<mlir::tf_executor::FetchOp>(inst)) {
552       TF_RETURN_IF_ERROR(exporter.AddFetchNode(function, fetch, output_names));
553     } else if (auto island =
554                    llvm::dyn_cast<mlir::tf_executor::IslandOp>(inst)) {
555       Operation& inner_op = island.GetBody().front();
556       auto op_name = GetTensorFlowOpName(inner_op.getName().getStringRef());
557       if (op_name.ok()) {
558         // If it is TF Control dialect specific op, look up custom operation
559         // in the module and first convert that, then add it to function
560         // definition library
561         // TODO(prakalps): If two functions have cyclic dependence, this will
562         // introduce an infinite loop.
563         TF_RETURN_IF_ERROR(convert_called_function(op_name.ValueOrDie().str()));
564       }
565 
566       if (IsLegacyCallInstruction(&inner_op)) {
567         TF_RETURN_IF_ERROR(convert_called_function(
568             inner_op.getAttrOfType<mlir::SymbolRefAttr>("f")
569                 .getLeafReference()));
570       }
571 
572       TF_RETURN_IF_ERROR(exporter.AddInstructionNode(&inner_op));
573     } else {
574       TF_RETURN_IF_ERROR(exporter.AddInstructionNode(&inst));
575     }
576   }
577   // Adds edges between the argument, operation and return nodes.
578   for (Operation& inst : graph_op.GetBody()) {
579     TF_RETURN_IF_ERROR(exporter.AddEdge(&inst));
580   }
581   // Fixes the edges between the inserted nodes and special "_SOURCE" and
582   // "_SINK".
583   FixupSourceAndSinkEdges(graph.get());
584 
585   TF_RETURN_IF_ERROR(
586       exporter.GetControlRetNodes(graph_op.GetFetch(), control_ret_nodes));
587 
588   return graph;
589 }
590 
ConvertLibFunction(const GraphExportConfig & configs,const Dialect * tf_dialect,const SymbolTable & symbol_table,FuncOp function,FunctionDefLibrary * flib,llvm::SmallDenseSet<FuncOp> & visited_functions)591 Status Exporter::ConvertLibFunction(
592     const GraphExportConfig& configs, const Dialect* tf_dialect,
593     const SymbolTable& symbol_table, FuncOp function, FunctionDefLibrary* flib,
594     llvm::SmallDenseSet<FuncOp>& visited_functions) {
595   // Return early if the function has already been exported.
596   bool is_new_function = visited_functions.insert(function).second;
597   if (!is_new_function) return Status::OK();
598 
599   auto function_name = function.getName().str();
600 
601   // TODO(fengliuai): use a small flib_def to reduce overhead
602   absl::flat_hash_set<Node*> control_ret_nodes;
603   TF_ASSIGN_OR_RETURN(
604       auto sub_graph,
605       Exporter::Convert(configs, tf_dialect, symbol_table, function, flib,
606                         visited_functions, &control_ret_nodes));
607   const auto control_ret = [&](const Node* n) -> absl::optional<string> {
608     return control_ret_nodes.contains(n)
609                ? absl::make_optional<string>(n->name())
610                : absl::nullopt;
611   };
612   FunctionDef func_def;
613   TF_RETURN_IF_ERROR(
614       GraphToFunctionDef(*sub_graph, function_name, control_ret, &func_def));
615 
616   // The node defs in FunctionDef might contain debug info which was added
617   // by the GraphToFunctionDef method. We should remove it if we don't want
618   // to export them to avoid failing the roundtrip test.
619   if (!configs.export_debug_info) {
620     for (auto& node_def : *func_def.mutable_node_def()) {
621       node_def.clear_experimental_debug_info();
622     }
623   }
624 
625   // Checks for gradient attribute. If present converts the gradient function
626   // and populates the GradientDef.
627   auto grad_string = mlir::TF::TensorFlowDialect::GetGradientAttrName();
628   if (auto attr =
629           function->getAttrOfType<mlir::FlatSymbolRefAttr>(grad_string)) {
630     auto grad_func = symbol_table.lookup<FuncOp>(attr.getValue());
631     TF_RETURN_IF_ERROR(ConvertLibFunction(configs, tf_dialect, symbol_table,
632                                           grad_func, flib, visited_functions));
633     GradientDef grad;
634     grad.set_function_name(function_name);
635     grad.set_gradient_func(grad_func.getName().str());
636     *flib->add_gradient() = grad;
637   }
638 
639   auto stateful_string = mlir::TF::TensorFlowDialect::GetStatefulAttrName();
640   if (auto attr = function->getAttrOfType<mlir::UnitAttr>(stateful_string)) {
641     func_def.mutable_signature()->set_is_stateful(true);
642   }
643 
644   // Ignore the gradient and is_stateful attribute on the function as they have
645   // been handled above. Ignore the entry func attribute as it is an MLIR
646   // metadata attribute and is not required in the function definition.
647   absl::flat_hash_set<absl::string_view> attrs_to_ignore = {
648       grad_string.data(), stateful_string.data(), kEntryFuncAttr};
649   llvm::SmallVector<mlir::NamedAttribute, 8> funcAttrs(
650       function->getDialectAttrs());
651   TF_RETURN_IF_ERROR(ConvertAttributes(funcAttrs, attrs_to_ignore,
652                                        /*remove_ref_type=*/false,
653                                        func_def.mutable_attr()));
654 
655   for (int i = 0, e = function.getNumArguments(); i < e; ++i) {
656     if (auto resource_arg_unique_id_attr =
657             function.getArgAttrOfType<mlir::IntegerAttr>(
658                 i, kResourceArgUniqueIdAttr)) {
659       (*func_def.mutable_resource_arg_unique_id())[i] =
660           resource_arg_unique_id_attr.getInt();
661     }
662 
663     llvm::ArrayRef<mlir::NamedAttribute> func_arg_i_attrs =
664         function.getArgAttrs(i);
665     if (func_arg_i_attrs.empty()) continue;
666     absl::flat_hash_set<absl::string_view> attrs_to_ignore = {
667         kDeviceAttr, kResourceArgUniqueIdAttr};
668     FunctionDef::ArgAttrs func_def_arg_i_attrs;
669     TF_RETURN_IF_ERROR(ConvertAttributes(func_arg_i_attrs, attrs_to_ignore,
670                                          /*remove_ref_type=*/false,
671                                          func_def_arg_i_attrs.mutable_attr()));
672     if (func_def_arg_i_attrs.attr().empty()) continue;
673     (*func_def.mutable_arg_attr())[i] = std::move(func_def_arg_i_attrs);
674   }
675 
676   (*flib->add_function()) = std::move(func_def);
677   return Status::OK();
678 }
679 
Convert(mlir::ModuleOp module,const GraphExportConfig & configs,std::unique_ptr<Graph> * graph,FunctionLibraryDefinition * flib_def,absl::flat_hash_set<Node * > * control_ret_nodes)680 Status Exporter::Convert(mlir::ModuleOp module,
681                          const GraphExportConfig& configs,
682                          std::unique_ptr<Graph>* graph,
683                          FunctionLibraryDefinition* flib_def,
684                          absl::flat_hash_set<Node*>* control_ret_nodes) {
685   mlir::Identifier entry_func_id =
686       mlir::Identifier::get("main", module.getContext());
687   absl::optional<FuncOp> entry_func;
688   FunctionDefLibrary flib;
689   llvm::SmallDenseSet<FuncOp> visited_functions;
690   auto tf_dialect = module.getContext()->getLoadedDialect("tf");
691   // Construct SymbolTable to enable cheap function lookups. The cost
692   // of constructing the table is offset by the number of queries.
693   SymbolTable symbol_table(module);
694   for (auto function : module.getOps<FuncOp>()) {
695     if (function.isExternal())
696       return errors::FailedPrecondition("External functions not supported");
697 
698     if (function.getName() == entry_func_id &&
699         !configs.export_entry_func_to_flib) {
700       entry_func.emplace(function);
701     } else {
702       TF_RETURN_IF_ERROR(ConvertLibFunction(configs, tf_dialect, symbol_table,
703                                             function, &flib,
704                                             visited_functions));
705     }
706   }
707 
708   if (!configs.export_entry_func_to_flib) {
709     if (!entry_func.has_value())
710       return errors::FailedPrecondition(
711           "entry function `main` must be present");
712 
713     // Updates the graph and the function library definition.
714     TF_ASSIGN_OR_RETURN(
715         *graph,
716         Exporter::Convert(configs, tf_dialect, symbol_table, entry_func.value(),
717                           &flib, visited_functions, control_ret_nodes));
718     // Add FunctionDefs and GradientDefs of MLIR functions to graph's function
719     // library. If duplicate FunctionDefs already exist (can happen if exporter
720     // had already added some FunctionDefs to the library to support legacy
721     // calls), they are ignored.
722     TF_RETURN_IF_ERROR(graph->get()->AddFunctionLibrary(flib));
723   }
724 
725   for (auto& func_def : flib.function()) {
726     TF_RETURN_IF_ERROR(flib_def->AddFunctionDef(func_def));
727   }
728   for (auto& grad_def : flib.gradient()) {
729     TF_RETURN_IF_ERROR(flib_def->AddGradientDef(grad_def));
730   }
731   return Status::OK();
732 }
733 }  // namespace
734 
ConvertMlirToGraph(mlir::ModuleOp module,const GraphExportConfig & configs,std::unique_ptr<Graph> * graph,FunctionLibraryDefinition * flib_def,absl::flat_hash_set<Node * > * control_ret_nodes)735 Status ConvertMlirToGraph(mlir::ModuleOp module,
736                           const GraphExportConfig& configs,
737                           std::unique_ptr<Graph>* graph,
738                           FunctionLibraryDefinition* flib_def,
739                           absl::flat_hash_set<Node*>* control_ret_nodes) {
740   mlir::StatusScopedDiagnosticHandler sh(module.getContext());
741   if (failed(VerifyExportSuitable(module))) return sh.ConsumeStatus();
742   return sh.Combine(
743       Exporter::Convert(module, configs, graph, flib_def, control_ret_nodes));
744 }
745 
ConvertMlirToGraph(mlir::ModuleOp module,const GraphExportConfig & configs,std::unique_ptr<Graph> * graph,FunctionLibraryDefinition * flib_def)746 Status ConvertMlirToGraph(mlir::ModuleOp module,
747                           const GraphExportConfig& configs,
748                           std::unique_ptr<Graph>* graph,
749                           FunctionLibraryDefinition* flib_def) {
750   absl::flat_hash_set<Node*> control_ret_nodes;
751   return ConvertMlirToGraph(module, configs, graph, flib_def,
752                             &control_ret_nodes);
753 }
754 
ConvertMlirToGraphdef(mlir::ModuleOp module,const GraphExportConfig & configs)755 StatusOr<std::unique_ptr<GraphDef>> ConvertMlirToGraphdef(
756     mlir::ModuleOp module, const GraphExportConfig& configs) {
757   FunctionLibraryDefinition flib_def(OpRegistry::Global(),
758                                      FunctionDefLibrary());
759   std::unique_ptr<Graph> graph;
760   TF_RETURN_IF_ERROR(ConvertMlirToGraph(module, configs, &graph, &flib_def));
761 
762   // If the entry function is exported to flib, then no graph is constructed.
763   // Construct one in that case.
764   if (configs.export_entry_func_to_flib) {
765     graph = std::make_unique<Graph>(OpRegistry::Global());
766     // TODO(hinsu): Avoid Proto -> Memory -> Proto conversion here.
767     FunctionDefLibrary flib = flib_def.ToProto();
768     TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(flib));
769   }
770 
771   auto graphdef = absl::make_unique<GraphDef>();
772   graph->ToGraphDef(graphdef.get());
773   if (!configs.export_library) graphdef->clear_library();
774   if (!configs.export_shapes) {
775     for (auto& node_def : *graphdef->mutable_node()) {
776       node_def.mutable_attr()->erase("shape");
777     }
778   }
779   if (!configs.export_debug_info) {
780     for (auto& node_def : *graphdef->mutable_node()) {
781       node_def.clear_experimental_debug_info();
782     }
783   }
784   return graphdef;
785 }
786 
ConvertMlirFunctionToFunctionLibraryDef(FuncOp func,const GraphExportConfig & configs,FunctionDef * function_def)787 stream_executor::port::Status ConvertMlirFunctionToFunctionLibraryDef(
788     FuncOp func, const GraphExportConfig& configs, FunctionDef* function_def) {
789   Dialect* tf_dialect = func.getContext()->getLoadedDialect("tf");
790   FunctionDefLibrary flib;
791   llvm::SmallDenseSet<FuncOp> visited_functions;
792   // Construct SymbolTable to enable cheap function lookups. The cost
793   // of constructing the table is offset by the number of queries. Even
794   // though this only converts one function in theory, this function
795   // may have gradient associated which would result in a lookup. This
796   // could be made lazy if we find this to be broad.
797   SymbolTable symbol_table(func->getParentOfType<mlir::ModuleOp>());
798   TF_RETURN_IF_ERROR(Exporter::ConvertLibFunction(
799       configs, tf_dialect, symbol_table, func, &flib, visited_functions));
800   for (auto& func_def : flib.function()) {
801     if (func_def.signature().name() == func.getName()) {
802       *function_def = func_def;
803       return Status::OK();
804     }
805   }
806   return errors::InvalidArgument(
807       "Function couldn't be found in the FunctionDefLibrary after converting "
808       "from MLIR");
809 }
810 
811 }  // namespace tensorflow
812