• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tfrt/utils/graph_partition.h"
16 
17 #include <algorithm>
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <unordered_map>
22 #include <utility>
23 
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "tensorflow/core/common_runtime/graph_constructor.h"
27 #include "tensorflow/core/common_runtime/inline_function_utils.h"
28 #include "tensorflow/core/common_runtime/partitioning_utils.h"
29 #include "tensorflow/core/framework/function.h"
30 #include "tensorflow/core/framework/graph_to_functiondef.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/graph/graph.h"
33 #include "tensorflow/core/graph/graph_partition.h"
34 #include "tensorflow/core/graph/node_builder.h"
35 #include "tensorflow/core/grappler/utils.h"
36 
37 namespace tensorflow {
38 namespace tfrt_stub {
39 
40 namespace {
41 
42 // An auxiliary struct to record input/output information.
43 struct NodeInfo {
44   Node* node = nullptr;
45   DataType data_type;
46   int index = -1;
47   Node* node_copy = nullptr;
48 };
49 
50 // An auxiliary struct for construction a StatefulPartitionedCallOp enclosing an
51 // IdentityN node.
52 struct CallNodeInputInfo {
53   int index = -1;
54   DataType data_type;
55   Node* input_node = nullptr;
56   int input_node_index = -1;
57 
58   Node* arg_node = nullptr;
59   Node* ret_node = nullptr;
60 };
61 
62 struct OutputNodeInfo {
63   absl::flat_hash_map<std::string, NodeInfo> output_nodes;
64   absl::optional<std::pair<std::string, NodeInfo>> auxiliary_output_node;
65 };
66 
67 // Prepares the `subgraph` for the conversion to a function by adding
68 // _Arg/_Retval nodes for input/output nodes respectively, and records
69 // input/output info for the following processing.
70 // TODO(b/217581711): Consider to use another GraphToFunctionDef() helper which
71 // does not require _Arg and _Retval nodes.
PrepareSubgraphForFunctionConversion(const std::vector<std::string> & inputs,const std::vector<std::string> & outputs,const Device * host_device,const std::string & func_name,absl::flat_hash_map<std::string,NodeInfo> & input_nodes,absl::flat_hash_map<std::string,NodeInfo> & output_nodes,absl::optional<std::pair<std::string,NodeInfo>> & auxiliary_output_node,Graph * subgraph,Graph * graph)72 Status PrepareSubgraphForFunctionConversion(
73     const std::vector<std::string>& inputs,
74     const std::vector<std::string>& outputs, const Device* host_device,
75     const std::string& func_name,
76     absl::flat_hash_map<std::string, NodeInfo>& input_nodes,
77     absl::flat_hash_map<std::string, NodeInfo>& output_nodes,
78     absl::optional<std::pair<std::string, NodeInfo>>& auxiliary_output_node,
79     Graph* subgraph, Graph* graph) {
80   std::unordered_map<std::string, Node*> name_to_node_map =
81       subgraph->BuildNodeNameIndex();
82 
83   int input_index = 0, output_index = 0;
84 
85   // For each input node in this subgraph, replace it with an _Arg node.
86   for (const auto& input : inputs) {
87     int position = -1;
88     std::string node_name = grappler::ParseNodeName(input, &position);
89     if (position != 0) {
90       return errors::Unimplemented(
91           "Support for input node with multiple output tensors is not "
92           "implemented.");
93     }
94     if (name_to_node_map.count(node_name) == 0) continue;
95 
96     Node* node = name_to_node_map.at(node_name);
97     NodeInfo node_info;
98     node_info.node = node;
99     node_info.data_type = node->output_type(position);
100     node_info.index = input_index++;
101     // Copy the input node which will be removed from the subgraph below.
102     // The copied node will be used in the top-level graph.
103     node_info.node_copy = graph->CopyNode(node);
104 
105     input_nodes.emplace(node->name(), node_info);
106 
107     // Create an _Arg node to replace the input node.
108     TF_ASSIGN_OR_RETURN(
109         Node * arg_node,
110         NodeBuilder(absl::StrCat("arg_", node_info.index, "/", node->name()),
111                     "_Arg")
112             .Attr("index", node_info.index)
113             .Attr("T", node_info.data_type)
114             .Finalize(subgraph));
115 
116     CHECK_EQ(node->num_inputs(), 0);
117     std::vector<const Edge*> out_edges(node->out_edges().begin(),
118                                        node->out_edges().end());
119     for (const Edge* edge : out_edges) {
120       if (edge->IsControlEdge()) {
121         subgraph->AddControlEdge(arg_node, edge->dst());
122       } else {
123         TF_RETURN_IF_ERROR(
124             subgraph->UpdateEdge(arg_node, 0, edge->dst(), edge->dst_input()));
125       }
126     }
127     subgraph->RemoveNode(node);
128   }
129 
130   // For each output node in this subgraph, connect it to a _Retval node.
131   for (const auto& output : outputs) {
132     int position = -1;
133     std::string node_name = grappler::ParseNodeName(output, &position);
134     if (position != 0) {
135       return errors::Unimplemented(
136           "Support for output node with multiple output tensors is not "
137           "implemented.");
138     }
139     if (name_to_node_map.count(node_name) == 0) continue;
140 
141     Node* node = name_to_node_map.at(node_name);
142     NodeInfo node_info;
143     node_info.node = node;
144     node_info.data_type = node->output_type(position);
145     node_info.index = output_index++;
146 
147     output_nodes.emplace(node->name(), node_info);
148 
149     // Create a _RetArg node, and append it to the original output node.
150     TF_ASSIGN_OR_RETURN(
151         Node * ret_node,
152         NodeBuilder(absl::StrCat("ret_", node_info.index, "/", node->name()),
153                     "_Retval")
154             .Attr("index", node_info.index)
155             .Attr("T", node_info.data_type)
156             .Input(NodeBuilder::NodeOut(node->name(), position,
157                                         node_info.data_type))
158             .Finalize(subgraph));
159     // Rename the output node, as there will be a node in the top level with
160     // the same name.
161     node->set_name(node->name() + "/partition_renamed");
162 
163     subgraph->AddEdge(node, 0, ret_node, 0);
164   }
165 
166   // If there is no output for this partition, create an auxiliary output, so
167   // that we can generate a data dependency from the PartitionedCallOp (the
168   // one we are going to create to wrap this partition) to a downstream
169   // stateful node. This helps to preserve the stateless PartitionedCallOp in
170   // the subsequent MLIR lowering passes; otherwise, it will be pruned if there
171   // is only a control dependency between PartitionedCallOp and another op
172   // node, because PartitionedCallOp is stateless and the control dependency
173   // will get lost during MLIR lowering with current side effect analysis
174   // (b/232026253).
175   if (output_nodes.empty()) {
176     // Create a const node.
177     const DataType data_type = DT_INT32;
178     TensorShape const_shape;
179     Tensor const_tensor(data_type, const_shape);
180     const_tensor.flat<int>()(0) = 0;
181     TF_ASSIGN_OR_RETURN(
182         Node * const_node,
183         NodeBuilder(absl::StrCat("const/unused/", func_name), "Const")
184             .AssignedDevice(host_device->name())
185             .Attr("dtype", data_type)
186             .Attr("value", const_tensor)
187             .Finalize(subgraph));
188 
189     NodeInfo node_info;
190     node_info.node = const_node;
191     node_info.data_type = data_type;
192     node_info.index = output_index++;
193     auxiliary_output_node.emplace(const_node->name(), node_info);
194 
195     // Create a _RetArg node, and append to the const node created above.
196     TF_ASSIGN_OR_RETURN(
197         Node * ret_node,
198         NodeBuilder(
199             absl::StrCat("ret_", node_info.index, "/", const_node->name()),
200             "_Retval")
201             .Attr("index", node_info.index)
202             .Attr("T", data_type)
203             .Input(NodeBuilder::NodeOut(const_node->name(), 0, data_type))
204             .Finalize(subgraph));
205 
206     subgraph->AddEdge(const_node, 0, ret_node, 0);
207   }
208   return OkStatus();
209 }
210 
211 // Converts the subgraph to a function, and builds a PartitionedCallOp
212 // to invoke the function.
BuildPartitionedCallOp(const std::string & func_name,const Device * host_device,const std::string & device,const absl::flat_hash_map<std::string,NodeInfo> & input_nodes,const absl::flat_hash_map<std::string,NodeInfo> & output_nodes,const absl::optional<std::pair<std::string,NodeInfo>> & auxiliary_output_node,const std::vector<std::string> & control_outputs,Graph * subgraph,Graph * graph)213 StatusOr<Node*> BuildPartitionedCallOp(
214     const std::string& func_name, const Device* host_device,
215     const std::string& device,
216     const absl::flat_hash_map<std::string, NodeInfo>& input_nodes,
217     const absl::flat_hash_map<std::string, NodeInfo>& output_nodes,
218     const absl::optional<std::pair<std::string, NodeInfo>>&
219         auxiliary_output_node,
220     const std::vector<std::string>& control_outputs, Graph* subgraph,
221     Graph* graph) {
222   // Build the call node.
223   std::string call_node_name = absl::StrCat("partitioned_call/", func_name);
224   NodeBuilder call_builder(call_node_name, "PartitionedCall");
225   call_builder.AssignedDevice(host_device->name());
226   call_builder.Attr(tensorflow::kNoInlineAttr, true);
227 
228   std::vector<DataType> input_dtypes(input_nodes.size());
229   for (const auto& input_node : input_nodes) {
230     input_dtypes[input_node.second.index] = input_node.second.data_type;
231   }
232   call_builder.Attr("Tin", input_dtypes);
233 
234   CHECK(auxiliary_output_node ? output_nodes.empty() : !output_nodes.empty());
235   std::vector<DataType> output_dtypes(
236       auxiliary_output_node ? 1 : output_nodes.size());
237   if (auxiliary_output_node) {
238     CHECK_EQ(auxiliary_output_node->second.index, 0);
239     output_dtypes[auxiliary_output_node->second.index] =
240         auxiliary_output_node->second.data_type;
241   } else {
242     for (const auto& output_node : output_nodes) {
243       output_dtypes[output_node.second.index] = output_node.second.data_type;
244     }
245   }
246   call_builder.Attr("Tout", output_dtypes);
247 
248   std::vector<NodeBuilder::NodeOut> call_node_inputs(input_nodes.size());
249   for (const auto& input_node : input_nodes) {
250     call_node_inputs[input_node.second.index] =
251         NodeBuilder::NodeOut(input_node.second.node_copy, 0);
252   }
253   call_builder.Input(call_node_inputs);
254 
255   NameAttrList f;
256   f.set_name(func_name);
257   call_builder.Attr("f", f);
258   TF_ASSIGN_OR_RETURN(Node * call_node, call_builder.Finalize(graph));
259 
260   // Convert the subgraph to a function.
261   absl::flat_hash_set<std::string> control_ret_names(control_outputs.begin(),
262                                                      control_outputs.end());
263   // After graph partition, there are send ops added as new end nodes.
264   // The completion of the graph requires the send ops to be executed.
265   for (const Node* node : subgraph->op_nodes()) {
266     if (node->IsSend()) {
267       control_ret_names.insert(node->name());
268     }
269   }
270   auto control_ret_node_names =
271       [&control_ret_names](const Node* node) -> absl::optional<std::string> {
272     if (control_ret_names.contains(node->name())) {
273       return node->name();
274     }
275     return absl::nullopt;
276   };
277 
278   FunctionDef new_fdef;
279   TF_RETURN_IF_ERROR(GraphToFunctionDef(*subgraph, func_name,
280                                         control_ret_node_names, &new_fdef));
281   // Set the `_noinline` attribute for the function to make sure it does not
282   // get inlined, and its corresponding function in TF MLIR does not get inlined
283   // in lowering passes as well.
284   (*new_fdef.mutable_attr())[tensorflow::kNoInlineAttr].set_b(true);
285   (*new_fdef.mutable_attr())["device"].set_s(device);
286   TF_RETURN_IF_ERROR(graph->mutable_flib_def()->AddFunctionDef(new_fdef));
287 
288   return call_node;
289 }
290 
291 // Builds a StatefulPartitionedCallOp, and connects all PartitionedCallOps to
292 // it. This StatefulPartitionedCallOp behaves as a stateful IdentityN.
BuildStatefulPartitionedCallOp(absl::flat_hash_map<std::string,CallNodeInputInfo> & call_node_input_info,const absl::flat_hash_map<std::string,Node * > & all_partitioned_call_ops,const std::string & stateful_call_func_name,const Device * host_device,Graph * graph)293 StatusOr<Node*> BuildStatefulPartitionedCallOp(
294     absl::flat_hash_map<std::string, CallNodeInputInfo>& call_node_input_info,
295     const absl::flat_hash_map<std::string, Node*>& all_partitioned_call_ops,
296     const std::string& stateful_call_func_name, const Device* host_device,
297     Graph* graph) {
298   std::string call_node_name =
299       absl::StrCat("stateful_partitioned_call/", stateful_call_func_name);
300   NodeBuilder call_builder(call_node_name, "StatefulPartitionedCall");
301   call_builder.Attr(tensorflow::kNoInlineAttr, true);
302   call_builder.AssignedDevice(host_device->name());
303 
304   int num_output_nodes = call_node_input_info.size();
305   std::vector<DataType> input_dtypes(num_output_nodes);
306   for (const auto& node_info : call_node_input_info) {
307     CHECK(node_info.second.index < num_output_nodes);
308     input_dtypes[node_info.second.index] = node_info.second.data_type;
309   }
310   call_builder.Attr("Tin", input_dtypes);
311   // Outputs are the same as inputs.
312   call_builder.Attr("Tout", input_dtypes);
313 
314   std::vector<NodeBuilder::NodeOut> call_node_inputs(num_output_nodes);
315   for (const auto& node_info : call_node_input_info) {
316     call_node_inputs[node_info.second.index] = NodeBuilder::NodeOut(
317         node_info.second.input_node, node_info.second.input_node_index);
318   }
319   call_builder.Input(call_node_inputs);
320 
321   NameAttrList f;
322   f.set_name(stateful_call_func_name);
323   call_builder.Attr("f", f);
324   TF_ASSIGN_OR_RETURN(Node * stateful_call_node, call_builder.Finalize(graph));
325 
326   // Construct a graph that only contains an IdentityN node, and convert the
327   // graph to a function.
328   auto id_graph = std::make_unique<Graph>(graph->flib_def().default_registry());
329 
330   std::vector<NodeBuilder::NodeOut> output_tensors(num_output_nodes);
331 
332   // Create an _Arg node for each input.
333   for (auto& node_info : call_node_input_info) {
334     TF_ASSIGN_OR_RETURN(node_info.second.arg_node,
335                         NodeBuilder(absl::StrCat("arg_", node_info.second.index,
336                                                  "/", stateful_call_func_name),
337                                     "_Arg")
338                             .Attr("index", node_info.second.index)
339                             .Attr("T", node_info.second.data_type)
340                             .Finalize(id_graph.get()));
341 
342     output_tensors[node_info.second.index] =
343         NodeBuilder::NodeOut(node_info.second.arg_node, 0);
344   }
345 
346   // Create the Identity Node.
347   TF_ASSIGN_OR_RETURN(
348       Node * identity_node,
349       NodeBuilder(absl::StrCat("identityN", "/", stateful_call_func_name),
350                   "IdentityN")
351           .AssignedDevice(host_device->name())
352           .Input(output_tensors)
353           .Finalize(id_graph.get()));
354 
355   // Create a _Retval node for each output.
356   for (auto& node_info : call_node_input_info) {
357     TF_ASSIGN_OR_RETURN(
358         node_info.second.ret_node,
359         NodeBuilder(absl::StrCat("ret_", node_info.second.index, "/",
360                                  stateful_call_func_name),
361                     "_Retval")
362             .Attr("index", node_info.second.index)
363             .Attr("T", node_info.second.data_type)
364             .Input(NodeBuilder::NodeOut(identity_node, node_info.second.index))
365             .Finalize(id_graph.get()));
366 
367     id_graph->AddEdge(identity_node, node_info.second.index,
368                       node_info.second.ret_node, 0);
369   }
370 
371   // Convert the id_graph to a function.
372   FunctionDef id_fdef;
373   TF_RETURN_IF_ERROR(
374       GraphToFunctionDef(*id_graph, stateful_call_func_name, &id_fdef));
375   (*id_fdef.mutable_attr())[tensorflow::kNoInlineAttr].set_b(true);
376   TF_RETURN_IF_ERROR(graph->mutable_flib_def()->AddFunctionDef(id_fdef));
377 
378   return stateful_call_node;
379 }
380 
381 // Returns true if nodes in the `graph` are assigned to multiple devices.
HasMultipleDevices(const Graph * graph)382 bool HasMultipleDevices(const Graph* graph) {
383   bool has_multiple_devices = false;
384   absl::optional<std::string> location;
385   for (const Node* node : graph->op_nodes()) {
386     if (location) {
387       if (*location != node->assigned_device_name()) {
388         has_multiple_devices = true;
389         break;
390       }
391     } else {
392       location = node->assigned_device_name();
393     }
394   }
395   return has_multiple_devices;
396 }
397 
GetNameFromDevice(const std::string & device)398 std::string GetNameFromDevice(const std::string& device) {
399   std::string ret = device;
400   for (int i = 0; i < ret.size(); ++i) {
401     // Replace ':', as it is not allowed in node names.
402     if (ret[i] == ':') ret[i] = '_';
403   }
404   return ret;
405 }
406 
407 }  // namespace
408 
409 // This function performs the following steps:
410 // 1. Partition the graph and insert send/recv ops on the edges across devices.
411 // 2. For each partition, convert the subgraph to a function and invoke the
412 //    function by a PartitionedCallOp, so that these functions can be executed
413 //    asynchronousely.
414 // 3. Connect all PartitionedCallOps to a StatefulPartitionedCallOps to make
415 //    sure PartitionedCallOps are not pruned in the subsequent MLIR lowering
416 //    passes.
417 // 4. Create output nodes and control output nodes to match the original graph's
418 //    nodes.
InsertTransferOps(const std::string & graph_func_name,const DeviceSet & device_set,const Device * host_device,const std::vector<std::string> & inputs,const std::vector<std::string> & outputs,const std::vector<std::string> & control_outputs,std::unique_ptr<Graph> graph)419 StatusOr<std::unique_ptr<Graph>> InsertTransferOps(
420     const std::string& graph_func_name, const DeviceSet& device_set,
421     const Device* host_device, const std::vector<std::string>& inputs,
422     const std::vector<std::string>& outputs,
423     const std::vector<std::string>& control_outputs,
424     std::unique_ptr<Graph> graph) {
425   // Skip transfer op insertion if the graph nodes are not assigned to multiple
426   // devices.
427   if (!HasMultipleDevices(graph.get())) {
428     return graph;
429   }
430 
431   // Step 1: Partition the graph and insert send/recv ops on the edges across
432   // devices.
433   auto new_graph = std::make_unique<Graph>(graph->flib_def());
434   FunctionDefLibrary flib = graph->flib_def().ToProto();
435 
436   std::unordered_map<string, std::unique_ptr<Graph>> partitions;
437   TF_RETURN_IF_ERROR(
438       PartitionFunctionGraph(device_set, std::move(graph), &partitions));
439 
440   // Step 2: For each partition, convert the subgraph to a function and invoke
441   // the function by PartitionedCallOp from the top-level graph.
442 
443   absl::flat_hash_map<std::string, Node*> all_partitioned_call_ops;
444   std::map<std::string, OutputNodeInfo> device_to_output_info_map;
445 
446   for (auto& partition : partitions) {
447     const string& device = partition.first;
448     VLOG(1) << "Process the partitioin on device: " << device;
449 
450     Graph* subgraph = partition.second.get();
451     TF_RETURN_IF_ERROR(subgraph->AddFunctionLibrary(flib));
452 
453     FunctionNameGenerator name_generator(
454         &new_graph->flib_def(), absl::StrCat(graph_func_name, "-partition-",
455                                              GetNameFromDevice(device)));
456     std::string func_name = name_generator.GetName();
457 
458     absl::flat_hash_map<std::string, NodeInfo> input_nodes;
459 
460     OutputNodeInfo& output_node_info = device_to_output_info_map[device];
461     absl::flat_hash_map<std::string, NodeInfo>& output_nodes =
462         output_node_info.output_nodes;
463     absl::optional<std::pair<std::string, NodeInfo>>& auxiliary_output_node =
464         output_node_info.auxiliary_output_node;
465 
466     // Add _Arg and _Retval nodes to the subgraph to prepare for converting it
467     // to a function. Meanwhile, record input/output infos for the following
468     // processing.
469     TF_RETURN_IF_ERROR(PrepareSubgraphForFunctionConversion(
470         inputs, outputs, host_device, func_name, input_nodes, output_nodes,
471         auxiliary_output_node, subgraph, new_graph.get()));
472 
473     // Convert the subgraph to a function, and build a PartitionedCallOp to
474     // invoke the function.
475     TF_ASSIGN_OR_RETURN(
476         Node * call_node,
477         BuildPartitionedCallOp(func_name, host_device, device, input_nodes,
478                                output_nodes, auxiliary_output_node,
479                                control_outputs, subgraph, new_graph.get()));
480     all_partitioned_call_ops[device] = call_node;
481   }
482 
483   // Step 3: Create a StatefulPartitionedCallOp, and connect all
484   // PartitionedCallOps to it. The StatefulPartitionedCallOp behaves as a
485   // stateful IdentityN. This helps to preserve the PartitionedCallOps
486   // (stateless) in the TF MLIR lowering passes; otherwise, without a stateful
487   // consumer, PartitionedCallOps will be pruned, as control output info of
488   // the graph gets lost during TF MLIR lowering (b/232026253).
489 
490   // Collect all outputs from all partitions, and update their indices to be
491   // used for constructing StatefulPartitionedCallOp.
492   int input_index = 0;
493   absl::flat_hash_map<std::string, CallNodeInputInfo> call_node_input_info;
494   auto get_call_node_input_info = [&](const std::string& device,
495                                       const std::string& node_name,
496                                       const NodeInfo& node_info) {
497     CHECK(!call_node_input_info.contains(node_name));
498     CallNodeInputInfo& info = call_node_input_info[node_name];
499     info.index = input_index++;
500     info.data_type = node_info.data_type;
501     info.input_node = all_partitioned_call_ops.at(device);
502     info.input_node_index = node_info.index;
503   };
504   for (const auto& entry : device_to_output_info_map) {
505     const std::string& device = entry.first;
506     const OutputNodeInfo& output_info = entry.second;
507     for (const auto& node_info : output_info.output_nodes) {
508       get_call_node_input_info(device, node_info.first, node_info.second);
509     }
510     if (output_info.auxiliary_output_node) {
511       get_call_node_input_info(device, output_info.auxiliary_output_node->first,
512                                output_info.auxiliary_output_node->second);
513     }
514   }
515 
516   FunctionNameGenerator name_generator(
517       &new_graph->flib_def(),
518       absl::StrCat(graph_func_name, "/output_aggregator"));
519   std::string stateful_call_func_name = name_generator.GetName();
520   TF_ASSIGN_OR_RETURN(
521       Node * stateful_call_node,
522       BuildStatefulPartitionedCallOp(
523           call_node_input_info, all_partitioned_call_ops,
524           stateful_call_func_name, host_device, new_graph.get()));
525 
526   // Step 4: Create output nodes and control output nodes corresponding to the
527   // original graph's nodes.
528 
529   // For each of the original output, construct a corresponding Identity node
530   // with the same name.
531   for (const auto& node_info : call_node_input_info) {
532     TF_RETURN_IF_ERROR(NodeBuilder(node_info.first, "Identity")
533                            .Input(NodeBuilder::NodeOut(stateful_call_node,
534                                                        node_info.second.index))
535                            .Attr("T", node_info.second.data_type)
536                            .AssignedDevice(host_device->name())
537                            .Finalize(new_graph.get(), nullptr));
538   }
539 
540   // For each of the original control output, construct a corresponding Identity
541   // node with the same name.
542   CHECK_GT(stateful_call_node->num_outputs(), 0);
543   for (const auto& control_output : control_outputs) {
544     TF_RETURN_IF_ERROR(NodeBuilder(control_output, "Identity")
545                            .Input(NodeBuilder::NodeOut(stateful_call_node, 0))
546                            .Attr("T", stateful_call_node->output_type(0))
547                            .AssignedDevice(host_device->name())
548                            .Finalize(new_graph.get(), nullptr));
549   }
550 
551   return new_graph;
552 }
553 
554 }  // namespace tfrt_stub
555 }  // namespace tensorflow
556