1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
17
18 #include <utility>
19
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/container/inlined_vector.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/string_view.h"
25 #include "absl/types/optional.h"
26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/ADT/StringRef.h"
30 #include "llvm/Support/Casting.h"
31 #include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
32 #include "mlir/IR/Attributes.h" // TF:llvm-project
33 #include "mlir/IR/Builders.h" // TF:llvm-project
34 #include "mlir/IR/Function.h" // TF:llvm-project
35 #include "mlir/IR/Identifier.h" // TF:llvm-project
36 #include "mlir/IR/Location.h" // TF:llvm-project
37 #include "mlir/IR/Module.h" // TF:llvm-project
38 #include "mlir/IR/Operation.h" // TF:llvm-project
39 #include "mlir/IR/Types.h" // TF:llvm-project
40 #include "mlir/Pass/Pass.h" // TF:llvm-project
41 #include "mlir/Pass/PassManager.h" // TF:llvm-project
42 #include "mlir/Support/DebugStringHelper.h" // TF:llvm-project
43 #include "mlir/Support/LogicalResult.h" // TF:llvm-project
44 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
45 #include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h"
46 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
47 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
48 #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
49 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
50 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
51 #include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h"
52 #include "tensorflow/compiler/xla/status_macros.h"
53 #include "tensorflow/core/framework/graph.pb.h"
54 #include "tensorflow/core/framework/graph_to_functiondef.h"
55 #include "tensorflow/core/framework/node_def.pb.h"
56 #include "tensorflow/core/framework/node_def_util.h"
57 #include "tensorflow/core/framework/op.h"
58 #include "tensorflow/core/framework/types.pb.h"
59 #include "tensorflow/core/framework/versions.pb.h"
60 #include "tensorflow/core/graph/algorithm.h"
61 #include "tensorflow/core/graph/graph.h"
62 #include "tensorflow/core/graph/tensor_id.h"
63 #include "tensorflow/core/lib/core/errors.h"
64 #include "tensorflow/core/lib/core/status.h"
65
66 namespace tensorflow {
67 using llvm::dyn_cast;
68 using llvm::isa;
69 using mlir::BlockArgument;
70 using mlir::Dialect;
71 using mlir::Operation;
72 using mlir::OperationState;
73 using mlir::Value;
74 using stream_executor::port::StatusOr;
75
76 namespace {
77
78 constexpr char kInvalidExecutorGraphMsg[] =
79 "Functions must be of a single Graph with single op Islands: ";
80
IsLegalChar(char c,bool first_char)81 bool IsLegalChar(char c, bool first_char) {
82 if (isalpha(c)) return true;
83 if (isdigit(c)) return true;
84 if (c == '.') return true;
85 if (c == '_') return true;
86
87 // First character of a node name can only be a letter, digit, dot or
88 // underscore.
89 if (first_char) return false;
90
91 if (c == '/') return true;
92 if (c == '-') return true;
93
94 return false;
95 }
96
97 // Convert characters in name that are considered illegal in TensorFlow Node
98 // name to '.'.
LegalizeNodeName(llvm::StringRef name)99 std::string LegalizeNodeName(llvm::StringRef name) {
100 assert(!name.empty() && "expected non-empty name");
101
102 std::string legalized_name;
103 bool first = true;
104 for (auto c : name) {
105 if (IsLegalChar(c, first)) {
106 legalized_name += c;
107 } else {
108 legalized_name += '.';
109 }
110 first = false;
111 }
112
113 return legalized_name;
114 }
115
116 // OpOrArgLocNameMapper that legalizes the returned name.
117 class LegalizedOpOrValLocNameMapper : public OpOrArgLocNameMapper {
118 private:
GetName(OpOrVal op_or_val)119 std::string GetName(OpOrVal op_or_val) override {
120 return LegalizeNodeName(OpOrArgLocNameMapper::GetName(op_or_val));
121 }
122 };
123
124 // Checks functions in module are of single tf_executor.graph and each
125 // tf_executor.island in tf_executor.graph only has a single op.
HasSingleGraphSingleOpIslandsFunctions(mlir::ModuleOp module)126 Status HasSingleGraphSingleOpIslandsFunctions(mlir::ModuleOp module) {
127 Status status = Status::OK();
128 module.walk([&](mlir::FuncOp function) {
129 if (function.getBlocks().size() != 1) {
130 status = errors::FailedPrecondition(
131 kInvalidExecutorGraphMsg,
132 "only single block functions are supported.");
133 return mlir::WalkResult::interrupt();
134 }
135
136 auto block = function.front().without_terminator();
137 auto graph = llvm::dyn_cast<mlir::tf_executor::GraphOp>(block.begin());
138 if (!graph) {
139 status = errors::FailedPrecondition(
140 kInvalidExecutorGraphMsg,
141 "first op in function is not a tf_executor.graph.");
142 return mlir::WalkResult::interrupt();
143 }
144
145 if (!has_single_element(block)) {
146 status = errors::FailedPrecondition(
147 kInvalidExecutorGraphMsg,
148 "function does not only contain a single tf_executor.graph.");
149 return mlir::WalkResult::interrupt();
150 }
151
152 for (Operation& op : graph.GetBody()) {
153 auto island = llvm::dyn_cast<mlir::tf_executor::IslandOp>(op);
154 if (!island) continue;
155
156 if (!island.WrapsSingleOp()) {
157 status = errors::FailedPrecondition(
158 kInvalidExecutorGraphMsg,
159 "tf_executor.island must perfectly wrap a single op.");
160 return mlir::WalkResult::interrupt();
161 }
162 }
163
164 return mlir::WalkResult::advance();
165 });
166
167 return status;
168 }
169
170 // Finds first inner op if `op` is a tf_executor.island. Otherwise `op` is
171 // returned.
GetIslandInnerOpOrSelf(mlir::Operation * op)172 Operation* GetIslandInnerOpOrSelf(mlir::Operation* op) {
173 auto island = llvm::dyn_cast<mlir::tf_executor::IslandOp>(op);
174 if (island) return &island.GetBody().front();
175 return op;
176 }
177
178 // Stateful helper class to export a function into a Graph.
179 class Exporter {
180 public:
181 // Converts the given Module to a Graph. The given module should only contain
182 // one entry function, which is identified by name "main". This entry function
183 // is converted to the base of the graph graph. The rest of the functions are
184 // converted to the library functions in that graph.
185 static Status Convert(mlir::ModuleOp module, const GraphExportConfig& configs,
186 std::unique_ptr<Graph>* graph,
187 FunctionLibraryDefinition* flib_def,
188 absl::flat_hash_set<Node*>* control_ret_nodes);
189
190 // Converts a given FuncOp to a FunctionDef and adds it to the function
191 // definition library
192 static Status ConvertLibFunction(const GraphExportConfig& configs,
193 const Dialect* tf_dialect,
194 mlir::FuncOp function,
195 FunctionDefLibrary* flib);
196 // Converts the given FuncOp to a Graph. The arguments and returns of
197 // function are added to the graph with special op names kArgOp and kRetOp.
198 // Later on, this graph can be converted a function definition and added to
199 // another graph.
200 static StatusOr<std::unique_ptr<Graph>> Convert(
201 const GraphExportConfig& configs, const Dialect* tf_dialect,
202 mlir::FuncOp function, FunctionDefLibrary* flib,
203 absl::flat_hash_set<Node*>* control_ret_nodes);
204
205 private:
Exporter(Graph * graph,const Dialect * tf_dialect)206 explicit Exporter(Graph* graph, const Dialect* tf_dialect)
207 : graph_(graph), tf_dialect_(tf_dialect) {}
208
209 Status AddArgumentNode(BlockArgument arg, unsigned index,
210 llvm::StringRef name);
211 Status AddFetchNode(mlir::FuncOp function, mlir::tf_executor::FetchOp fetch,
212 llvm::ArrayRef<llvm::StringRef> names);
213 Status AddInstructionNode(Operation* inst);
214 Status AddEdge(Operation* inst);
215
216 StatusOr<std::unique_ptr<NodeDef>> GetArgumentNode(BlockArgument arg,
217 unsigned index,
218 llvm::StringRef name);
219 StatusOr<std::unique_ptr<NodeDef>> GetReturnNode(mlir::FuncOp function,
220 Value operand,
221 unsigned index,
222 llvm::StringRef name);
223 Status GetControlRetNodes(mlir::tf_executor::FetchOp fetch,
224 absl::flat_hash_set<Node*>* control_ret_nodes);
225 // Adds one edge between src_node and dst_node. If it is not a control edge,
226 // an index is used to find out the right operand of the dst_node.
227 Status AddEdgeBetweenNodes(Value src, Node* dst_node, unsigned dst_index);
228
229 Graph* graph_;
230 LegalizedOpOrValLocNameMapper op_to_name_;
231 absl::flat_hash_map<Operation*, Node*> nodes_;
232 llvm::DenseMap<BlockArgument, Node*> args_;
233 // One single return operation can return multiple results, and each of them
234 // will be converted to one node in the graph.
235 typedef absl::InlinedVector<Node*, 4> NodeVector;
236 absl::flat_hash_map<Operation*, NodeVector> returns_;
237 const mlir::Dialect* tf_dialect_;
238 };
239
GetArgumentNode(BlockArgument arg,unsigned index,llvm::StringRef name)240 StatusOr<std::unique_ptr<NodeDef>> Exporter::GetArgumentNode(
241 BlockArgument arg, unsigned index, llvm::StringRef name) {
242 auto func = arg.getParentRegion()->getParentOfType<mlir::FuncOp>();
243
244 auto node_def = absl::make_unique<NodeDef>();
245 if (!name.empty())
246 node_def->set_name(name.str());
247 else
248 node_def->set_name(
249 std::string(op_to_name_.GetUniqueName(func.getName().str())));
250
251 node_def->set_op(FunctionLibraryDefinition::kArgOp);
252
253 DataType dtype;
254 TF_RETURN_IF_ERROR(ConvertToDataType(
255 arg.getType().cast<mlir::TensorType>().getElementType(), &dtype));
256 AttrValue type_attr;
257 type_attr.set_type(dtype);
258 (*node_def->mutable_attr())["T"] = type_attr;
259
260 AttrValue index_attr;
261 index_attr.set_i(index);
262 (*node_def->mutable_attr())["index"] = index_attr;
263
264 if (auto device_attr =
265 func.getArgAttrOfType<mlir::StringAttr>(index, "tf.device")) {
266 *node_def->mutable_device() = device_attr.getValue().str();
267 }
268
269 if (auto resource_arg_unique_id_attr =
270 func.getArgAttrOfType<mlir::IntegerAttr>(
271 index, "tf.resource_arg_unique_id")) {
272 AttrValue unique_id_attr;
273 unique_id_attr.set_i(resource_arg_unique_id_attr.getInt());
274 (*node_def->mutable_attr())["_resource_arg_unique_id"] = unique_id_attr;
275 }
276
277 return node_def;
278 }
279
GetReturnNode(mlir::FuncOp function,Value operand,unsigned index,llvm::StringRef name)280 StatusOr<std::unique_ptr<NodeDef>> Exporter::GetReturnNode(
281 mlir::FuncOp function, Value operand, unsigned index,
282 llvm::StringRef name) {
283 auto node_def = absl::make_unique<NodeDef>();
284 if (!name.empty())
285 node_def->set_name(name.str());
286 else
287 node_def->set_name(
288 std::string(op_to_name_.GetUniqueName(function.getName().str())));
289
290 node_def->set_op(FunctionLibraryDefinition::kRetOp);
291 DataType dtype;
292 TF_RETURN_IF_ERROR(ConvertToDataType(
293 operand.getType().cast<mlir::TensorType>().getElementType(), &dtype));
294 AttrValue type_attr;
295 type_attr.set_type(dtype);
296 (*node_def->mutable_attr())["T"] = type_attr;
297 AttrValue index_attr;
298 index_attr.set_i(index);
299 (*node_def->mutable_attr())["index"] = index_attr;
300 return node_def;
301 }
302
AddEdgeBetweenNodes(Value src,Node * dst_node,unsigned dst_index)303 Status Exporter::AddEdgeBetweenNodes(Value src, Node* dst_node,
304 unsigned dst_index) {
305 if (auto input_result = src.dyn_cast<mlir::OpResult>()) {
306 auto* input_inst = GetIslandInnerOpOrSelf(input_result.getOwner());
307 // Replaces the input node with NextIteration sink if it is a NextIteration
308 // source.
309 if (auto next_iter_source =
310 llvm::dyn_cast<mlir::tf_executor::NextIterationSourceOp>(
311 input_inst))
312 input_inst = next_iter_source.GetSink();
313
314 auto node_it = nodes_.find(input_inst);
315 TF_RET_CHECK(node_it != nodes_.end())
316 << "Use of OpResult encountered before def!";
317 if (input_result.getType().isa<mlir::tf_executor::ControlType>()) {
318 graph_->AddControlEdge(node_it->second, dst_node);
319 } else {
320 graph_->AddEdge(node_it->second, input_result.getResultNumber(), dst_node,
321 dst_index);
322 }
323 return Status::OK();
324 }
325
326 auto input_arg = src.cast<BlockArgument>();
327 auto input_node_it = args_.find(input_arg);
328 TF_RET_CHECK(input_node_it != args_.end())
329 << "Use of BlockArgument encounted before def!";
330 // For argument, there is only one result output, so the index is always 0.
331 graph_->AddEdge(input_node_it->second, 0, dst_node, dst_index);
332 return Status::OK();
333 }
334
AddEdge(Operation * inst)335 Status Exporter::AddEdge(Operation* inst) {
336 // For tf_executor.fetch, add only its data edges. Control edges are captured
337 // later.
338 if (auto fetch = llvm::dyn_cast<mlir::tf_executor::FetchOp>(inst)) {
339 for (auto operand_and_idx : llvm::enumerate(fetch.getOperands())) {
340 Value operand = operand_and_idx.value();
341 if (operand.getType().isa<mlir::tf_executor::ControlType>()) break;
342
343 auto* dst_node = returns_[fetch][operand_and_idx.index()];
344 TF_RETURN_IF_ERROR(AddEdgeBetweenNodes(operand, dst_node, 0));
345 }
346
347 return Status::OK();
348 }
349
350 // For tf_executor.NextIteration.Sink, skip its token operand and add data and
351 // control edges with their index offset by 1.
352 if (auto next_iter_sink =
353 llvm::dyn_cast<mlir::tf_executor::NextIterationSinkOp>(inst)) {
354 auto* dst_node = nodes_[inst];
355 TF_RETURN_IF_ERROR(
356 AddEdgeBetweenNodes(next_iter_sink.input(), dst_node, 0));
357 for (auto control_and_idx : llvm::enumerate(next_iter_sink.controlInputs()))
358 TF_RETURN_IF_ERROR(AddEdgeBetweenNodes(control_and_idx.value(), dst_node,
359 control_and_idx.index() + 1));
360
361 return Status::OK();
362 }
363
364 // For tf_executor.NextIteration.Source, op can be skipped as it is assumed
365 // there are no operands.
366 if (llvm::isa<mlir::tf_executor::NextIterationSourceOp>(inst)) {
367 assert(inst->getNumOperands() == 0);
368 return Status::OK();
369 }
370
371 Operation* op = GetIslandInnerOpOrSelf(inst);
372 auto* dst_node = nodes_[op];
373 int operand_offset = 0;
374 // For tf_executor.island, add data edges from its wrapped op before control
375 // edges.
376 if (auto island = llvm::dyn_cast<mlir::tf_executor::IslandOp>(inst)) {
377 for (auto operand_and_idx : llvm::enumerate(op->getOperands()))
378 TF_RETURN_IF_ERROR(AddEdgeBetweenNodes(operand_and_idx.value(), dst_node,
379 operand_and_idx.index()));
380
381 operand_offset = op->getNumOperands();
382 }
383
384 // For all other ops (including tf_executor.island), add remaining edges.
385 for (auto operand_and_idx : llvm::enumerate(inst->getOperands()))
386 TF_RETURN_IF_ERROR(
387 AddEdgeBetweenNodes(operand_and_idx.value(), dst_node,
388 operand_and_idx.index() + operand_offset));
389
390 return Status::OK();
391 }
392
AddInstructionNode(Operation * inst)393 Status Exporter::AddInstructionNode(Operation* inst) {
394 std::unique_ptr<NodeDef> node_def;
395 auto name = op_to_name_.GetUniqueName(inst);
396 // Convert registered TF ops to NodeDef. Only registered ops are handled to
397 // ensure that PopulateDerivedAttrs adds the correct attributes.
398 TF_ASSIGN_OR_RETURN(node_def,
399 ConvertTFDialectOpToNodeDef(
400 inst, name, /*ignore_unregistered_attrs=*/false));
401
402 Status status;
403 Node* node = graph_->AddNode(*node_def, &status);
404 TF_RETURN_IF_ERROR(status);
405 DCHECK(node != nullptr);
406 nodes_[inst] = node;
407 return Status::OK();
408 }
409
IsEntryFunctionArg(BlockArgument arg)410 bool IsEntryFunctionArg(BlockArgument arg) {
411 return arg.getParentRegion()->getParentOfType<mlir::FuncOp>().getName() ==
412 "main";
413 }
414
415 // Creates argument nodes from Block argument. If a name is supplied, that
416 // name will be used instead of generating a unique name.
AddArgumentNode(BlockArgument arg,unsigned index,llvm::StringRef name)417 Status Exporter::AddArgumentNode(BlockArgument arg, unsigned index,
418 llvm::StringRef name) {
419 if (!IsEntryFunctionArg(arg) || !name.empty()) {
420 TF_ASSIGN_OR_RETURN(auto node_def, GetArgumentNode(arg, index, name));
421 Status status;
422 Node* node = graph_->AddNode(*node_def, &status);
423 TF_RETURN_IF_ERROR(status);
424 args_[arg] = node;
425 return status;
426 }
427
428 // If it is an argument from the "main" function, it has only one user, which
429 // is an input node. We recover the original input node and skip adding the
430 // argument node. The new input node will be handled as normal in the
431 // following steps.
432 if (!arg.hasOneUse()) {
433 return errors::FailedPrecondition(
434 "Arg in 'main' should only have one user.");
435 }
436 auto* input = *arg.user_begin();
437 auto* parent = input->getParentOp();
438 auto island = llvm::dyn_cast_or_null<mlir::tf_executor::IslandOp>(parent);
439 if (!island)
440 return errors::FailedPrecondition(
441 "User of arg in 'main' must be in an inner op of a "
442 "tf_executor.island.");
443
444 if (!island.control().use_empty())
445 return errors::FailedPrecondition(
446 "tf_executor.island of user of arg in 'main' must have no control "
447 "output users.");
448
449 auto input_name = input->getName().getStringRef();
450 input_name.consume_back(".input");
451
452 mlir::OpBuilder builder(island.getContext());
453 builder.setInsertionPointToStart(&island.GetBody());
454 auto loc = mlir::NameLoc::get(
455 builder.getIdentifier(op_to_name_.GetUniqueName(input)),
456 builder.getContext());
457 OperationState state(loc, input_name.str());
458 state.attributes.append(input->getAttrs().begin(), input->getAttrs().end());
459 for (auto op : input->getOperands()) {
460 // Skip the argument in the new operation.
461 if (op.isa<BlockArgument>()) continue;
462 state.operands.push_back(op);
463 }
464 state.types.append(input->getResultTypes().begin(),
465 input->getResultTypes().end());
466 auto* inst = builder.createOperation(state);
467 // If it is one of the specified input names, then the new instruction should
468 // have the same name.
469 op_to_name_.InitOpName(inst, op_to_name_.GetUniqueName(input));
470 for (int index : llvm::seq<int>(0, input->getNumResults())) {
471 input->getResult(index).replaceAllUsesWith(inst->getResult(index));
472 }
473 input->dropAllReferences();
474 input->erase();
475 return Status::OK();
476 }
477
478 // Creates return nodes per operand of a FetchOp. If names is supplied, those
479 // names will be used per node in order instead of generating a unique name.
AddFetchNode(mlir::FuncOp function,mlir::tf_executor::FetchOp fetch,llvm::ArrayRef<llvm::StringRef> names)480 Status Exporter::AddFetchNode(mlir::FuncOp function,
481 mlir::tf_executor::FetchOp fetch,
482 llvm::ArrayRef<llvm::StringRef> names) {
483 Status status;
484 auto& return_nodes = returns_[fetch];
485 for (auto operand_and_idx : llvm::enumerate(fetch.getOperands())) {
486 if (operand_and_idx.value().getType().isa<mlir::tf_executor::ControlType>())
487 break;
488
489 TF_ASSIGN_OR_RETURN(
490 auto node_def,
491 GetReturnNode(function, operand_and_idx.value(),
492 operand_and_idx.index(),
493 names.empty() ? "" : names[operand_and_idx.index()]));
494 Node* node = graph_->AddNode(*node_def, &status);
495 TF_RETURN_IF_ERROR(status);
496 return_nodes.push_back(node);
497 }
498 return Status::OK();
499 }
500
501 // Collects control ret Nodes based on tf_executor.graph's associated
502 // tf_executor.fetch control inputs.
GetControlRetNodes(mlir::tf_executor::FetchOp fetch,absl::flat_hash_set<Node * > * control_ret_nodes)503 Status Exporter::GetControlRetNodes(
504 mlir::tf_executor::FetchOp fetch,
505 absl::flat_hash_set<Node*>* control_ret_nodes) {
506 for (Value fetch_operand : fetch.getOperands()) {
507 if (fetch_operand.getType().isa<mlir::tf_executor::ControlType>()) {
508 Operation* defining_op =
509 GetIslandInnerOpOrSelf(fetch_operand.getDefiningOp());
510 auto node_it = nodes_.find(defining_op);
511 TF_RET_CHECK(node_it != nodes_.end());
512 control_ret_nodes->insert(node_it->second);
513 }
514 }
515 return Status::OK();
516 }
517
Convert(const GraphExportConfig & configs,const Dialect * tf_dialect,mlir::FuncOp function,FunctionDefLibrary * flib,absl::flat_hash_set<Node * > * control_ret_nodes)518 StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
519 const GraphExportConfig& configs, const Dialect* tf_dialect,
520 mlir::FuncOp function, FunctionDefLibrary* flib,
521 absl::flat_hash_set<Node*>* control_ret_nodes) {
522 mlir::Block& block = function.front();
523
524 // Determine if _Arg and _Retval nodes should use input and output names.
525 bool graph_as_function = false;
526
527 // Extract input & output names if set.
528 llvm::SmallVector<llvm::StringRef, 2> input_names;
529 llvm::SmallVector<llvm::StringRef, 2> output_names;
530 auto dict_attr =
531 function.getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
532 if (dict_attr) {
533 TF_RET_CHECK(dict_attr.get("inputs").isa<mlir::StringAttr>())
534 << "inputs missing in entry function attribute";
535 TF_RET_CHECK(dict_attr.get("outputs").isa<mlir::StringAttr>())
536 << "outputs missing in entry function attribute";
537 dict_attr.get("inputs").cast<mlir::StringAttr>().getValue().split(
538 input_names, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false);
539 dict_attr.get("outputs").cast<mlir::StringAttr>().getValue().split(
540 output_names, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false);
541 graph_as_function = configs.graph_as_function;
542 }
543
544 auto graph = absl::make_unique<Graph>(OpRegistry::Global());
545
546 // Extract version info.
547 auto version_attr = function.getParentOfType<mlir::ModuleOp>()
548 .getAttrOfType<mlir::DictionaryAttr>("tf.versions");
549 if (version_attr) {
550 VersionDef versions;
551 versions.set_producer(
552 version_attr.get("producer").cast<mlir::IntegerAttr>().getInt());
553 versions.set_min_consumer(
554 version_attr.get("min_consumer").cast<mlir::IntegerAttr>().getInt());
555 for (auto bad_consumer :
556 version_attr.get("bad_consumers").cast<mlir::ArrayAttr>()) {
557 versions.mutable_bad_consumers()->Add(
558 bad_consumer.cast<mlir::IntegerAttr>().getInt());
559 }
560 graph->set_versions(versions);
561 }
562
563 // We have to add the function library here, so a custom operation, which is
564 // defined in the function library can be added to the graph.
565 TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(*flib));
566 Exporter exporter(graph.get(), tf_dialect);
567
568 auto graph_op = llvm::cast<mlir::tf_executor::GraphOp>(block.front());
569
570 // Set input and output names and increment the use counter for them to help
571 // generate unique names.
572 if (!output_names.empty()) {
573 const int num_data_results = graph_op.getNumResults();
574 TF_RET_CHECK(output_names.size() == num_data_results)
575 << "output names (" << output_names.size()
576 << ") != terminator operands (" << num_data_results << ")";
577 llvm::DenseMap<Operation*, llvm::StringRef> output_op_to_name;
578 llvm::StringMap<Operation*> name_to_op;
579 for (auto it : llvm::enumerate(graph_op.GetFetch().getOperands())) {
580 // Skip control rets.
581 if (it.index() >= num_data_results) break;
582 // If there is a result index specified, ensure only one and that it
583 // matches the result index of the op.
584 auto result = it.value().cast<mlir::OpResult>();
585 std::string orig_name(output_names[it.index()]);
586 auto tensor_id = ParseTensorName(orig_name);
587 auto name = LegalizeNodeName(
588 llvm::StringRef(tensor_id.node().data(), tensor_id.node().size()));
589
590 if (graph_as_function) {
591 // Ensure name does not get reused.
592 (void)exporter.op_to_name_.GetUniqueName(name);
593 continue;
594 }
595
596 TF_RET_CHECK(result.getResultNumber() == tensor_id.index());
597 Operation* defining_op = GetIslandInnerOpOrSelf(result.getDefiningOp());
598 if (output_op_to_name.insert({defining_op, name}).second) {
599 TF_RET_CHECK(name_to_op.insert({name, defining_op}).second)
600 << "multiple operations associated with the same name";
601 exporter.op_to_name_.InitOpName(defining_op, name);
602 } else {
603 TF_RET_CHECK(output_op_to_name[defining_op] == name)
604 << "associating multiple names with the same op not supported";
605 }
606 }
607 }
608
609 if (!input_names.empty()) {
610 TF_RET_CHECK(input_names.size() == block.getNumArguments());
611 for (auto it : llvm::enumerate(function.getArguments())) {
612 // TODO(lyandy): Update when changing feed/fetch import.
613 std::string orig_name(input_names[it.index()]);
614 std::string name = LegalizeNodeName(orig_name);
615 auto tensor_id = ParseTensorName(name);
616 TF_RET_CHECK(tensor_id.index() == 0)
617 << "input port designation not supported";
618 // Only assign user of argument the input name if the main graph did not
619 // have its _Arg nodes lifted into the functions arguments.
620 if (graph_as_function) {
621 // Ensure name does not get reused.
622 (void)exporter.op_to_name_.GetUniqueName(name);
623 } else {
624 Operation* defining_op =
625 GetIslandInnerOpOrSelf(*it.value().user_begin());
626 exporter.op_to_name_.InitOpName(defining_op, name);
627 }
628 }
629 }
630
631 // Adds nodes for basic block (function) arguments.
632 for (auto it : llvm::enumerate(block.getArguments())) {
633 int index = it.index();
634 auto arg = it.value();
635 mlir::Type type = arg.getType();
636 if (!type.isa<mlir::TensorType>()) {
637 return errors::InvalidArgument(
638 "FuncOps arguments must have tensor types. Found ",
639 mlir::debugString(type), " in function ", function.getName().str());
640 }
641
642 TF_RETURN_IF_ERROR(exporter.AddArgumentNode(
643 arg, index,
644 graph_as_function && !input_names.empty() ? input_names[index] : ""));
645 }
646
647 auto convert_called_function = [&](llvm::StringRef name) {
648 auto func =
649 function.getParentOfType<mlir::ModuleOp>().lookupSymbol<mlir::FuncOp>(
650 name);
651 if (func != nullptr) {
652 TF_RETURN_IF_ERROR(ConvertLibFunction(configs, tf_dialect, func, flib));
653 TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(*flib));
654 }
655 return Status::OK();
656 };
657
658 // Adds nodes for operations.
659 for (Operation& inst : graph_op.GetBody()) {
660 for (auto type : inst.getResultTypes())
661 if (!type.isa<mlir::TensorType>() &&
662 !type.isa<mlir::tf_executor::ControlType>() &&
663 !type.isa<mlir::tf_executor::TokenType>())
664 return errors::InvalidArgument(
665 "Values must be of tensor type, TensorFlow control type, or "
666 "TensorFlow token type. Found ",
667 mlir::debugString(type));
668
669 if (llvm::isa<mlir::tf_executor::NextIterationSourceOp>(inst)) {
670 // Skip tf_executor.NextIteration.Source as associated
671 // tf_executor.NextIteration.Sink will be used instead.
672 continue;
673 } else if (auto fetch = llvm::dyn_cast<mlir::tf_executor::FetchOp>(inst)) {
674 TF_RETURN_IF_ERROR(exporter.AddFetchNode(
675 function, fetch,
676 graph_as_function ? output_names
677 : llvm::ArrayRef<llvm::StringRef>()));
678 } else if (auto island =
679 llvm::dyn_cast<mlir::tf_executor::IslandOp>(inst)) {
680 Operation& inner_op = island.GetBody().front();
681 auto op_name = GetTensorFlowOpName(inner_op.getName().getStringRef());
682 if (op_name.ok()) {
683 // If it is TF Control dialect specific op, look up custom operation
684 // in the module and first convert that, then add it to function
685 // definition library
686 // TODO(prakalps): If two functions have cyclic dependence, this will
687 // introduce an infinite loop.
688 TF_RETURN_IF_ERROR(convert_called_function(op_name.ValueOrDie().str()));
689 }
690
691 if (IsLegacyCallInstruction(&inner_op)) {
692 TF_RETURN_IF_ERROR(convert_called_function(
693 inner_op.getAttrOfType<mlir::SymbolRefAttr>("f")
694 .getLeafReference()));
695 }
696
697 TF_RETURN_IF_ERROR(exporter.AddInstructionNode(&inner_op));
698 } else {
699 TF_RETURN_IF_ERROR(exporter.AddInstructionNode(&inst));
700 }
701 }
702 // Adds edges between the argument, operation and return nodes.
703 for (Operation& inst : graph_op.GetBody()) {
704 TF_RETURN_IF_ERROR(exporter.AddEdge(&inst));
705 }
706 // Fixes the edges between the inserted nodes and special "_SOURCE" and
707 // "_SINK".
708 FixupSourceAndSinkEdges(graph.get());
709
710 TF_RETURN_IF_ERROR(
711 exporter.GetControlRetNodes(graph_op.GetFetch(), control_ret_nodes));
712
713 return graph;
714 }
715
ConvertLibFunction(const GraphExportConfig & configs,const Dialect * tf_dialect,mlir::FuncOp function,FunctionDefLibrary * flib)716 Status Exporter::ConvertLibFunction(const GraphExportConfig& configs,
717 const Dialect* tf_dialect,
718 mlir::FuncOp function,
719 FunctionDefLibrary* flib) {
720 // First look for the function in the current function library. If found,
721 // nothing needs to be done.
722 OpRegistry empty_registry;
723 FunctionLibraryDefinition flib_def(&empty_registry, *flib);
724 auto function_name = function.getName().str();
725 if (flib_def.Find(function_name)) return Status::OK();
726
727 // TODO(fengliuai): use a small flib_def to reduce overhead
728 absl::flat_hash_set<Node*> control_ret_nodes;
729 TF_ASSIGN_OR_RETURN(auto sub_graph,
730 Exporter::Convert(configs, tf_dialect, function, flib,
731 &control_ret_nodes));
732 const auto control_ret = [&](const Node* n) -> absl::optional<string> {
733 return control_ret_nodes.contains(n)
734 ? absl::make_optional<string>(n->name())
735 : absl::nullopt;
736 };
737 FunctionDef func_def;
738 TF_RETURN_IF_ERROR(
739 GraphToFunctionDef(*sub_graph, function_name, control_ret, &func_def));
740
741 // The node defs in FunctionDef might contain debug info which was added
742 // by the GraphToFunctionDef method. We should remove it if we don't want
743 // to export them to avoid failing the roundtrip test.
744 if (!configs.export_debug_info) {
745 for (auto& node_def : *func_def.mutable_node_def()) {
746 node_def.clear_experimental_debug_info();
747 }
748 }
749
750 // Checks for gradient attribute. If present converts the gradient function
751 // and populates the GradientDef.
752 auto grad_string = mlir::TF::TensorFlowDialect::GetGradientAttrName();
753 if (auto attr =
754 function.getAttrOfType<mlir::FlatSymbolRefAttr>(grad_string)) {
755 auto grad_func =
756 function.getParentOfType<mlir::ModuleOp>().lookupSymbol<mlir::FuncOp>(
757 attr.getValue());
758 TF_RETURN_IF_ERROR(
759 ConvertLibFunction(configs, tf_dialect, grad_func, flib));
760 GradientDef grad;
761 grad.set_function_name(function_name);
762 grad.set_gradient_func(grad_func.getName().str());
763 *flib->add_gradient() = grad;
764 }
765
766 auto stateful_string = mlir::TF::TensorFlowDialect::GetStatefulAttrName();
767 if (auto attr = function.getAttrOfType<mlir::UnitAttr>(stateful_string)) {
768 func_def.mutable_signature()->set_is_stateful(true);
769 }
770 for (int64 i = 0; i < function.getNumArguments(); ++i) {
771 if (auto resource_arg_unique_id_attr =
772 function.getArgAttrOfType<mlir::IntegerAttr>(
773 i, "tf.resource_arg_unique_id")) {
774 (*func_def.mutable_resource_arg_unique_id())[i] =
775 resource_arg_unique_id_attr.getInt();
776 }
777 }
778
779 // Ignore the gradient and is_stateful attribute on the function as they have
780 // been handled above.
781 absl::flat_hash_set<absl::string_view> attrs_to_ignore = {
782 grad_string.data(), stateful_string.data()};
783 llvm::SmallVector<mlir::NamedAttribute, 8> funcAttrs(
784 function.getDialectAttrs());
785 TF_RETURN_IF_ERROR(
786 ConvertAttributes(funcAttrs, attrs_to_ignore, func_def.mutable_attr()));
787 (*flib->add_function()) = func_def;
788 return Status::OK();
789 }
790
Convert(mlir::ModuleOp module,const GraphExportConfig & configs,std::unique_ptr<Graph> * graph,FunctionLibraryDefinition * flib_def,absl::flat_hash_set<Node * > * control_ret_nodes)791 Status Exporter::Convert(mlir::ModuleOp module,
792 const GraphExportConfig& configs,
793 std::unique_ptr<Graph>* graph,
794 FunctionLibraryDefinition* flib_def,
795 absl::flat_hash_set<Node*>* control_ret_nodes) {
796 mlir::Identifier entry_func_id =
797 mlir::Identifier::get("main", module.getContext());
798 absl::optional<mlir::FuncOp> entry_func;
799 FunctionDefLibrary flib;
800 auto tf_dialect = module.getContext()->getRegisteredDialect("tf");
801 for (auto function : module.getOps<mlir::FuncOp>()) {
802 if (function.isExternal())
803 return errors::FailedPrecondition("External functions not supported");
804
805 if (function.getName() == entry_func_id) {
806 entry_func.emplace(function);
807 } else {
808 TF_RETURN_IF_ERROR(
809 ConvertLibFunction(configs, tf_dialect, function, &flib));
810 }
811 }
812
813 if (!entry_func.has_value())
814 return errors::FailedPrecondition("entry function `main` must be present");
815
816 // Updates the graph and the function library definition.
817 TF_ASSIGN_OR_RETURN(
818 *graph, Exporter::Convert(configs, tf_dialect, entry_func.value(), &flib,
819 control_ret_nodes));
820 for (auto& func_def : flib.function()) {
821 TF_RETURN_IF_ERROR(flib_def->AddFunctionDef(func_def));
822 }
823 for (auto& grad_def : flib.gradient()) {
824 TF_RETURN_IF_ERROR(flib_def->AddGradientDef(grad_def));
825 }
826 return Status::OK();
827 }
828 } // namespace
829
ConvertMlirToGraph(mlir::ModuleOp module,const GraphExportConfig & configs,std::unique_ptr<Graph> * graph,FunctionLibraryDefinition * flib_def,absl::flat_hash_set<Node * > * control_ret_nodes)830 Status ConvertMlirToGraph(mlir::ModuleOp module,
831 const GraphExportConfig& configs,
832 std::unique_ptr<Graph>* graph,
833 FunctionLibraryDefinition* flib_def,
834 absl::flat_hash_set<Node*>* control_ret_nodes) {
835 TF_RETURN_IF_ERROR(HasSingleGraphSingleOpIslandsFunctions(module));
836 return Exporter::Convert(module, configs, graph, flib_def, control_ret_nodes);
837 }
838
ConvertMlirToGraph(mlir::ModuleOp module,const GraphExportConfig & configs,std::unique_ptr<Graph> * graph,FunctionLibraryDefinition * flib_def)839 Status ConvertMlirToGraph(mlir::ModuleOp module,
840 const GraphExportConfig& configs,
841 std::unique_ptr<Graph>* graph,
842 FunctionLibraryDefinition* flib_def) {
843 absl::flat_hash_set<Node*> control_ret_nodes;
844 return ConvertMlirToGraph(module, configs, graph, flib_def,
845 &control_ret_nodes);
846 }
847
ConvertMlirToGraphdef(mlir::ModuleOp module,const GraphExportConfig & configs)848 StatusOr<std::unique_ptr<GraphDef>> ConvertMlirToGraphdef(
849 mlir::ModuleOp module, const GraphExportConfig& configs) {
850 FunctionLibraryDefinition flib_def(OpRegistry::Global(),
851 FunctionDefLibrary());
852 auto graph = absl::make_unique<Graph>(flib_def);
853 TF_RETURN_IF_ERROR(ConvertMlirToGraph(module, configs, &graph, &flib_def));
854 auto graphdef = absl::make_unique<GraphDef>();
855 graph->ToGraphDef(graphdef.get());
856 if (!configs.export_library) graphdef->clear_library();
857 if (!configs.export_shapes) {
858 for (auto& node_def : *graphdef->mutable_node()) {
859 node_def.mutable_attr()->erase("shape");
860 }
861 }
862 if (!configs.export_debug_info) {
863 for (auto& node_def : *graphdef->mutable_node()) {
864 node_def.clear_experimental_debug_info();
865 }
866 }
867 return graphdef;
868 }
869
870 } // namespace tensorflow
871