1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tfrt/utils/graph_partition.h"
16
17 #include <algorithm>
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <unordered_map>
22 #include <utility>
23
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "tensorflow/core/common_runtime/graph_constructor.h"
27 #include "tensorflow/core/common_runtime/inline_function_utils.h"
28 #include "tensorflow/core/common_runtime/partitioning_utils.h"
29 #include "tensorflow/core/framework/function.h"
30 #include "tensorflow/core/framework/graph_to_functiondef.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/graph/graph.h"
33 #include "tensorflow/core/graph/graph_partition.h"
34 #include "tensorflow/core/graph/node_builder.h"
35 #include "tensorflow/core/grappler/utils.h"
36
37 namespace tensorflow {
38 namespace tfrt_stub {
39
40 namespace {
41
42 // An auxiliary struct to record input/output information.
43 struct NodeInfo {
44 Node* node = nullptr;
45 DataType data_type;
46 int index = -1;
47 Node* node_copy = nullptr;
48 };
49
50 // An auxiliary struct for construction a StatefulPartitionedCallOp enclosing an
51 // IdentityN node.
52 struct CallNodeInputInfo {
53 int index = -1;
54 DataType data_type;
55 Node* input_node = nullptr;
56 int input_node_index = -1;
57
58 Node* arg_node = nullptr;
59 Node* ret_node = nullptr;
60 };
61
62 struct OutputNodeInfo {
63 absl::flat_hash_map<std::string, NodeInfo> output_nodes;
64 absl::optional<std::pair<std::string, NodeInfo>> auxiliary_output_node;
65 };
66
67 // Prepares the `subgraph` for the conversion to a function by adding
68 // _Arg/_Retval nodes for input/output nodes respectively, and records
69 // input/output info for the following processing.
70 // TODO(b/217581711): Consider to use another GraphToFunctionDef() helper which
71 // does not require _Arg and _Retval nodes.
PrepareSubgraphForFunctionConversion(const std::vector<std::string> & inputs,const std::vector<std::string> & outputs,const Device * host_device,const std::string & func_name,absl::flat_hash_map<std::string,NodeInfo> & input_nodes,absl::flat_hash_map<std::string,NodeInfo> & output_nodes,absl::optional<std::pair<std::string,NodeInfo>> & auxiliary_output_node,Graph * subgraph,Graph * graph)72 Status PrepareSubgraphForFunctionConversion(
73 const std::vector<std::string>& inputs,
74 const std::vector<std::string>& outputs, const Device* host_device,
75 const std::string& func_name,
76 absl::flat_hash_map<std::string, NodeInfo>& input_nodes,
77 absl::flat_hash_map<std::string, NodeInfo>& output_nodes,
78 absl::optional<std::pair<std::string, NodeInfo>>& auxiliary_output_node,
79 Graph* subgraph, Graph* graph) {
80 std::unordered_map<std::string, Node*> name_to_node_map =
81 subgraph->BuildNodeNameIndex();
82
83 int input_index = 0, output_index = 0;
84
85 // For each input node in this subgraph, replace it with an _Arg node.
86 for (const auto& input : inputs) {
87 int position = -1;
88 std::string node_name = grappler::ParseNodeName(input, &position);
89 if (position != 0) {
90 return errors::Unimplemented(
91 "Support for input node with multiple output tensors is not "
92 "implemented.");
93 }
94 if (name_to_node_map.count(node_name) == 0) continue;
95
96 Node* node = name_to_node_map.at(node_name);
97 NodeInfo node_info;
98 node_info.node = node;
99 node_info.data_type = node->output_type(position);
100 node_info.index = input_index++;
101 // Copy the input node which will be removed from the subgraph below.
102 // The copied node will be used in the top-level graph.
103 node_info.node_copy = graph->CopyNode(node);
104
105 input_nodes.emplace(node->name(), node_info);
106
107 // Create an _Arg node to replace the input node.
108 TF_ASSIGN_OR_RETURN(
109 Node * arg_node,
110 NodeBuilder(absl::StrCat("arg_", node_info.index, "/", node->name()),
111 "_Arg")
112 .Attr("index", node_info.index)
113 .Attr("T", node_info.data_type)
114 .Finalize(subgraph));
115
116 CHECK_EQ(node->num_inputs(), 0);
117 std::vector<const Edge*> out_edges(node->out_edges().begin(),
118 node->out_edges().end());
119 for (const Edge* edge : out_edges) {
120 if (edge->IsControlEdge()) {
121 subgraph->AddControlEdge(arg_node, edge->dst());
122 } else {
123 TF_RETURN_IF_ERROR(
124 subgraph->UpdateEdge(arg_node, 0, edge->dst(), edge->dst_input()));
125 }
126 }
127 subgraph->RemoveNode(node);
128 }
129
130 // For each output node in this subgraph, connect it to a _Retval node.
131 for (const auto& output : outputs) {
132 int position = -1;
133 std::string node_name = grappler::ParseNodeName(output, &position);
134 if (position != 0) {
135 return errors::Unimplemented(
136 "Support for output node with multiple output tensors is not "
137 "implemented.");
138 }
139 if (name_to_node_map.count(node_name) == 0) continue;
140
141 Node* node = name_to_node_map.at(node_name);
142 NodeInfo node_info;
143 node_info.node = node;
144 node_info.data_type = node->output_type(position);
145 node_info.index = output_index++;
146
147 output_nodes.emplace(node->name(), node_info);
148
149 // Create a _RetArg node, and append it to the original output node.
150 TF_ASSIGN_OR_RETURN(
151 Node * ret_node,
152 NodeBuilder(absl::StrCat("ret_", node_info.index, "/", node->name()),
153 "_Retval")
154 .Attr("index", node_info.index)
155 .Attr("T", node_info.data_type)
156 .Input(NodeBuilder::NodeOut(node->name(), position,
157 node_info.data_type))
158 .Finalize(subgraph));
159 // Rename the output node, as there will be a node in the top level with
160 // the same name.
161 node->set_name(node->name() + "/partition_renamed");
162
163 subgraph->AddEdge(node, 0, ret_node, 0);
164 }
165
166 // If there is no output for this partition, create an auxiliary output, so
167 // that we can generate a data dependency from the PartitionedCallOp (the
168 // one we are going to create to wrap this partition) to a downstream
169 // stateful node. This helps to preserve the stateless PartitionedCallOp in
170 // the subsequent MLIR lowering passes; otherwise, it will be pruned if there
171 // is only a control dependency between PartitionedCallOp and another op
172 // node, because PartitionedCallOp is stateless and the control dependency
173 // will get lost during MLIR lowering with current side effect analysis
174 // (b/232026253).
175 if (output_nodes.empty()) {
176 // Create a const node.
177 const DataType data_type = DT_INT32;
178 TensorShape const_shape;
179 Tensor const_tensor(data_type, const_shape);
180 const_tensor.flat<int>()(0) = 0;
181 TF_ASSIGN_OR_RETURN(
182 Node * const_node,
183 NodeBuilder(absl::StrCat("const/unused/", func_name), "Const")
184 .AssignedDevice(host_device->name())
185 .Attr("dtype", data_type)
186 .Attr("value", const_tensor)
187 .Finalize(subgraph));
188
189 NodeInfo node_info;
190 node_info.node = const_node;
191 node_info.data_type = data_type;
192 node_info.index = output_index++;
193 auxiliary_output_node.emplace(const_node->name(), node_info);
194
195 // Create a _RetArg node, and append to the const node created above.
196 TF_ASSIGN_OR_RETURN(
197 Node * ret_node,
198 NodeBuilder(
199 absl::StrCat("ret_", node_info.index, "/", const_node->name()),
200 "_Retval")
201 .Attr("index", node_info.index)
202 .Attr("T", data_type)
203 .Input(NodeBuilder::NodeOut(const_node->name(), 0, data_type))
204 .Finalize(subgraph));
205
206 subgraph->AddEdge(const_node, 0, ret_node, 0);
207 }
208 return OkStatus();
209 }
210
211 // Converts the subgraph to a function, and builds a PartitionedCallOp
212 // to invoke the function.
BuildPartitionedCallOp(const std::string & func_name,const Device * host_device,const std::string & device,const absl::flat_hash_map<std::string,NodeInfo> & input_nodes,const absl::flat_hash_map<std::string,NodeInfo> & output_nodes,const absl::optional<std::pair<std::string,NodeInfo>> & auxiliary_output_node,const std::vector<std::string> & control_outputs,Graph * subgraph,Graph * graph)213 StatusOr<Node*> BuildPartitionedCallOp(
214 const std::string& func_name, const Device* host_device,
215 const std::string& device,
216 const absl::flat_hash_map<std::string, NodeInfo>& input_nodes,
217 const absl::flat_hash_map<std::string, NodeInfo>& output_nodes,
218 const absl::optional<std::pair<std::string, NodeInfo>>&
219 auxiliary_output_node,
220 const std::vector<std::string>& control_outputs, Graph* subgraph,
221 Graph* graph) {
222 // Build the call node.
223 std::string call_node_name = absl::StrCat("partitioned_call/", func_name);
224 NodeBuilder call_builder(call_node_name, "PartitionedCall");
225 call_builder.AssignedDevice(host_device->name());
226 call_builder.Attr(tensorflow::kNoInlineAttr, true);
227
228 std::vector<DataType> input_dtypes(input_nodes.size());
229 for (const auto& input_node : input_nodes) {
230 input_dtypes[input_node.second.index] = input_node.second.data_type;
231 }
232 call_builder.Attr("Tin", input_dtypes);
233
234 CHECK(auxiliary_output_node ? output_nodes.empty() : !output_nodes.empty());
235 std::vector<DataType> output_dtypes(
236 auxiliary_output_node ? 1 : output_nodes.size());
237 if (auxiliary_output_node) {
238 CHECK_EQ(auxiliary_output_node->second.index, 0);
239 output_dtypes[auxiliary_output_node->second.index] =
240 auxiliary_output_node->second.data_type;
241 } else {
242 for (const auto& output_node : output_nodes) {
243 output_dtypes[output_node.second.index] = output_node.second.data_type;
244 }
245 }
246 call_builder.Attr("Tout", output_dtypes);
247
248 std::vector<NodeBuilder::NodeOut> call_node_inputs(input_nodes.size());
249 for (const auto& input_node : input_nodes) {
250 call_node_inputs[input_node.second.index] =
251 NodeBuilder::NodeOut(input_node.second.node_copy, 0);
252 }
253 call_builder.Input(call_node_inputs);
254
255 NameAttrList f;
256 f.set_name(func_name);
257 call_builder.Attr("f", f);
258 TF_ASSIGN_OR_RETURN(Node * call_node, call_builder.Finalize(graph));
259
260 // Convert the subgraph to a function.
261 absl::flat_hash_set<std::string> control_ret_names(control_outputs.begin(),
262 control_outputs.end());
263 // After graph partition, there are send ops added as new end nodes.
264 // The completion of the graph requires the send ops to be executed.
265 for (const Node* node : subgraph->op_nodes()) {
266 if (node->IsSend()) {
267 control_ret_names.insert(node->name());
268 }
269 }
270 auto control_ret_node_names =
271 [&control_ret_names](const Node* node) -> absl::optional<std::string> {
272 if (control_ret_names.contains(node->name())) {
273 return node->name();
274 }
275 return absl::nullopt;
276 };
277
278 FunctionDef new_fdef;
279 TF_RETURN_IF_ERROR(GraphToFunctionDef(*subgraph, func_name,
280 control_ret_node_names, &new_fdef));
281 // Set the `_noinline` attribute for the function to make sure it does not
282 // get inlined, and its corresponding function in TF MLIR does not get inlined
283 // in lowering passes as well.
284 (*new_fdef.mutable_attr())[tensorflow::kNoInlineAttr].set_b(true);
285 (*new_fdef.mutable_attr())["device"].set_s(device);
286 TF_RETURN_IF_ERROR(graph->mutable_flib_def()->AddFunctionDef(new_fdef));
287
288 return call_node;
289 }
290
291 // Builds a StatefulPartitionedCallOp, and connects all PartitionedCallOps to
292 // it. This StatefulPartitionedCallOp behaves as a stateful IdentityN.
BuildStatefulPartitionedCallOp(absl::flat_hash_map<std::string,CallNodeInputInfo> & call_node_input_info,const absl::flat_hash_map<std::string,Node * > & all_partitioned_call_ops,const std::string & stateful_call_func_name,const Device * host_device,Graph * graph)293 StatusOr<Node*> BuildStatefulPartitionedCallOp(
294 absl::flat_hash_map<std::string, CallNodeInputInfo>& call_node_input_info,
295 const absl::flat_hash_map<std::string, Node*>& all_partitioned_call_ops,
296 const std::string& stateful_call_func_name, const Device* host_device,
297 Graph* graph) {
298 std::string call_node_name =
299 absl::StrCat("stateful_partitioned_call/", stateful_call_func_name);
300 NodeBuilder call_builder(call_node_name, "StatefulPartitionedCall");
301 call_builder.Attr(tensorflow::kNoInlineAttr, true);
302 call_builder.AssignedDevice(host_device->name());
303
304 int num_output_nodes = call_node_input_info.size();
305 std::vector<DataType> input_dtypes(num_output_nodes);
306 for (const auto& node_info : call_node_input_info) {
307 CHECK(node_info.second.index < num_output_nodes);
308 input_dtypes[node_info.second.index] = node_info.second.data_type;
309 }
310 call_builder.Attr("Tin", input_dtypes);
311 // Outputs are the same as inputs.
312 call_builder.Attr("Tout", input_dtypes);
313
314 std::vector<NodeBuilder::NodeOut> call_node_inputs(num_output_nodes);
315 for (const auto& node_info : call_node_input_info) {
316 call_node_inputs[node_info.second.index] = NodeBuilder::NodeOut(
317 node_info.second.input_node, node_info.second.input_node_index);
318 }
319 call_builder.Input(call_node_inputs);
320
321 NameAttrList f;
322 f.set_name(stateful_call_func_name);
323 call_builder.Attr("f", f);
324 TF_ASSIGN_OR_RETURN(Node * stateful_call_node, call_builder.Finalize(graph));
325
326 // Construct a graph that only contains an IdentityN node, and convert the
327 // graph to a function.
328 auto id_graph = std::make_unique<Graph>(graph->flib_def().default_registry());
329
330 std::vector<NodeBuilder::NodeOut> output_tensors(num_output_nodes);
331
332 // Create an _Arg node for each input.
333 for (auto& node_info : call_node_input_info) {
334 TF_ASSIGN_OR_RETURN(node_info.second.arg_node,
335 NodeBuilder(absl::StrCat("arg_", node_info.second.index,
336 "/", stateful_call_func_name),
337 "_Arg")
338 .Attr("index", node_info.second.index)
339 .Attr("T", node_info.second.data_type)
340 .Finalize(id_graph.get()));
341
342 output_tensors[node_info.second.index] =
343 NodeBuilder::NodeOut(node_info.second.arg_node, 0);
344 }
345
346 // Create the Identity Node.
347 TF_ASSIGN_OR_RETURN(
348 Node * identity_node,
349 NodeBuilder(absl::StrCat("identityN", "/", stateful_call_func_name),
350 "IdentityN")
351 .AssignedDevice(host_device->name())
352 .Input(output_tensors)
353 .Finalize(id_graph.get()));
354
355 // Create a _Retval node for each output.
356 for (auto& node_info : call_node_input_info) {
357 TF_ASSIGN_OR_RETURN(
358 node_info.second.ret_node,
359 NodeBuilder(absl::StrCat("ret_", node_info.second.index, "/",
360 stateful_call_func_name),
361 "_Retval")
362 .Attr("index", node_info.second.index)
363 .Attr("T", node_info.second.data_type)
364 .Input(NodeBuilder::NodeOut(identity_node, node_info.second.index))
365 .Finalize(id_graph.get()));
366
367 id_graph->AddEdge(identity_node, node_info.second.index,
368 node_info.second.ret_node, 0);
369 }
370
371 // Convert the id_graph to a function.
372 FunctionDef id_fdef;
373 TF_RETURN_IF_ERROR(
374 GraphToFunctionDef(*id_graph, stateful_call_func_name, &id_fdef));
375 (*id_fdef.mutable_attr())[tensorflow::kNoInlineAttr].set_b(true);
376 TF_RETURN_IF_ERROR(graph->mutable_flib_def()->AddFunctionDef(id_fdef));
377
378 return stateful_call_node;
379 }
380
381 // Returns true if nodes in the `graph` are assigned to multiple devices.
HasMultipleDevices(const Graph * graph)382 bool HasMultipleDevices(const Graph* graph) {
383 bool has_multiple_devices = false;
384 absl::optional<std::string> location;
385 for (const Node* node : graph->op_nodes()) {
386 if (location) {
387 if (*location != node->assigned_device_name()) {
388 has_multiple_devices = true;
389 break;
390 }
391 } else {
392 location = node->assigned_device_name();
393 }
394 }
395 return has_multiple_devices;
396 }
397
GetNameFromDevice(const std::string & device)398 std::string GetNameFromDevice(const std::string& device) {
399 std::string ret = device;
400 for (int i = 0; i < ret.size(); ++i) {
401 // Replace ':', as it is not allowed in node names.
402 if (ret[i] == ':') ret[i] = '_';
403 }
404 return ret;
405 }
406
407 } // namespace
408
409 // This function performs the following steps:
410 // 1. Partition the graph and insert send/recv ops on the edges across devices.
411 // 2. For each partition, convert the subgraph to a function and invoke the
412 // function by a PartitionedCallOp, so that these functions can be executed
413 // asynchronousely.
414 // 3. Connect all PartitionedCallOps to a StatefulPartitionedCallOps to make
415 // sure PartitionedCallOps are not pruned in the subsequent MLIR lowering
416 // passes.
417 // 4. Create output nodes and control output nodes to match the original graph's
418 // nodes.
InsertTransferOps(const std::string & graph_func_name,const DeviceSet & device_set,const Device * host_device,const std::vector<std::string> & inputs,const std::vector<std::string> & outputs,const std::vector<std::string> & control_outputs,std::unique_ptr<Graph> graph)419 StatusOr<std::unique_ptr<Graph>> InsertTransferOps(
420 const std::string& graph_func_name, const DeviceSet& device_set,
421 const Device* host_device, const std::vector<std::string>& inputs,
422 const std::vector<std::string>& outputs,
423 const std::vector<std::string>& control_outputs,
424 std::unique_ptr<Graph> graph) {
425 // Skip transfer op insertion if the graph nodes are not assigned to multiple
426 // devices.
427 if (!HasMultipleDevices(graph.get())) {
428 return graph;
429 }
430
431 // Step 1: Partition the graph and insert send/recv ops on the edges across
432 // devices.
433 auto new_graph = std::make_unique<Graph>(graph->flib_def());
434 FunctionDefLibrary flib = graph->flib_def().ToProto();
435
436 std::unordered_map<string, std::unique_ptr<Graph>> partitions;
437 TF_RETURN_IF_ERROR(
438 PartitionFunctionGraph(device_set, std::move(graph), &partitions));
439
440 // Step 2: For each partition, convert the subgraph to a function and invoke
441 // the function by PartitionedCallOp from the top-level graph.
442
443 absl::flat_hash_map<std::string, Node*> all_partitioned_call_ops;
444 std::map<std::string, OutputNodeInfo> device_to_output_info_map;
445
446 for (auto& partition : partitions) {
447 const string& device = partition.first;
448 VLOG(1) << "Process the partitioin on device: " << device;
449
450 Graph* subgraph = partition.second.get();
451 TF_RETURN_IF_ERROR(subgraph->AddFunctionLibrary(flib));
452
453 FunctionNameGenerator name_generator(
454 &new_graph->flib_def(), absl::StrCat(graph_func_name, "-partition-",
455 GetNameFromDevice(device)));
456 std::string func_name = name_generator.GetName();
457
458 absl::flat_hash_map<std::string, NodeInfo> input_nodes;
459
460 OutputNodeInfo& output_node_info = device_to_output_info_map[device];
461 absl::flat_hash_map<std::string, NodeInfo>& output_nodes =
462 output_node_info.output_nodes;
463 absl::optional<std::pair<std::string, NodeInfo>>& auxiliary_output_node =
464 output_node_info.auxiliary_output_node;
465
466 // Add _Arg and _Retval nodes to the subgraph to prepare for converting it
467 // to a function. Meanwhile, record input/output infos for the following
468 // processing.
469 TF_RETURN_IF_ERROR(PrepareSubgraphForFunctionConversion(
470 inputs, outputs, host_device, func_name, input_nodes, output_nodes,
471 auxiliary_output_node, subgraph, new_graph.get()));
472
473 // Convert the subgraph to a function, and build a PartitionedCallOp to
474 // invoke the function.
475 TF_ASSIGN_OR_RETURN(
476 Node * call_node,
477 BuildPartitionedCallOp(func_name, host_device, device, input_nodes,
478 output_nodes, auxiliary_output_node,
479 control_outputs, subgraph, new_graph.get()));
480 all_partitioned_call_ops[device] = call_node;
481 }
482
483 // Step 3: Create a StatefulPartitionedCallOp, and connect all
484 // PartitionedCallOps to it. The StatefulPartitionedCallOp behaves as a
485 // stateful IdentityN. This helps to preserve the PartitionedCallOps
486 // (stateless) in the TF MLIR lowering passes; otherwise, without a stateful
487 // consumer, PartitionedCallOps will be pruned, as control output info of
488 // the graph gets lost during TF MLIR lowering (b/232026253).
489
490 // Collect all outputs from all partitions, and update their indices to be
491 // used for constructing StatefulPartitionedCallOp.
492 int input_index = 0;
493 absl::flat_hash_map<std::string, CallNodeInputInfo> call_node_input_info;
494 auto get_call_node_input_info = [&](const std::string& device,
495 const std::string& node_name,
496 const NodeInfo& node_info) {
497 CHECK(!call_node_input_info.contains(node_name));
498 CallNodeInputInfo& info = call_node_input_info[node_name];
499 info.index = input_index++;
500 info.data_type = node_info.data_type;
501 info.input_node = all_partitioned_call_ops.at(device);
502 info.input_node_index = node_info.index;
503 };
504 for (const auto& entry : device_to_output_info_map) {
505 const std::string& device = entry.first;
506 const OutputNodeInfo& output_info = entry.second;
507 for (const auto& node_info : output_info.output_nodes) {
508 get_call_node_input_info(device, node_info.first, node_info.second);
509 }
510 if (output_info.auxiliary_output_node) {
511 get_call_node_input_info(device, output_info.auxiliary_output_node->first,
512 output_info.auxiliary_output_node->second);
513 }
514 }
515
516 FunctionNameGenerator name_generator(
517 &new_graph->flib_def(),
518 absl::StrCat(graph_func_name, "/output_aggregator"));
519 std::string stateful_call_func_name = name_generator.GetName();
520 TF_ASSIGN_OR_RETURN(
521 Node * stateful_call_node,
522 BuildStatefulPartitionedCallOp(
523 call_node_input_info, all_partitioned_call_ops,
524 stateful_call_func_name, host_device, new_graph.get()));
525
526 // Step 4: Create output nodes and control output nodes corresponding to the
527 // original graph's nodes.
528
529 // For each of the original output, construct a corresponding Identity node
530 // with the same name.
531 for (const auto& node_info : call_node_input_info) {
532 TF_RETURN_IF_ERROR(NodeBuilder(node_info.first, "Identity")
533 .Input(NodeBuilder::NodeOut(stateful_call_node,
534 node_info.second.index))
535 .Attr("T", node_info.second.data_type)
536 .AssignedDevice(host_device->name())
537 .Finalize(new_graph.get(), nullptr));
538 }
539
540 // For each of the original control output, construct a corresponding Identity
541 // node with the same name.
542 CHECK_GT(stateful_call_node->num_outputs(), 0);
543 for (const auto& control_output : control_outputs) {
544 TF_RETURN_IF_ERROR(NodeBuilder(control_output, "Identity")
545 .Input(NodeBuilder::NodeOut(stateful_call_node, 0))
546 .Attr("T", stateful_call_node->output_type(0))
547 .AssignedDevice(host_device->name())
548 .Finalize(new_graph.get(), nullptr));
549 }
550
551 return new_graph;
552 }
553
554 } // namespace tfrt_stub
555 } // namespace tensorflow
556