1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
17
18 #include <utility>
19
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/container/inlined_vector.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/string_view.h"
25 #include "absl/types/optional.h"
26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/ADT/StringRef.h"
30 #include "llvm/Support/Casting.h"
31 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
32 #include "mlir/IR/Attributes.h" // from @llvm-project
33 #include "mlir/IR/Builders.h" // from @llvm-project
34 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
35 #include "mlir/IR/Identifier.h" // from @llvm-project
36 #include "mlir/IR/Location.h" // from @llvm-project
37 #include "mlir/IR/Operation.h" // from @llvm-project
38 #include "mlir/IR/Types.h" // from @llvm-project
39 #include "mlir/Pass/Pass.h" // from @llvm-project
40 #include "mlir/Pass/PassManager.h" // from @llvm-project
41 #include "mlir/Support/DebugStringHelper.h" // from @llvm-project
42 #include "mlir/Support/LogicalResult.h" // from @llvm-project
43 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
45 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
46 #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
47 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
48 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
49 #include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h"
50 #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
51 #include "tensorflow/compiler/mlir/utils/name_utils.h"
52 #include "tensorflow/compiler/xla/status_macros.h"
53 #include "tensorflow/core/framework/graph.pb.h"
54 #include "tensorflow/core/framework/graph_to_functiondef.h"
55 #include "tensorflow/core/framework/node_def.pb.h"
56 #include "tensorflow/core/framework/node_def_util.h"
57 #include "tensorflow/core/framework/op.h"
58 #include "tensorflow/core/framework/types.pb.h"
59 #include "tensorflow/core/framework/versions.pb.h"
60 #include "tensorflow/core/graph/algorithm.h"
61 #include "tensorflow/core/graph/graph.h"
62 #include "tensorflow/core/graph/tensor_id.h"
63 #include "tensorflow/core/lib/core/errors.h"
64 #include "tensorflow/core/lib/core/status.h"
65
66 namespace tensorflow {
67 using llvm::dyn_cast;
68 using llvm::isa;
69 using mlir::BlockArgument;
70 using mlir::Dialect;
71 using mlir::Operation;
72 using mlir::Value;
73 using stream_executor::port::StatusOr;
74
75 namespace {
76
77 constexpr char kInvalidExecutorGraphMsg[] =
78 "Functions must be of a single Graph with single op Islands: ";
79
80 constexpr char kDeviceAttr[] = "tf.device";
81 constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id";
82
83 // OpOrArgLocNameMapper that legalizes the returned name.
84 class LegalizedOpOrValLocNameMapper : public OpOrArgLocNameMapper {
85 private:
GetName(OpOrVal op_or_val)86 std::string GetName(OpOrVal op_or_val) override {
87 std::string name = OpOrArgLocNameMapper::GetName(op_or_val);
88 assert(!name.empty() && "expected non-empty name");
89 mlir::LegalizeNodeName(name);
90 return name;
91 }
92 };
93
94 // Checks functions in module are of single tf_executor.graph and each
95 // tf_executor.island in tf_executor.graph only has a single op.
HasSingleGraphSingleOpIslandsFunctions(mlir::ModuleOp module)96 Status HasSingleGraphSingleOpIslandsFunctions(mlir::ModuleOp module) {
97 Status status = Status::OK();
98 module.walk([&](mlir::FuncOp function) {
99 if (!llvm::hasSingleElement(function)) {
100 status = errors::FailedPrecondition(
101 kInvalidExecutorGraphMsg,
102 "only single block functions are supported.");
103 return mlir::WalkResult::interrupt();
104 }
105
106 auto block = function.front().without_terminator();
107 auto graph = llvm::dyn_cast<mlir::tf_executor::GraphOp>(block.begin());
108 if (!graph) {
109 status = errors::FailedPrecondition(
110 kInvalidExecutorGraphMsg,
111 "first op in function is not a tf_executor.graph.");
112 return mlir::WalkResult::interrupt();
113 }
114
115 if (!hasSingleElement(block)) {
116 status = errors::FailedPrecondition(
117 kInvalidExecutorGraphMsg,
118 "function does not only contain a single tf_executor.graph.");
119 return mlir::WalkResult::interrupt();
120 }
121
122 for (Operation& op : graph.GetBody()) {
123 auto island = llvm::dyn_cast<mlir::tf_executor::IslandOp>(op);
124 if (!island) continue;
125
126 if (!island.WrapsSingleOp()) {
127 status = errors::FailedPrecondition(
128 kInvalidExecutorGraphMsg,
129 "tf_executor.island must perfectly wrap a single op.");
130 return mlir::WalkResult::interrupt();
131 }
132 }
133
134 return mlir::WalkResult::advance();
135 });
136
137 return status;
138 }
139
140 // Finds first inner op if `op` is a tf_executor.island. Otherwise `op` is
141 // returned.
GetIslandInnerOpOrSelf(mlir::Operation * op)142 Operation* GetIslandInnerOpOrSelf(mlir::Operation* op) {
143 auto island = llvm::dyn_cast<mlir::tf_executor::IslandOp>(op);
144 if (island) return &island.GetBody().front();
145 return op;
146 }
147
148 // Stateful helper class to export a function into a Graph.
149 class Exporter {
150 public:
151 // Converts the given Module to a Graph. The given module should only contain
152 // one entry function, which is identified by name "main". This entry function
153 // is converted to the base of the graph graph. The rest of the functions are
154 // converted to the library functions in that graph.
155 static Status Convert(mlir::ModuleOp module, const GraphExportConfig& configs,
156 std::unique_ptr<Graph>* graph,
157 FunctionLibraryDefinition* flib_def,
158 absl::flat_hash_set<Node*>* control_ret_nodes);
159
160 // Converts a given FuncOp to a FunctionDef and adds it to the function
161 // definition library
162 static Status ConvertLibFunction(const GraphExportConfig& configs,
163 const Dialect* tf_dialect,
164 mlir::FuncOp function,
165 FunctionDefLibrary* flib);
166 // Converts the given FuncOp to a Graph. The arguments and returns of
167 // function are added to the graph with special op names kArgOp and kRetOp.
168 // Later on, this graph can be converted a function definition and added to
169 // another graph.
170 static StatusOr<std::unique_ptr<Graph>> Convert(
171 const GraphExportConfig& configs, const Dialect* tf_dialect,
172 mlir::FuncOp function, FunctionDefLibrary* flib,
173 absl::flat_hash_set<Node*>* control_ret_nodes);
174
175 private:
Exporter(Graph * graph,const Dialect * tf_dialect)176 explicit Exporter(Graph* graph, const Dialect* tf_dialect)
177 : graph_(graph), tf_dialect_(tf_dialect) {}
178
179 Status AddArgumentNode(BlockArgument arg, unsigned index,
180 llvm::StringRef name);
181 Status AddFetchNode(mlir::FuncOp function, mlir::tf_executor::FetchOp fetch,
182 llvm::ArrayRef<llvm::StringRef> names);
183 Status AddInstructionNode(Operation* inst);
184 Status AddEdge(Operation* inst);
185
186 StatusOr<std::unique_ptr<NodeDef>> GetArgumentNode(BlockArgument arg,
187 unsigned index,
188 llvm::StringRef name);
189 StatusOr<std::unique_ptr<NodeDef>> GetReturnNode(mlir::FuncOp function,
190 Value operand,
191 unsigned index,
192 llvm::StringRef name);
193 Status GetControlRetNodes(mlir::tf_executor::FetchOp fetch,
194 absl::flat_hash_set<Node*>* control_ret_nodes);
195 // Adds one edge between src_node and dst_node. If it is not a control edge,
196 // an index is used to find out the right operand of the dst_node.
197 Status AddEdgeBetweenNodes(Value src, Node* dst_node, unsigned dst_index);
198
199 Graph* graph_;
200 LegalizedOpOrValLocNameMapper op_to_name_;
201 absl::flat_hash_map<Operation*, Node*> nodes_;
202 llvm::DenseMap<BlockArgument, Node*> args_;
203 // One single return operation can return multiple results, and each of them
204 // will be converted to one node in the graph.
205 typedef absl::InlinedVector<Node*, 4> NodeVector;
206 absl::flat_hash_map<Operation*, NodeVector> returns_;
207 const mlir::Dialect* tf_dialect_;
208 };
209
GetArgumentNode(BlockArgument arg,unsigned index,llvm::StringRef name)210 StatusOr<std::unique_ptr<NodeDef>> Exporter::GetArgumentNode(
211 BlockArgument arg, unsigned index, llvm::StringRef name) {
212 auto func = arg.getParentRegion()->getParentOfType<mlir::FuncOp>();
213
214 auto node_def = absl::make_unique<NodeDef>();
215 if (!name.empty())
216 node_def->set_name(name.str());
217 else
218 node_def->set_name(
219 std::string(op_to_name_.GetUniqueName(func.getName().str())));
220
221 node_def->set_op(FunctionLibraryDefinition::kArgOp);
222
223 TF_RETURN_IF_ERROR(SetShapeAttribute("_output_shapes",
224 arg.getType().cast<mlir::ShapedType>(),
225 node_def->mutable_attr()));
226
227 DataType dtype;
228 TF_RETURN_IF_ERROR(ConvertToDataType(
229 arg.getType().cast<mlir::TensorType>().getElementType(), &dtype));
230 AttrValue type_attr;
231 type_attr.set_type(dtype);
232 (*node_def->mutable_attr())["T"] = type_attr;
233
234 AttrValue index_attr;
235 index_attr.set_i(index);
236 (*node_def->mutable_attr())["index"] = index_attr;
237
238 if (auto device_attr =
239 func.getArgAttrOfType<mlir::StringAttr>(index, kDeviceAttr))
240 *node_def->mutable_device() = device_attr.getValue().str();
241
242 llvm::ArrayRef<mlir::NamedAttribute> func_arg_i_attrs =
243 func.getArgAttrs(index);
244 absl::flat_hash_set<absl::string_view> attrs_to_ignore = {kDeviceAttr};
245 TF_RETURN_IF_ERROR(ConvertAttributes(func_arg_i_attrs, attrs_to_ignore,
246 /*remove_ref_type=*/false,
247 node_def->mutable_attr()));
248
249 return node_def;
250 }
251
GetReturnNode(mlir::FuncOp function,Value operand,unsigned index,llvm::StringRef name)252 StatusOr<std::unique_ptr<NodeDef>> Exporter::GetReturnNode(
253 mlir::FuncOp function, Value operand, unsigned index,
254 llvm::StringRef name) {
255 auto node_def = absl::make_unique<NodeDef>();
256 if (!name.empty())
257 node_def->set_name(name.str());
258 else
259 node_def->set_name(
260 std::string(op_to_name_.GetUniqueName(function.getName().str())));
261
262 node_def->set_op(FunctionLibraryDefinition::kRetOp);
263 DataType dtype;
264 TF_RETURN_IF_ERROR(ConvertToDataType(
265 operand.getType().cast<mlir::TensorType>().getElementType(), &dtype));
266 AttrValue type_attr;
267 type_attr.set_type(dtype);
268 (*node_def->mutable_attr())["T"] = type_attr;
269 AttrValue index_attr;
270 index_attr.set_i(index);
271 (*node_def->mutable_attr())["index"] = index_attr;
272
273 if (auto device_attr =
274 function.getResultAttrOfType<mlir::StringAttr>(index, kDeviceAttr))
275 *node_def->mutable_device() = device_attr.getValue().str();
276
277 llvm::ArrayRef<mlir::NamedAttribute> func_res_i_attrs =
278 function.getResultAttrs(index);
279 absl::flat_hash_set<absl::string_view> attrs_to_ignore = {kDeviceAttr};
280 TF_RETURN_IF_ERROR(ConvertAttributes(func_res_i_attrs, attrs_to_ignore,
281 /*remove_ref_type=*/false,
282 node_def->mutable_attr()));
283
284 return node_def;
285 }
286
AddEdgeBetweenNodes(Value src,Node * dst_node,unsigned dst_index)287 Status Exporter::AddEdgeBetweenNodes(Value src, Node* dst_node,
288 unsigned dst_index) {
289 if (auto input_result = src.dyn_cast<mlir::OpResult>()) {
290 auto* input_inst = GetIslandInnerOpOrSelf(input_result.getOwner());
291 // Replaces the input node with NextIteration sink if it is a NextIteration
292 // source.
293 if (auto next_iter_source =
294 llvm::dyn_cast<mlir::tf_executor::NextIterationSourceOp>(
295 input_inst))
296 input_inst = next_iter_source.GetSink();
297
298 auto node_it = nodes_.find(input_inst);
299 TF_RET_CHECK(node_it != nodes_.end())
300 << "Use of OpResult encountered before def!";
301 if (input_result.getType().isa<mlir::tf_executor::ControlType>()) {
302 graph_->AddControlEdge(node_it->second, dst_node);
303 } else {
304 graph_->AddEdge(node_it->second, input_result.getResultNumber(), dst_node,
305 dst_index);
306 }
307 return Status::OK();
308 }
309
310 auto input_arg = src.cast<BlockArgument>();
311 auto input_node_it = args_.find(input_arg);
312 TF_RET_CHECK(input_node_it != args_.end())
313 << "Use of BlockArgument encounted before def!";
314 // For argument, there is only one result output, so the index is always 0.
315 graph_->AddEdge(input_node_it->second, 0, dst_node, dst_index);
316 return Status::OK();
317 }
318
AddEdge(Operation * inst)319 Status Exporter::AddEdge(Operation* inst) {
320 // For tf_executor.fetch, add only its data edges. Control edges are captured
321 // later.
322 if (auto fetch = llvm::dyn_cast<mlir::tf_executor::FetchOp>(inst)) {
323 for (auto operand_and_idx : llvm::enumerate(fetch.getOperands())) {
324 Value operand = operand_and_idx.value();
325 if (operand.getType().isa<mlir::tf_executor::ControlType>()) break;
326
327 auto* dst_node = returns_[fetch][operand_and_idx.index()];
328 TF_RETURN_IF_ERROR(AddEdgeBetweenNodes(operand, dst_node, 0));
329 }
330
331 return Status::OK();
332 }
333
334 // For tf_executor.NextIteration.Sink, skip its token operand and add data and
335 // control edges with their index offset by 1.
336 if (auto next_iter_sink =
337 llvm::dyn_cast<mlir::tf_executor::NextIterationSinkOp>(inst)) {
338 auto* dst_node = nodes_[inst];
339 TF_RETURN_IF_ERROR(
340 AddEdgeBetweenNodes(next_iter_sink.input(), dst_node, 0));
341 for (auto control_and_idx : llvm::enumerate(next_iter_sink.controlInputs()))
342 TF_RETURN_IF_ERROR(AddEdgeBetweenNodes(control_and_idx.value(), dst_node,
343 control_and_idx.index() + 1));
344
345 return Status::OK();
346 }
347
348 // For tf_executor.NextIteration.Source, op can be skipped as it is assumed
349 // there are no operands.
350 if (llvm::isa<mlir::tf_executor::NextIterationSourceOp>(inst)) {
351 assert(inst->getNumOperands() == 0);
352 return Status::OK();
353 }
354
355 Operation* op = GetIslandInnerOpOrSelf(inst);
356 auto* dst_node = nodes_[op];
357 int operand_offset = 0;
358 // For tf_executor.island, add data edges from its wrapped op before control
359 // edges.
360 if (auto island = llvm::dyn_cast<mlir::tf_executor::IslandOp>(inst)) {
361 for (auto operand_and_idx : llvm::enumerate(op->getOperands()))
362 TF_RETURN_IF_ERROR(AddEdgeBetweenNodes(operand_and_idx.value(), dst_node,
363 operand_and_idx.index()));
364
365 operand_offset = op->getNumOperands();
366 }
367
368 // For all other ops (including tf_executor.island), add remaining edges.
369 for (auto operand_and_idx : llvm::enumerate(inst->getOperands()))
370 TF_RETURN_IF_ERROR(
371 AddEdgeBetweenNodes(operand_and_idx.value(), dst_node,
372 operand_and_idx.index() + operand_offset));
373
374 return Status::OK();
375 }
376
AddInstructionNode(Operation * inst)377 Status Exporter::AddInstructionNode(Operation* inst) {
378 std::unique_ptr<NodeDef> node_def;
379 auto name = op_to_name_.GetUniqueName(inst);
380 // Convert registered TF ops to NodeDef. Only registered ops are handled to
381 // ensure that PopulateDerivedAttrs adds the correct attributes.
382 TF_ASSIGN_OR_RETURN(node_def,
383 ConvertTFDialectOpToNodeDef(
384 inst, name, /*ignore_unregistered_attrs=*/false));
385
386 Status status;
387 Node* node = graph_->AddNode(*node_def, &status);
388 TF_RETURN_IF_ERROR(status);
389 DCHECK(node != nullptr);
390 nodes_[inst] = node;
391 return Status::OK();
392 }
393
IsEntryFunctionArg(BlockArgument arg)394 bool IsEntryFunctionArg(BlockArgument arg) {
395 return arg.getParentRegion()->getParentOfType<mlir::FuncOp>().getName() ==
396 "main";
397 }
398
399 // Creates argument nodes from Block argument. If a name is supplied, that
400 // name will be used instead of generating a unique name.
AddArgumentNode(BlockArgument arg,unsigned index,llvm::StringRef name)401 Status Exporter::AddArgumentNode(BlockArgument arg, unsigned index,
402 llvm::StringRef name) {
403 TF_ASSIGN_OR_RETURN(auto node_def, GetArgumentNode(arg, index, name));
404 Status status;
405 Node* node = graph_->AddNode(*node_def, &status);
406 TF_RETURN_IF_ERROR(status);
407 args_[arg] = node;
408 return Status::OK();
409 }
410
411 // Creates return nodes per operand of a FetchOp. If names is supplied, those
412 // names will be used per node in order instead of generating a unique name.
AddFetchNode(mlir::FuncOp function,mlir::tf_executor::FetchOp fetch,llvm::ArrayRef<llvm::StringRef> names)413 Status Exporter::AddFetchNode(mlir::FuncOp function,
414 mlir::tf_executor::FetchOp fetch,
415 llvm::ArrayRef<llvm::StringRef> names) {
416 Status status;
417 auto& return_nodes = returns_[fetch];
418 for (auto operand_and_idx : llvm::enumerate(fetch.getOperands())) {
419 if (operand_and_idx.value().getType().isa<mlir::tf_executor::ControlType>())
420 break;
421
422 TF_ASSIGN_OR_RETURN(
423 auto node_def,
424 GetReturnNode(function, operand_and_idx.value(),
425 operand_and_idx.index(),
426 names.empty() ? "" : names[operand_and_idx.index()]));
427 Node* node = graph_->AddNode(*node_def, &status);
428 TF_RETURN_IF_ERROR(status);
429 return_nodes.push_back(node);
430 }
431 return Status::OK();
432 }
433
434 // Collects control ret Nodes based on tf_executor.graph's associated
435 // tf_executor.fetch control inputs.
GetControlRetNodes(mlir::tf_executor::FetchOp fetch,absl::flat_hash_set<Node * > * control_ret_nodes)436 Status Exporter::GetControlRetNodes(
437 mlir::tf_executor::FetchOp fetch,
438 absl::flat_hash_set<Node*>* control_ret_nodes) {
439 for (Value fetch_operand : fetch.getOperands()) {
440 if (fetch_operand.getType().isa<mlir::tf_executor::ControlType>()) {
441 Operation* defining_op =
442 GetIslandInnerOpOrSelf(fetch_operand.getDefiningOp());
443 auto node_it = nodes_.find(defining_op);
444 TF_RET_CHECK(node_it != nodes_.end());
445 control_ret_nodes->insert(node_it->second);
446 }
447 }
448 return Status::OK();
449 }
450
Convert(const GraphExportConfig & configs,const Dialect * tf_dialect,mlir::FuncOp function,FunctionDefLibrary * flib,absl::flat_hash_set<Node * > * control_ret_nodes)451 StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
452 const GraphExportConfig& configs, const Dialect* tf_dialect,
453 mlir::FuncOp function, FunctionDefLibrary* flib,
454 absl::flat_hash_set<Node*>* control_ret_nodes) {
455 mlir::Block& block = function.front();
456
457 // Extract input & output names if set.
458 llvm::SmallVector<llvm::StringRef, 2> input_names;
459 llvm::SmallVector<llvm::StringRef, 2> output_names;
460 auto dict_attr =
461 function->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
462 if (dict_attr) {
463 TF_RET_CHECK(dict_attr.get("inputs").isa<mlir::StringAttr>())
464 << "inputs missing in entry function attribute";
465 TF_RET_CHECK(dict_attr.get("outputs").isa<mlir::StringAttr>())
466 << "outputs missing in entry function attribute";
467 dict_attr.get("inputs").cast<mlir::StringAttr>().getValue().split(
468 input_names, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false);
469 dict_attr.get("outputs").cast<mlir::StringAttr>().getValue().split(
470 output_names, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false);
471 }
472
473 auto graph = absl::make_unique<Graph>(OpRegistry::Global());
474
475 // Extract version info.
476 VersionDef versions;
477 auto module = function->getParentOfType<mlir::ModuleOp>();
478 if (mlir::succeeded(ExtractTfVersions(module, &versions))) {
479 graph->set_versions(versions);
480 }
481
482 // We have to add the function library here, so a custom operation, which is
483 // defined in the function library can be added to the graph.
484 TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(*flib));
485 Exporter exporter(graph.get(), tf_dialect);
486
487 auto graph_op = llvm::cast<mlir::tf_executor::GraphOp>(block.front());
488
489 // Set input and output names and increment the use counter for them to help
490 // generate unique names.
491 if (!output_names.empty()) {
492 const int num_data_results = graph_op.getNumResults();
493 const int64 output_names_size = output_names.size();
494 TF_RET_CHECK(output_names_size == num_data_results)
495 << "output names (" << output_names.size()
496 << ") != terminator operands (" << num_data_results << ")";
497 llvm::DenseMap<Operation*, llvm::StringRef> output_op_to_name;
498 llvm::StringMap<Operation*> name_to_op;
499 for (const auto& it : llvm::enumerate(graph_op.GetFetch().getOperands())) {
500 // Skip control rets.
501 const int64 index = it.index();
502 if (index >= num_data_results) break;
503 // TODO(jpienaar): If there is a result index specified, ensure only one
504 // and that it matches the result index of the op.
505 std::string name(output_names[index]);
506 auto tensor_id = ParseTensorName(name);
507 std::string tensor_id_node(tensor_id.node());
508 assert(!tensor_id_node.empty() && "expected non-empty name");
509 mlir::LegalizeNodeName(tensor_id_node);
510
511 // Ensure name does not get reused.
512 (void)exporter.op_to_name_.GetUniqueName(tensor_id_node);
513 }
514 }
515
516 if (!input_names.empty()) {
517 TF_RET_CHECK(input_names.size() == block.getNumArguments());
518 for (const auto& it : llvm::enumerate(function.getArguments())) {
519 // TODO(lyandy): Update when changing feed/fetch import.
520 std::string name(input_names[it.index()]);
521 assert(!name.empty() && "expected non-empty name");
522 mlir::LegalizeNodeName(name);
523 auto tensor_id = ParseTensorName(name);
524 TF_RET_CHECK(tensor_id.index() == 0)
525 << "input port designation not supported";
526 // Only assign user of argument the input name if the main graph did not
527 // have its _Arg nodes lifted into the functions arguments.
528 // Ensure name does not get reused.
529 (void)exporter.op_to_name_.GetUniqueName(name);
530 }
531 }
532
533 // Adds nodes for basic block (function) arguments.
534 for (auto it : llvm::enumerate(block.getArguments())) {
535 int index = it.index();
536 auto arg = it.value();
537 mlir::Type type = arg.getType();
538 if (!type.isa<mlir::TensorType>()) {
539 return errors::InvalidArgument(
540 "FuncOps arguments must have tensor types. Found ",
541 mlir::debugString(type), " in function ", function.getName().str());
542 }
543
544 TF_RETURN_IF_ERROR(exporter.AddArgumentNode(
545 arg, index, !input_names.empty() ? input_names[index] : ""));
546 }
547
548 auto convert_called_function = [&](llvm::StringRef name) {
549 auto func =
550 function->getParentOfType<mlir::ModuleOp>().lookupSymbol<mlir::FuncOp>(
551 name);
552 if (func != nullptr) {
553 TF_RETURN_IF_ERROR(ConvertLibFunction(configs, tf_dialect, func, flib));
554 TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(*flib));
555 }
556 return Status::OK();
557 };
558
559 // Adds nodes for operations.
560 for (Operation& inst : graph_op.GetBody()) {
561 for (auto type : inst.getResultTypes())
562 if (!type.isa<mlir::TensorType, mlir::tf_executor::ControlType,
563 mlir::tf_executor::TokenType>())
564 return errors::InvalidArgument(
565 "Values must be of tensor type, TensorFlow control type, or "
566 "TensorFlow token type. Found ",
567 mlir::debugString(type));
568
569 if (llvm::isa<mlir::tf_executor::NextIterationSourceOp>(inst)) {
570 // Skip tf_executor.NextIteration.Source as associated
571 // tf_executor.NextIteration.Sink will be used instead.
572 continue;
573 } else if (auto fetch = llvm::dyn_cast<mlir::tf_executor::FetchOp>(inst)) {
574 TF_RETURN_IF_ERROR(exporter.AddFetchNode(function, fetch, output_names));
575 } else if (auto island =
576 llvm::dyn_cast<mlir::tf_executor::IslandOp>(inst)) {
577 Operation& inner_op = island.GetBody().front();
578 auto op_name = GetTensorFlowOpName(inner_op.getName().getStringRef());
579 if (op_name.ok()) {
580 // If it is TF Control dialect specific op, look up custom operation
581 // in the module and first convert that, then add it to function
582 // definition library
583 // TODO(prakalps): If two functions have cyclic dependence, this will
584 // introduce an infinite loop.
585 TF_RETURN_IF_ERROR(convert_called_function(op_name.ValueOrDie().str()));
586 }
587
588 if (IsLegacyCallInstruction(&inner_op)) {
589 TF_RETURN_IF_ERROR(convert_called_function(
590 inner_op.getAttrOfType<mlir::SymbolRefAttr>("f")
591 .getLeafReference()));
592 }
593
594 TF_RETURN_IF_ERROR(exporter.AddInstructionNode(&inner_op));
595 } else {
596 TF_RETURN_IF_ERROR(exporter.AddInstructionNode(&inst));
597 }
598 }
599 // Adds edges between the argument, operation and return nodes.
600 for (Operation& inst : graph_op.GetBody()) {
601 TF_RETURN_IF_ERROR(exporter.AddEdge(&inst));
602 }
603 // Fixes the edges between the inserted nodes and special "_SOURCE" and
604 // "_SINK".
605 FixupSourceAndSinkEdges(graph.get());
606
607 TF_RETURN_IF_ERROR(
608 exporter.GetControlRetNodes(graph_op.GetFetch(), control_ret_nodes));
609
610 return graph;
611 }
612
ConvertLibFunction(const GraphExportConfig & configs,const Dialect * tf_dialect,mlir::FuncOp function,FunctionDefLibrary * flib)613 Status Exporter::ConvertLibFunction(const GraphExportConfig& configs,
614 const Dialect* tf_dialect,
615 mlir::FuncOp function,
616 FunctionDefLibrary* flib) {
617 // First look for the function in the current function library. If found,
618 // nothing needs to be done.
619 OpRegistry empty_registry;
620 FunctionLibraryDefinition flib_def(&empty_registry, *flib);
621 auto function_name = function.getName().str();
622 if (flib_def.Find(function_name)) return Status::OK();
623
624 // TODO(fengliuai): use a small flib_def to reduce overhead
625 absl::flat_hash_set<Node*> control_ret_nodes;
626 TF_ASSIGN_OR_RETURN(auto sub_graph,
627 Exporter::Convert(configs, tf_dialect, function, flib,
628 &control_ret_nodes));
629 const auto control_ret = [&](const Node* n) -> absl::optional<string> {
630 return control_ret_nodes.contains(n)
631 ? absl::make_optional<string>(n->name())
632 : absl::nullopt;
633 };
634 FunctionDef func_def;
635 TF_RETURN_IF_ERROR(
636 GraphToFunctionDef(*sub_graph, function_name, control_ret, &func_def));
637
638 // The node defs in FunctionDef might contain debug info which was added
639 // by the GraphToFunctionDef method. We should remove it if we don't want
640 // to export them to avoid failing the roundtrip test.
641 if (!configs.export_debug_info) {
642 for (auto& node_def : *func_def.mutable_node_def()) {
643 node_def.clear_experimental_debug_info();
644 }
645 }
646
647 // Checks for gradient attribute. If present converts the gradient function
648 // and populates the GradientDef.
649 auto grad_string = mlir::TF::TensorFlowDialect::GetGradientAttrName();
650 if (auto attr =
651 function->getAttrOfType<mlir::FlatSymbolRefAttr>(grad_string)) {
652 auto grad_func =
653 function->getParentOfType<mlir::ModuleOp>().lookupSymbol<mlir::FuncOp>(
654 attr.getValue());
655 TF_RETURN_IF_ERROR(
656 ConvertLibFunction(configs, tf_dialect, grad_func, flib));
657 GradientDef grad;
658 grad.set_function_name(function_name);
659 grad.set_gradient_func(grad_func.getName().str());
660 *flib->add_gradient() = grad;
661 }
662
663 auto stateful_string = mlir::TF::TensorFlowDialect::GetStatefulAttrName();
664 if (auto attr = function->getAttrOfType<mlir::UnitAttr>(stateful_string)) {
665 func_def.mutable_signature()->set_is_stateful(true);
666 }
667
668 // Ignore the gradient and is_stateful attribute on the function as they have
669 // been handled above.
670 absl::flat_hash_set<absl::string_view> attrs_to_ignore = {
671 grad_string.data(), stateful_string.data()};
672 llvm::SmallVector<mlir::NamedAttribute, 8> funcAttrs(
673 function->getDialectAttrs());
674 TF_RETURN_IF_ERROR(ConvertAttributes(funcAttrs, attrs_to_ignore,
675 /*remove_ref_type=*/false,
676 func_def.mutable_attr()));
677
678 for (int i = 0, e = function.getNumArguments(); i < e; ++i) {
679 if (auto resource_arg_unique_id_attr =
680 function.getArgAttrOfType<mlir::IntegerAttr>(
681 i, kResourceArgUniqueIdAttr)) {
682 (*func_def.mutable_resource_arg_unique_id())[i] =
683 resource_arg_unique_id_attr.getInt();
684 }
685
686 llvm::ArrayRef<mlir::NamedAttribute> func_arg_i_attrs =
687 function.getArgAttrs(i);
688 if (func_arg_i_attrs.empty()) continue;
689 absl::flat_hash_set<absl::string_view> attrs_to_ignore = {
690 kDeviceAttr, kResourceArgUniqueIdAttr};
691 FunctionDef::ArgAttrs func_def_arg_i_attrs;
692 TF_RETURN_IF_ERROR(ConvertAttributes(func_arg_i_attrs, attrs_to_ignore,
693 /*remove_ref_type=*/false,
694 func_def_arg_i_attrs.mutable_attr()));
695 if (func_def_arg_i_attrs.attr().empty()) continue;
696 (*func_def.mutable_arg_attr())[i] = std::move(func_def_arg_i_attrs);
697 }
698
699 (*flib->add_function()) = std::move(func_def);
700 return Status::OK();
701 }
702
Convert(mlir::ModuleOp module,const GraphExportConfig & configs,std::unique_ptr<Graph> * graph,FunctionLibraryDefinition * flib_def,absl::flat_hash_set<Node * > * control_ret_nodes)703 Status Exporter::Convert(mlir::ModuleOp module,
704 const GraphExportConfig& configs,
705 std::unique_ptr<Graph>* graph,
706 FunctionLibraryDefinition* flib_def,
707 absl::flat_hash_set<Node*>* control_ret_nodes) {
708 mlir::Identifier entry_func_id =
709 mlir::Identifier::get("main", module.getContext());
710 absl::optional<mlir::FuncOp> entry_func;
711 FunctionDefLibrary flib;
712 auto tf_dialect = module.getContext()->getLoadedDialect("tf");
713 for (auto function : module.getOps<mlir::FuncOp>()) {
714 if (function.isExternal())
715 return errors::FailedPrecondition("External functions not supported");
716
717 if (function.getName() == entry_func_id) {
718 entry_func.emplace(function);
719 } else {
720 TF_RETURN_IF_ERROR(
721 ConvertLibFunction(configs, tf_dialect, function, &flib));
722 }
723 }
724
725 if (!entry_func.has_value())
726 return errors::FailedPrecondition("entry function `main` must be present");
727
728 // Updates the graph and the function library definition.
729 TF_ASSIGN_OR_RETURN(
730 *graph, Exporter::Convert(configs, tf_dialect, entry_func.value(), &flib,
731 control_ret_nodes));
732 for (auto& func_def : flib.function()) {
733 TF_RETURN_IF_ERROR(flib_def->AddFunctionDef(func_def));
734 }
735 for (auto& grad_def : flib.gradient()) {
736 TF_RETURN_IF_ERROR(flib_def->AddGradientDef(grad_def));
737 }
738 return Status::OK();
739 }
740 } // namespace
741
ConvertMlirToGraph(mlir::ModuleOp module,const GraphExportConfig & configs,std::unique_ptr<Graph> * graph,FunctionLibraryDefinition * flib_def,absl::flat_hash_set<Node * > * control_ret_nodes)742 Status ConvertMlirToGraph(mlir::ModuleOp module,
743 const GraphExportConfig& configs,
744 std::unique_ptr<Graph>* graph,
745 FunctionLibraryDefinition* flib_def,
746 absl::flat_hash_set<Node*>* control_ret_nodes) {
747 TF_RETURN_IF_ERROR(HasSingleGraphSingleOpIslandsFunctions(module));
748 return Exporter::Convert(module, configs, graph, flib_def, control_ret_nodes);
749 }
750
ConvertMlirToGraph(mlir::ModuleOp module,const GraphExportConfig & configs,std::unique_ptr<Graph> * graph,FunctionLibraryDefinition * flib_def)751 Status ConvertMlirToGraph(mlir::ModuleOp module,
752 const GraphExportConfig& configs,
753 std::unique_ptr<Graph>* graph,
754 FunctionLibraryDefinition* flib_def) {
755 absl::flat_hash_set<Node*> control_ret_nodes;
756 return ConvertMlirToGraph(module, configs, graph, flib_def,
757 &control_ret_nodes);
758 }
759
ConvertMlirToGraphdef(mlir::ModuleOp module,const GraphExportConfig & configs)760 StatusOr<std::unique_ptr<GraphDef>> ConvertMlirToGraphdef(
761 mlir::ModuleOp module, const GraphExportConfig& configs) {
762 FunctionLibraryDefinition flib_def(OpRegistry::Global(),
763 FunctionDefLibrary());
764 auto graph = absl::make_unique<Graph>(flib_def);
765 TF_RETURN_IF_ERROR(ConvertMlirToGraph(module, configs, &graph, &flib_def));
766 auto graphdef = absl::make_unique<GraphDef>();
767 graph->ToGraphDef(graphdef.get());
768 if (!configs.export_library) graphdef->clear_library();
769 if (!configs.export_shapes) {
770 for (auto& node_def : *graphdef->mutable_node()) {
771 node_def.mutable_attr()->erase("shape");
772 }
773 }
774 if (!configs.export_debug_info) {
775 for (auto& node_def : *graphdef->mutable_node()) {
776 node_def.clear_experimental_debug_info();
777 }
778 }
779 return graphdef;
780 }
781
ConvertMlirFunctionToFunctionLibraryDef(mlir::FuncOp func,const GraphExportConfig & configs,FunctionDef * function_def)782 stream_executor::port::Status ConvertMlirFunctionToFunctionLibraryDef(
783 mlir::FuncOp func, const GraphExportConfig& configs,
784 FunctionDef* function_def) {
785 Dialect* tf_dialect = func.getContext()->getLoadedDialect("tf");
786 FunctionDefLibrary flib;
787 TF_RETURN_IF_ERROR(
788 Exporter::ConvertLibFunction(configs, tf_dialect, func, &flib));
789 for (auto& func_def : flib.function()) {
790 if (func_def.signature().name() == func.getName()) {
791 *function_def = func_def;
792 return Status::OK();
793 }
794 }
795 return errors::InvalidArgument(
796 "Function couldn't be found in the FunctionDefLibrary after converting "
797 "from MLIR");
798 }
799
800 } // namespace tensorflow
801