1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
17
18 #include <string>
19 #include <utility>
20
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/container/inlined_vector.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/string_view.h"
26 #include "absl/types/optional.h"
27 #include "llvm/ADT/ArrayRef.h"
28 #include "llvm/ADT/DenseSet.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallVector.h"
31 #include "llvm/ADT/StringRef.h"
32 #include "llvm/Support/Casting.h"
33 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
34 #include "mlir/IR/Attributes.h" // from @llvm-project
35 #include "mlir/IR/Builders.h" // from @llvm-project
36 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
37 #include "mlir/IR/Identifier.h" // from @llvm-project
38 #include "mlir/IR/Location.h" // from @llvm-project
39 #include "mlir/IR/Operation.h" // from @llvm-project
40 #include "mlir/IR/SymbolTable.h" // from @llvm-project
41 #include "mlir/IR/Types.h" // from @llvm-project
42 #include "mlir/Pass/Pass.h" // from @llvm-project
43 #include "mlir/Pass/PassManager.h" // from @llvm-project
44 #include "mlir/Support/DebugStringHelper.h" // from @llvm-project
45 #include "mlir/Support/LogicalResult.h" // from @llvm-project
46 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
47 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
48 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
49 #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
50 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
51 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
52 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
53 #include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h"
54 #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
55 #include "tensorflow/compiler/mlir/tensorflow/utils/verify_suitable_for_graph_export.h"
56 #include "tensorflow/compiler/mlir/utils/name_utils.h"
57 #include "tensorflow/compiler/xla/status_macros.h"
58 #include "tensorflow/core/framework/graph.pb.h"
59 #include "tensorflow/core/framework/graph_to_functiondef.h"
60 #include "tensorflow/core/framework/node_def.pb.h"
61 #include "tensorflow/core/framework/node_def_util.h"
62 #include "tensorflow/core/framework/op.h"
63 #include "tensorflow/core/framework/types.pb.h"
64 #include "tensorflow/core/framework/versions.pb.h"
65 #include "tensorflow/core/graph/algorithm.h"
66 #include "tensorflow/core/graph/graph.h"
67 #include "tensorflow/core/graph/tensor_id.h"
68 #include "tensorflow/core/lib/core/errors.h"
69 #include "tensorflow/core/lib/core/status.h"
70
71 namespace tensorflow {
72 using llvm::dyn_cast;
73 using llvm::isa;
74 using mlir::BlockArgument;
75 using mlir::Dialect;
76 using mlir::FuncOp;
77 using mlir::Operation;
78 using mlir::SymbolTable;
79 using mlir::Value;
80 using stream_executor::port::StatusOr;
81
82 namespace {
83
84 constexpr char kDeviceAttr[] = "tf.device";
85 constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id";
86 constexpr char kEntryFuncAttr[] = "tf.entry_function";
87 constexpr char kAliasingAttr[] = "tf.aliasing_output";
88
89 // OpOrArgLocNameMapper that legalizes the returned name.
90 class LegalizedOpOrValLocNameMapper : public OpOrArgLocNameMapper {
91 private:
GetName(OpOrVal op_or_val)92 std::string GetName(OpOrVal op_or_val) override {
93 std::string name = OpOrArgLocNameMapper::GetName(op_or_val);
94 assert(!name.empty() && "expected non-empty name");
95 mlir::LegalizeNodeName(name);
96 return name;
97 }
98 };
99
100 // Finds first inner op if `op` is a tf_executor.island. Otherwise `op` is
101 // returned.
GetIslandInnerOpOrSelf(mlir::Operation * op)102 Operation* GetIslandInnerOpOrSelf(mlir::Operation* op) {
103 auto island = llvm::dyn_cast<mlir::tf_executor::IslandOp>(op);
104 if (island) return &island.GetBody().front();
105 return op;
106 }
107
108 // Stateful helper class to export a function into a Graph.
109 class Exporter {
110 public:
111 // Converts the given Module to a Graph. The given module should only contain
112 // one entry function, which is identified by name "main". This entry function
113 // is converted to the base of the graph graph. The rest of the functions are
114 // converted to the library functions in that graph.
115 static Status Convert(mlir::ModuleOp module, const GraphExportConfig& configs,
116 std::unique_ptr<Graph>* graph,
117 FunctionLibraryDefinition* flib_def,
118 absl::flat_hash_set<Node*>* control_ret_nodes);
119
120 // Converts a given FuncOp to a FunctionDef and adds it to the function
121 // definition library
122 static Status ConvertLibFunction(
123 const GraphExportConfig& configs, const Dialect* tf_dialect,
124 const SymbolTable& symbol_table, FuncOp function,
125 FunctionDefLibrary* flib, llvm::SmallDenseSet<FuncOp>& visited_functions);
126
127 // Converts the given FuncOp to a Graph. The arguments and returns of
128 // function are added to the graph with special op names kArgOp and kRetOp.
129 // Later on, this graph can be converted a function definition and added to
130 // another graph.
131 static StatusOr<std::unique_ptr<Graph>> Convert(
132 const GraphExportConfig& configs, const Dialect* tf_dialect,
133 const SymbolTable& symbol_table, FuncOp function,
134 FunctionDefLibrary* flib, llvm::SmallDenseSet<FuncOp>& visited_functions,
135 absl::flat_hash_set<Node*>* control_ret_nodes);
136
137 private:
Exporter(Graph * graph,const Dialect * tf_dialect)138 explicit Exporter(Graph* graph, const Dialect* tf_dialect)
139 : graph_(graph), tf_dialect_(tf_dialect) {}
140
141 Status AddArgumentNode(BlockArgument arg, unsigned index,
142 llvm::StringRef name);
143 Status AddFetchNode(FuncOp function, mlir::tf_executor::FetchOp fetch,
144 llvm::ArrayRef<llvm::StringRef> names);
145 Status AddInstructionNode(Operation* inst);
146 Status AddEdge(Operation* inst);
147
148 StatusOr<std::unique_ptr<NodeDef>> GetArgumentNode(BlockArgument arg,
149 unsigned index,
150 llvm::StringRef name);
151 StatusOr<std::unique_ptr<NodeDef>> GetReturnNode(FuncOp function,
152 Value operand,
153 unsigned index,
154 llvm::StringRef name);
155 Status GetControlRetNodes(mlir::tf_executor::FetchOp fetch,
156 absl::flat_hash_set<Node*>* control_ret_nodes);
157 // Adds one edge between src_node and dst_node. If it is not a control edge,
158 // an index is used to find out the right operand of the dst_node.
159 Status AddEdgeBetweenNodes(Value src, Node* dst_node, unsigned dst_index);
160
161 Graph* graph_;
162 LegalizedOpOrValLocNameMapper op_to_name_;
163 absl::flat_hash_map<Operation*, Node*> nodes_;
164 llvm::DenseMap<BlockArgument, Node*> args_;
165 // One single return operation can return multiple results, and each of them
166 // will be converted to one node in the graph.
167 typedef absl::InlinedVector<Node*, 4> NodeVector;
168 absl::flat_hash_map<Operation*, NodeVector> returns_;
169 const mlir::Dialect* tf_dialect_;
170 };
171
GetArgumentNode(BlockArgument arg,unsigned index,llvm::StringRef name)172 StatusOr<std::unique_ptr<NodeDef>> Exporter::GetArgumentNode(
173 BlockArgument arg, unsigned index, llvm::StringRef name) {
174 auto func = arg.getParentRegion()->getParentOfType<FuncOp>();
175
176 auto node_def = absl::make_unique<NodeDef>();
177 if (!name.empty())
178 node_def->set_name(name.str());
179 else
180 node_def->set_name(
181 std::string(op_to_name_.GetUniqueName(func.getName().str())));
182
183 node_def->set_op(FunctionLibraryDefinition::kArgOp);
184
185 mlir::TensorType arg_type = arg.getType().cast<mlir::TensorType>();
186 if (auto resource_type =
187 arg_type.getElementType().dyn_cast<mlir::TF::ResourceType>()) {
188 llvm::ArrayRef<mlir::TensorType> subtypes = resource_type.getSubtypes();
189 if (!subtypes.empty()) {
190 AttrValue handle_dtypes_attr;
191 AttrValue handle_shapes_attr;
192 for (mlir::TensorType subtype : subtypes) {
193 DataType dtype;
194 TF_RETURN_IF_ERROR(ConvertToDataType(subtype.getElementType(), &dtype));
195 handle_dtypes_attr.mutable_list()->add_type(dtype);
196
197 SetTensorShapeProto(subtype,
198 handle_shapes_attr.mutable_list()->add_shape());
199 }
200
201 (*node_def->mutable_attr())["_handle_dtypes"] = handle_dtypes_attr;
202 (*node_def->mutable_attr())["_handle_shapes"] = handle_shapes_attr;
203 }
204 }
205
206 TF_RETURN_IF_ERROR(
207 SetShapeAttribute("_output_shapes", arg_type, node_def->mutable_attr()));
208
209 DataType dtype;
210 TF_RETURN_IF_ERROR(ConvertToDataType(arg_type.getElementType(), &dtype));
211 AttrValue type_attr;
212 type_attr.set_type(dtype);
213 (*node_def->mutable_attr())["T"] = type_attr;
214
215 AttrValue index_attr;
216 index_attr.set_i(index);
217 (*node_def->mutable_attr())["index"] = index_attr;
218
219 if (auto device_attr =
220 func.getArgAttrOfType<mlir::StringAttr>(index, kDeviceAttr))
221 *node_def->mutable_device() = device_attr.getValue().str();
222
223 llvm::ArrayRef<mlir::NamedAttribute> func_arg_i_attrs =
224 func.getArgAttrs(index);
225 absl::flat_hash_set<absl::string_view> attrs_to_ignore = {kDeviceAttr,
226 kAliasingAttr};
227 TF_RETURN_IF_ERROR(ConvertAttributes(func_arg_i_attrs, attrs_to_ignore,
228 /*remove_ref_type=*/false,
229 node_def->mutable_attr()));
230
231 return node_def;
232 }
233
GetReturnNode(FuncOp function,Value operand,unsigned index,llvm::StringRef name)234 StatusOr<std::unique_ptr<NodeDef>> Exporter::GetReturnNode(
235 FuncOp function, Value operand, unsigned index, llvm::StringRef name) {
236 auto node_def = absl::make_unique<NodeDef>();
237 if (!name.empty())
238 node_def->set_name(name.str());
239 else
240 node_def->set_name(
241 std::string(op_to_name_.GetUniqueName(function.getName().str())));
242
243 node_def->set_op(FunctionLibraryDefinition::kRetOp);
244 DataType dtype;
245 TF_RETURN_IF_ERROR(ConvertToDataType(
246 operand.getType().cast<mlir::TensorType>().getElementType(), &dtype));
247 AttrValue type_attr;
248 type_attr.set_type(dtype);
249 (*node_def->mutable_attr())["T"] = type_attr;
250 AttrValue index_attr;
251 index_attr.set_i(index);
252 (*node_def->mutable_attr())["index"] = index_attr;
253
254 if (auto device_attr =
255 function.getResultAttrOfType<mlir::StringAttr>(index, kDeviceAttr))
256 *node_def->mutable_device() = device_attr.getValue().str();
257
258 llvm::ArrayRef<mlir::NamedAttribute> func_res_i_attrs =
259 function.getResultAttrs(index);
260 absl::flat_hash_set<absl::string_view> attrs_to_ignore = {kDeviceAttr};
261 TF_RETURN_IF_ERROR(ConvertAttributes(func_res_i_attrs, attrs_to_ignore,
262 /*remove_ref_type=*/false,
263 node_def->mutable_attr()));
264
265 return node_def;
266 }
267
AddEdgeBetweenNodes(Value src,Node * dst_node,unsigned dst_index)268 Status Exporter::AddEdgeBetweenNodes(Value src, Node* dst_node,
269 unsigned dst_index) {
270 if (auto input_result = src.dyn_cast<mlir::OpResult>()) {
271 auto* input_inst = GetIslandInnerOpOrSelf(input_result.getOwner());
272 // Replaces the input node with NextIteration sink if it is a NextIteration
273 // source.
274 if (auto next_iter_source =
275 llvm::dyn_cast<mlir::tf_executor::NextIterationSourceOp>(
276 input_inst))
277 input_inst = next_iter_source.GetSink();
278
279 auto node_it = nodes_.find(input_inst);
280 TF_RET_CHECK(node_it != nodes_.end())
281 << "Use of OpResult encountered before def!";
282 if (input_result.getType().isa<mlir::tf_executor::ControlType>()) {
283 graph_->AddControlEdge(node_it->second, dst_node);
284 } else {
285 graph_->AddEdge(node_it->second, input_result.getResultNumber(), dst_node,
286 dst_index);
287 }
288 return Status::OK();
289 }
290
291 auto input_arg = src.cast<BlockArgument>();
292 auto input_node_it = args_.find(input_arg);
293 TF_RET_CHECK(input_node_it != args_.end())
294 << "Use of BlockArgument encounted before def!";
295 // For argument, there is only one result output, so the index is always 0.
296 graph_->AddEdge(input_node_it->second, 0, dst_node, dst_index);
297 return Status::OK();
298 }
299
AddEdge(Operation * inst)300 Status Exporter::AddEdge(Operation* inst) {
301 // For tf_executor.fetch, add only its data edges. Control edges are captured
302 // later.
303 if (auto fetch = llvm::dyn_cast<mlir::tf_executor::FetchOp>(inst)) {
304 for (auto operand_and_idx : llvm::enumerate(fetch.getOperands())) {
305 Value operand = operand_and_idx.value();
306 if (operand.getType().isa<mlir::tf_executor::ControlType>()) break;
307
308 auto* dst_node = returns_[fetch][operand_and_idx.index()];
309 TF_RETURN_IF_ERROR(AddEdgeBetweenNodes(operand, dst_node, 0));
310 }
311
312 return Status::OK();
313 }
314
315 // For tf_executor.NextIteration.Sink, skip its token operand and add data and
316 // control edges with their index offset by 1.
317 if (auto next_iter_sink =
318 llvm::dyn_cast<mlir::tf_executor::NextIterationSinkOp>(inst)) {
319 auto* dst_node = nodes_[inst];
320 TF_RETURN_IF_ERROR(
321 AddEdgeBetweenNodes(next_iter_sink.input(), dst_node, 0));
322 for (auto control_and_idx : llvm::enumerate(next_iter_sink.controlInputs()))
323 TF_RETURN_IF_ERROR(AddEdgeBetweenNodes(control_and_idx.value(), dst_node,
324 control_and_idx.index() + 1));
325
326 return Status::OK();
327 }
328
329 // For tf_executor.NextIteration.Source, op can be skipped as it is assumed
330 // there are no operands.
331 if (llvm::isa<mlir::tf_executor::NextIterationSourceOp>(inst)) {
332 assert(inst->getNumOperands() == 0);
333 return Status::OK();
334 }
335
336 Operation* op = GetIslandInnerOpOrSelf(inst);
337 auto* dst_node = nodes_[op];
338 int operand_offset = 0;
339 // For tf_executor.island, add data edges from its wrapped op before control
340 // edges.
341 if (auto island = llvm::dyn_cast<mlir::tf_executor::IslandOp>(inst)) {
342 for (auto operand_and_idx : llvm::enumerate(op->getOperands()))
343 TF_RETURN_IF_ERROR(AddEdgeBetweenNodes(operand_and_idx.value(), dst_node,
344 operand_and_idx.index()));
345
346 operand_offset = op->getNumOperands();
347 }
348
349 // For all other ops (including tf_executor.island), add remaining edges.
350 for (auto operand_and_idx : llvm::enumerate(inst->getOperands()))
351 TF_RETURN_IF_ERROR(
352 AddEdgeBetweenNodes(operand_and_idx.value(), dst_node,
353 operand_and_idx.index() + operand_offset));
354
355 return Status::OK();
356 }
357
AddInstructionNode(Operation * inst)358 Status Exporter::AddInstructionNode(Operation* inst) {
359 std::unique_ptr<NodeDef> node_def;
360 auto name = op_to_name_.GetUniqueName(inst);
361 // Convert registered TF ops to NodeDef. Only registered ops are handled to
362 // ensure that PopulateDerivedAttrs adds the correct attributes.
363 TF_ASSIGN_OR_RETURN(node_def,
364 ConvertTFDialectOpToNodeDef(
365 inst, name, /*ignore_unregistered_attrs=*/false));
366
367 Status status;
368 Node* node = graph_->AddNode(*node_def, &status);
369 TF_RETURN_IF_ERROR(status);
370 DCHECK(node != nullptr);
371 nodes_[inst] = node;
372 return Status::OK();
373 }
374
IsEntryFunctionArg(BlockArgument arg)375 bool IsEntryFunctionArg(BlockArgument arg) {
376 return arg.getParentRegion()->getParentOfType<FuncOp>().getName() == "main";
377 }
378
379 // Creates argument nodes from Block argument. If a name is supplied, that
380 // name will be used instead of generating a unique name.
AddArgumentNode(BlockArgument arg,unsigned index,llvm::StringRef name)381 Status Exporter::AddArgumentNode(BlockArgument arg, unsigned index,
382 llvm::StringRef name) {
383 TF_ASSIGN_OR_RETURN(auto node_def, GetArgumentNode(arg, index, name));
384 Status status;
385 Node* node = graph_->AddNode(*node_def, &status);
386 TF_RETURN_IF_ERROR(status);
387 args_[arg] = node;
388 return Status::OK();
389 }
390
391 // Creates return nodes per operand of a FetchOp. If names is supplied, those
392 // names will be used per node in order instead of generating a unique name.
AddFetchNode(FuncOp function,mlir::tf_executor::FetchOp fetch,llvm::ArrayRef<llvm::StringRef> names)393 Status Exporter::AddFetchNode(FuncOp function, mlir::tf_executor::FetchOp fetch,
394 llvm::ArrayRef<llvm::StringRef> names) {
395 Status status;
396 auto& return_nodes = returns_[fetch];
397 for (auto operand_and_idx : llvm::enumerate(fetch.getOperands())) {
398 if (operand_and_idx.value().getType().isa<mlir::tf_executor::ControlType>())
399 break;
400
401 TF_ASSIGN_OR_RETURN(
402 auto node_def,
403 GetReturnNode(function, operand_and_idx.value(),
404 operand_and_idx.index(),
405 names.empty() ? "" : names[operand_and_idx.index()]));
406 Node* node = graph_->AddNode(*node_def, &status);
407 TF_RETURN_IF_ERROR(status);
408 return_nodes.push_back(node);
409 }
410 return Status::OK();
411 }
412
413 // Collects control ret Nodes based on tf_executor.graph's associated
414 // tf_executor.fetch control inputs.
GetControlRetNodes(mlir::tf_executor::FetchOp fetch,absl::flat_hash_set<Node * > * control_ret_nodes)415 Status Exporter::GetControlRetNodes(
416 mlir::tf_executor::FetchOp fetch,
417 absl::flat_hash_set<Node*>* control_ret_nodes) {
418 for (Value fetch_operand : fetch.getOperands()) {
419 if (fetch_operand.getType().isa<mlir::tf_executor::ControlType>()) {
420 Operation* defining_op =
421 GetIslandInnerOpOrSelf(fetch_operand.getDefiningOp());
422 auto node_it = nodes_.find(defining_op);
423 TF_RET_CHECK(node_it != nodes_.end());
424 control_ret_nodes->insert(node_it->second);
425 }
426 }
427 return Status::OK();
428 }
429
Convert(const GraphExportConfig & configs,const Dialect * tf_dialect,const SymbolTable & symbol_table,FuncOp function,FunctionDefLibrary * flib,llvm::SmallDenseSet<FuncOp> & visited_functions,absl::flat_hash_set<Node * > * control_ret_nodes)430 StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
431 const GraphExportConfig& configs, const Dialect* tf_dialect,
432 const SymbolTable& symbol_table, FuncOp function, FunctionDefLibrary* flib,
433 llvm::SmallDenseSet<FuncOp>& visited_functions,
434 absl::flat_hash_set<Node*>* control_ret_nodes) {
435 mlir::Block& block = function.front();
436
437 // Extract input & output names if set.
438 llvm::SmallVector<llvm::StringRef, 2> input_names;
439 llvm::SmallVector<llvm::StringRef, 2> output_names;
440 auto dict_attr =
441 function->getAttrOfType<mlir::DictionaryAttr>(kEntryFuncAttr);
442 if (dict_attr) {
443 TF_RET_CHECK(dict_attr.get("inputs").isa<mlir::StringAttr>())
444 << "inputs missing in entry function attribute";
445 TF_RET_CHECK(dict_attr.get("outputs").isa<mlir::StringAttr>())
446 << "outputs missing in entry function attribute";
447 dict_attr.get("inputs").cast<mlir::StringAttr>().getValue().split(
448 input_names, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false);
449 dict_attr.get("outputs").cast<mlir::StringAttr>().getValue().split(
450 output_names, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false);
451 }
452
453 auto graph = absl::make_unique<Graph>(OpRegistry::Global());
454
455 // Extract version info.
456 VersionDef versions;
457 auto module = function->getParentOfType<mlir::ModuleOp>();
458 if (mlir::succeeded(ExtractTfVersions(module, &versions))) {
459 graph->set_versions(versions);
460 }
461
462 Exporter exporter(graph.get(), tf_dialect);
463
464 auto graph_op = llvm::cast<mlir::tf_executor::GraphOp>(block.front());
465
466 // Set input and output names and increment the use counter for them to help
467 // generate unique names.
468 if (!output_names.empty()) {
469 const int num_data_results = graph_op.getNumResults();
470 const int64_t output_names_size = output_names.size();
471 TF_RET_CHECK(output_names_size == num_data_results)
472 << "output names (" << output_names.size()
473 << ") != terminator operands (" << num_data_results << ")";
474 llvm::DenseMap<Operation*, llvm::StringRef> output_op_to_name;
475 llvm::StringMap<Operation*> name_to_op;
476 for (const auto& it : llvm::enumerate(graph_op.GetFetch().getOperands())) {
477 // Skip control rets.
478 const int64_t index = it.index();
479 if (index >= num_data_results) break;
480 // TODO(jpienaar): If there is a result index specified, ensure only one
481 // and that it matches the result index of the op.
482 std::string name(output_names[index]);
483 auto tensor_id = ParseTensorName(name);
484 std::string tensor_id_node(tensor_id.node());
485 assert(!tensor_id_node.empty() && "expected non-empty name");
486 mlir::LegalizeNodeName(tensor_id_node);
487
488 // Ensure name does not get reused.
489 (void)exporter.op_to_name_.GetUniqueName(tensor_id_node);
490 }
491 }
492
493 if (!input_names.empty()) {
494 TF_RET_CHECK(input_names.size() == block.getNumArguments());
495 for (const auto& it : llvm::enumerate(function.getArguments())) {
496 // TODO(lyandy): Update when changing feed/fetch import.
497 std::string name(input_names[it.index()]);
498 assert(!name.empty() && "expected non-empty name");
499 mlir::LegalizeNodeName(name);
500 auto tensor_id = ParseTensorName(name);
501 TF_RET_CHECK(tensor_id.index() == 0)
502 << "input port designation not supported";
503 // Only assign user of argument the input name if the main graph did not
504 // have its _Arg nodes lifted into the functions arguments.
505 // Ensure name does not get reused.
506 (void)exporter.op_to_name_.GetUniqueName(name);
507 }
508 }
509
510 // Adds nodes for basic block (function) arguments.
511 for (auto it : llvm::enumerate(block.getArguments())) {
512 int index = it.index();
513 auto arg = it.value();
514 mlir::Type type = arg.getType();
515 if (!type.isa<mlir::TensorType>()) {
516 return errors::InvalidArgument(
517 "FuncOps arguments must have tensor types. Found ",
518 mlir::debugString(type), " in function ", function.getName().str());
519 }
520
521 TF_RETURN_IF_ERROR(exporter.AddArgumentNode(
522 arg, index, !input_names.empty() ? input_names[index] : ""));
523 }
524
525 auto convert_called_function = [&](llvm::StringRef name) {
526 auto func = symbol_table.lookup<FuncOp>(name);
527 if (func != nullptr) {
528 TF_RETURN_IF_ERROR(ConvertLibFunction(configs, tf_dialect, symbol_table,
529 func, flib, visited_functions));
530 // TODO(prakalps): Optimize to only add the requested function to graph
531 // library rather than the all the functions exported so far.
532 TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(*flib));
533 }
534 return Status::OK();
535 };
536
537 // Adds nodes for operations.
538 for (Operation& inst : graph_op.GetBody()) {
539 for (auto type : inst.getResultTypes())
540 if (!type.isa<mlir::TensorType, mlir::tf_executor::ControlType,
541 mlir::tf_executor::TokenType>())
542 return errors::InvalidArgument(
543 "Values must be of tensor type, TensorFlow control type, or "
544 "TensorFlow token type. Found ",
545 mlir::debugString(type));
546
547 if (llvm::isa<mlir::tf_executor::NextIterationSourceOp>(inst)) {
548 // Skip tf_executor.NextIteration.Source as associated
549 // tf_executor.NextIteration.Sink will be used instead.
550 continue;
551 } else if (auto fetch = llvm::dyn_cast<mlir::tf_executor::FetchOp>(inst)) {
552 TF_RETURN_IF_ERROR(exporter.AddFetchNode(function, fetch, output_names));
553 } else if (auto island =
554 llvm::dyn_cast<mlir::tf_executor::IslandOp>(inst)) {
555 Operation& inner_op = island.GetBody().front();
556 auto op_name = GetTensorFlowOpName(inner_op.getName().getStringRef());
557 if (op_name.ok()) {
558 // If it is TF Control dialect specific op, look up custom operation
559 // in the module and first convert that, then add it to function
560 // definition library
561 // TODO(prakalps): If two functions have cyclic dependence, this will
562 // introduce an infinite loop.
563 TF_RETURN_IF_ERROR(convert_called_function(op_name.ValueOrDie().str()));
564 }
565
566 if (IsLegacyCallInstruction(&inner_op)) {
567 TF_RETURN_IF_ERROR(convert_called_function(
568 inner_op.getAttrOfType<mlir::SymbolRefAttr>("f")
569 .getLeafReference()));
570 }
571
572 TF_RETURN_IF_ERROR(exporter.AddInstructionNode(&inner_op));
573 } else {
574 TF_RETURN_IF_ERROR(exporter.AddInstructionNode(&inst));
575 }
576 }
577 // Adds edges between the argument, operation and return nodes.
578 for (Operation& inst : graph_op.GetBody()) {
579 TF_RETURN_IF_ERROR(exporter.AddEdge(&inst));
580 }
581 // Fixes the edges between the inserted nodes and special "_SOURCE" and
582 // "_SINK".
583 FixupSourceAndSinkEdges(graph.get());
584
585 TF_RETURN_IF_ERROR(
586 exporter.GetControlRetNodes(graph_op.GetFetch(), control_ret_nodes));
587
588 return graph;
589 }
590
ConvertLibFunction(const GraphExportConfig & configs,const Dialect * tf_dialect,const SymbolTable & symbol_table,FuncOp function,FunctionDefLibrary * flib,llvm::SmallDenseSet<FuncOp> & visited_functions)591 Status Exporter::ConvertLibFunction(
592 const GraphExportConfig& configs, const Dialect* tf_dialect,
593 const SymbolTable& symbol_table, FuncOp function, FunctionDefLibrary* flib,
594 llvm::SmallDenseSet<FuncOp>& visited_functions) {
595 // Return early if the function has already been exported.
596 bool is_new_function = visited_functions.insert(function).second;
597 if (!is_new_function) return Status::OK();
598
599 auto function_name = function.getName().str();
600
601 // TODO(fengliuai): use a small flib_def to reduce overhead
602 absl::flat_hash_set<Node*> control_ret_nodes;
603 TF_ASSIGN_OR_RETURN(
604 auto sub_graph,
605 Exporter::Convert(configs, tf_dialect, symbol_table, function, flib,
606 visited_functions, &control_ret_nodes));
607 const auto control_ret = [&](const Node* n) -> absl::optional<string> {
608 return control_ret_nodes.contains(n)
609 ? absl::make_optional<string>(n->name())
610 : absl::nullopt;
611 };
612 FunctionDef func_def;
613 TF_RETURN_IF_ERROR(
614 GraphToFunctionDef(*sub_graph, function_name, control_ret, &func_def));
615
616 // The node defs in FunctionDef might contain debug info which was added
617 // by the GraphToFunctionDef method. We should remove it if we don't want
618 // to export them to avoid failing the roundtrip test.
619 if (!configs.export_debug_info) {
620 for (auto& node_def : *func_def.mutable_node_def()) {
621 node_def.clear_experimental_debug_info();
622 }
623 }
624
625 // Checks for gradient attribute. If present converts the gradient function
626 // and populates the GradientDef.
627 auto grad_string = mlir::TF::TensorFlowDialect::GetGradientAttrName();
628 if (auto attr =
629 function->getAttrOfType<mlir::FlatSymbolRefAttr>(grad_string)) {
630 auto grad_func = symbol_table.lookup<FuncOp>(attr.getValue());
631 TF_RETURN_IF_ERROR(ConvertLibFunction(configs, tf_dialect, symbol_table,
632 grad_func, flib, visited_functions));
633 GradientDef grad;
634 grad.set_function_name(function_name);
635 grad.set_gradient_func(grad_func.getName().str());
636 *flib->add_gradient() = grad;
637 }
638
639 auto stateful_string = mlir::TF::TensorFlowDialect::GetStatefulAttrName();
640 if (auto attr = function->getAttrOfType<mlir::UnitAttr>(stateful_string)) {
641 func_def.mutable_signature()->set_is_stateful(true);
642 }
643
644 // Ignore the gradient and is_stateful attribute on the function as they have
645 // been handled above. Ignore the entry func attribute as it is an MLIR
646 // metadata attribute and is not required in the function definition.
647 absl::flat_hash_set<absl::string_view> attrs_to_ignore = {
648 grad_string.data(), stateful_string.data(), kEntryFuncAttr};
649 llvm::SmallVector<mlir::NamedAttribute, 8> funcAttrs(
650 function->getDialectAttrs());
651 TF_RETURN_IF_ERROR(ConvertAttributes(funcAttrs, attrs_to_ignore,
652 /*remove_ref_type=*/false,
653 func_def.mutable_attr()));
654
655 for (int i = 0, e = function.getNumArguments(); i < e; ++i) {
656 if (auto resource_arg_unique_id_attr =
657 function.getArgAttrOfType<mlir::IntegerAttr>(
658 i, kResourceArgUniqueIdAttr)) {
659 (*func_def.mutable_resource_arg_unique_id())[i] =
660 resource_arg_unique_id_attr.getInt();
661 }
662
663 llvm::ArrayRef<mlir::NamedAttribute> func_arg_i_attrs =
664 function.getArgAttrs(i);
665 if (func_arg_i_attrs.empty()) continue;
666 absl::flat_hash_set<absl::string_view> attrs_to_ignore = {
667 kDeviceAttr, kResourceArgUniqueIdAttr};
668 FunctionDef::ArgAttrs func_def_arg_i_attrs;
669 TF_RETURN_IF_ERROR(ConvertAttributes(func_arg_i_attrs, attrs_to_ignore,
670 /*remove_ref_type=*/false,
671 func_def_arg_i_attrs.mutable_attr()));
672 if (func_def_arg_i_attrs.attr().empty()) continue;
673 (*func_def.mutable_arg_attr())[i] = std::move(func_def_arg_i_attrs);
674 }
675
676 (*flib->add_function()) = std::move(func_def);
677 return Status::OK();
678 }
679
Convert(mlir::ModuleOp module,const GraphExportConfig & configs,std::unique_ptr<Graph> * graph,FunctionLibraryDefinition * flib_def,absl::flat_hash_set<Node * > * control_ret_nodes)680 Status Exporter::Convert(mlir::ModuleOp module,
681 const GraphExportConfig& configs,
682 std::unique_ptr<Graph>* graph,
683 FunctionLibraryDefinition* flib_def,
684 absl::flat_hash_set<Node*>* control_ret_nodes) {
685 mlir::Identifier entry_func_id =
686 mlir::Identifier::get("main", module.getContext());
687 absl::optional<FuncOp> entry_func;
688 FunctionDefLibrary flib;
689 llvm::SmallDenseSet<FuncOp> visited_functions;
690 auto tf_dialect = module.getContext()->getLoadedDialect("tf");
691 // Construct SymbolTable to enable cheap function lookups. The cost
692 // of constructing the table is offset by the number of queries.
693 SymbolTable symbol_table(module);
694 for (auto function : module.getOps<FuncOp>()) {
695 if (function.isExternal())
696 return errors::FailedPrecondition("External functions not supported");
697
698 if (function.getName() == entry_func_id &&
699 !configs.export_entry_func_to_flib) {
700 entry_func.emplace(function);
701 } else {
702 TF_RETURN_IF_ERROR(ConvertLibFunction(configs, tf_dialect, symbol_table,
703 function, &flib,
704 visited_functions));
705 }
706 }
707
708 if (!configs.export_entry_func_to_flib) {
709 if (!entry_func.has_value())
710 return errors::FailedPrecondition(
711 "entry function `main` must be present");
712
713 // Updates the graph and the function library definition.
714 TF_ASSIGN_OR_RETURN(
715 *graph,
716 Exporter::Convert(configs, tf_dialect, symbol_table, entry_func.value(),
717 &flib, visited_functions, control_ret_nodes));
718 // Add FunctionDefs and GradientDefs of MLIR functions to graph's function
719 // library. If duplicate FunctionDefs already exist (can happen if exporter
720 // had already added some FunctionDefs to the library to support legacy
721 // calls), they are ignored.
722 TF_RETURN_IF_ERROR(graph->get()->AddFunctionLibrary(flib));
723 }
724
725 for (auto& func_def : flib.function()) {
726 TF_RETURN_IF_ERROR(flib_def->AddFunctionDef(func_def));
727 }
728 for (auto& grad_def : flib.gradient()) {
729 TF_RETURN_IF_ERROR(flib_def->AddGradientDef(grad_def));
730 }
731 return Status::OK();
732 }
733 } // namespace
734
ConvertMlirToGraph(mlir::ModuleOp module,const GraphExportConfig & configs,std::unique_ptr<Graph> * graph,FunctionLibraryDefinition * flib_def,absl::flat_hash_set<Node * > * control_ret_nodes)735 Status ConvertMlirToGraph(mlir::ModuleOp module,
736 const GraphExportConfig& configs,
737 std::unique_ptr<Graph>* graph,
738 FunctionLibraryDefinition* flib_def,
739 absl::flat_hash_set<Node*>* control_ret_nodes) {
740 mlir::StatusScopedDiagnosticHandler sh(module.getContext());
741 if (failed(VerifyExportSuitable(module))) return sh.ConsumeStatus();
742 return sh.Combine(
743 Exporter::Convert(module, configs, graph, flib_def, control_ret_nodes));
744 }
745
ConvertMlirToGraph(mlir::ModuleOp module,const GraphExportConfig & configs,std::unique_ptr<Graph> * graph,FunctionLibraryDefinition * flib_def)746 Status ConvertMlirToGraph(mlir::ModuleOp module,
747 const GraphExportConfig& configs,
748 std::unique_ptr<Graph>* graph,
749 FunctionLibraryDefinition* flib_def) {
750 absl::flat_hash_set<Node*> control_ret_nodes;
751 return ConvertMlirToGraph(module, configs, graph, flib_def,
752 &control_ret_nodes);
753 }
754
ConvertMlirToGraphdef(mlir::ModuleOp module,const GraphExportConfig & configs)755 StatusOr<std::unique_ptr<GraphDef>> ConvertMlirToGraphdef(
756 mlir::ModuleOp module, const GraphExportConfig& configs) {
757 FunctionLibraryDefinition flib_def(OpRegistry::Global(),
758 FunctionDefLibrary());
759 std::unique_ptr<Graph> graph;
760 TF_RETURN_IF_ERROR(ConvertMlirToGraph(module, configs, &graph, &flib_def));
761
762 // If the entry function is exported to flib, then no graph is constructed.
763 // Construct one in that case.
764 if (configs.export_entry_func_to_flib) {
765 graph = std::make_unique<Graph>(OpRegistry::Global());
766 // TODO(hinsu): Avoid Proto -> Memory -> Proto conversion here.
767 FunctionDefLibrary flib = flib_def.ToProto();
768 TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(flib));
769 }
770
771 auto graphdef = absl::make_unique<GraphDef>();
772 graph->ToGraphDef(graphdef.get());
773 if (!configs.export_library) graphdef->clear_library();
774 if (!configs.export_shapes) {
775 for (auto& node_def : *graphdef->mutable_node()) {
776 node_def.mutable_attr()->erase("shape");
777 }
778 }
779 if (!configs.export_debug_info) {
780 for (auto& node_def : *graphdef->mutable_node()) {
781 node_def.clear_experimental_debug_info();
782 }
783 }
784 return graphdef;
785 }
786
ConvertMlirFunctionToFunctionLibraryDef(FuncOp func,const GraphExportConfig & configs,FunctionDef * function_def)787 stream_executor::port::Status ConvertMlirFunctionToFunctionLibraryDef(
788 FuncOp func, const GraphExportConfig& configs, FunctionDef* function_def) {
789 Dialect* tf_dialect = func.getContext()->getLoadedDialect("tf");
790 FunctionDefLibrary flib;
791 llvm::SmallDenseSet<FuncOp> visited_functions;
792 // Construct SymbolTable to enable cheap function lookups. The cost
793 // of constructing the table is offset by the number of queries. Even
794 // though this only converts one function in theory, this function
795 // may have gradient associated which would result in a lookup. This
796 // could be made lazy if we find this to be broad.
797 SymbolTable symbol_table(func->getParentOfType<mlir::ModuleOp>());
798 TF_RETURN_IF_ERROR(Exporter::ConvertLibFunction(
799 configs, tf_dialect, symbol_table, func, &flib, visited_functions));
800 for (auto& func_def : flib.function()) {
801 if (func_def.signature().name() == func.getName()) {
802 *function_def = func_def;
803 return Status::OK();
804 }
805 }
806 return errors::InvalidArgument(
807 "Function couldn't be found in the FunctionDefLibrary after converting "
808 "from MLIR");
809 }
810
811 } // namespace tensorflow
812