• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.h"
17 
18 #include <deque>
19 #include <map>
20 #include <unordered_map>
21 #include <unordered_set>
22 
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/container/node_hash_set.h"
25 #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
26 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
27 #include "tensorflow/core/graph/algorithm.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/core/lib/gtl/cleanup.h"
31 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
32 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.h"
33 
34 namespace tensorflow {
35 namespace tpu {
36 
37 namespace {
38 
39 constexpr char kDefaultShardingValue[] = "";
40 
FindEdgeConnecting(const Node * src,const Node * dst)41 const Edge* FindEdgeConnecting(const Node* src, const Node* dst) {
42   for (const auto e : src->out_edges()) {
43     if (e->dst()->name() == dst->name()) return &(*e);
44   }
45   return nullptr;
46 }
47 
48 // Contains TPUExecute node and its DT_RESOURCE input nodes that
49 // correspond to model weights.
50 struct ExecuteNodeInfo {
51   Node* execute_node;
52   std::vector<const Edge*> var_inputs;
53 };
54 
55 // Returns whether `node` is in `execute_nodes` or `(identity) -> execute`.
IsExecuteNodeOrIdentityToExecuteNode(const Graph & graph,const std::unordered_set<Node * > & loop_nodes,const absl::flat_hash_set<Node * > & execute_nodes,Node * node)56 bool IsExecuteNodeOrIdentityToExecuteNode(
57     const Graph& graph, const std::unordered_set<Node*>& loop_nodes,  // NOLINT
58     const absl::flat_hash_set<Node*>& execute_nodes, Node* node) {
59   if (execute_nodes.find(node) != execute_nodes.end()) return true;
60   if (loop_nodes.find(node) == loop_nodes.end()) return false;
61   if (node->IsNextIteration()) return true;
62   if (!node->IsIdentity()) return false;
63 
64   for (const Edge* e : node->out_edges()) {
65     if (e->IsControlEdge()) continue;
66 
67     Node* node = e->dst();
68     if (!IsExecuteNodeOrIdentityToExecuteNode(graph, loop_nodes, execute_nodes,
69                                               node)) {
70       return false;
71     }
72   }
73 
74   return true;
75 }
76 
77 // From input node to the TPUExecute op, finds the corresponding Enter node
78 // by searching/traversing nodes in below pattern of nodes:
79 // Enter ----> (identity) --->  While body input
80 // Returns nullptr if the Enter node is not found.
FindEnterNodeFromTPUExecuteNodeInput(Node * input_node)81 xla::StatusOr<Node*> FindEnterNodeFromTPUExecuteNodeInput(Node* input_node) {
82   Node* node = input_node;
83   while (node->IsIdentity()) {
84     TF_RETURN_IF_ERROR(node->input_node(0, &node));
85   }
86 
87   if (node->IsEnter()) {
88     return node;
89   }
90   return nullptr;
91 }
92 
ResourceOnlyUsedForTPUExecuteInLoop(const Graph & graph,const std::unordered_set<Node * > & loop_nodes,const Node * enter_node,const absl::flat_hash_set<Node * > execute_nodes)93 xla::StatusOr<bool> ResourceOnlyUsedForTPUExecuteInLoop(
94     const Graph& graph, const std::unordered_set<Node*>& loop_nodes,  // NOLINT
95     const Node* enter_node, const absl::flat_hash_set<Node*> execute_nodes) {
96   for (const Edge* output_edge : enter_node->out_edges()) {
97     Node* output_node = output_edge->dst();
98     if (output_edge->IsControlEdge() || output_node->IsExit()) continue;
99 
100     // If output node is not execute node, it must be output node
101     // to the while loop body.
102     if (!IsExecuteNodeOrIdentityToExecuteNode(graph, loop_nodes, execute_nodes,
103                                               output_node)) {
104       return false;
105     }
106   }
107   return true;
108 }
109 
110 // Given a TPUCompile node, find all TPUExecute nodes that executes the compiled
111 // program and its model weight variable inputs as well.
112 // TPUCompileMetadataProto of TPUCompile node must be reset to `new_metadata`
113 // if new reshard ops are added.
ExtractExecuteNodeInfo(const Node * compile_node,const Graph & graph,const std::unordered_set<Node * > & loop_nodes,std::vector<ExecuteNodeInfo> * execute_node_info,TPUCompileMetadataProto * new_metadata)114 Status ExtractExecuteNodeInfo(const Node* compile_node, const Graph& graph,
115                               const std::unordered_set<Node*>& loop_nodes,  // NOLINT
116                               std::vector<ExecuteNodeInfo>* execute_node_info,
117                               TPUCompileMetadataProto* new_metadata) {
118   string metadata_string;
119   TF_RETURN_IF_ERROR(
120       GetNodeAttr(compile_node->attrs(), "metadata", &metadata_string));
121   new_metadata->ParsePartialFromString(metadata_string);
122   if (new_metadata->num_cores_per_replica() != 1) {
123     // We do not support model parallelism yet.
124     return Status::OK();
125   }
126 
127   execute_node_info->clear();
128   for (Node* node : compile_node->out_nodes()) {
129     if (node->type_string() == "TPUExecute") {
130       execute_node_info->push_back({node});
131     }
132   }
133   if (execute_node_info->empty()) {
134     return Status::OK();
135   }
136   TF_RET_CHECK(execute_node_info->size() == new_metadata->num_replicas())
137       << "Number of replicas does not equal number of execute nodes: "
138       << new_metadata->num_replicas() << " vs " << execute_node_info->size();
139   DataTypeVector arg_types;
140   TF_RETURN_IF_ERROR(GetNodeAttr((*execute_node_info)[0].execute_node->attrs(),
141                                  "Targs", &arg_types));
142   for (int64_t i = 0; i < arg_types.size(); ++i) {
143     if (arg_types[i] != DT_RESOURCE) {
144       continue;
145     }
146     const auto sharding_config = new_metadata->args(i).enable_xla_sharding();
147     if (sharding_config != TPUCompileMetadataProto::Arg::TENTATIVE &&
148         sharding_config != TPUCompileMetadataProto::Arg::ALLOWED) {
149       continue;
150     }
151     std::vector<const Edge*> edges(execute_node_info->size());
152     bool is_supported = true;
153     std::unordered_map<Node*, absl::flat_hash_set<Node*>>
154         enter_to_execute_nodes;
155     for (int64_t j = 0; j < edges.size(); ++j) {
156       auto execute = (*execute_node_info)[j].execute_node;
157       TF_RETURN_IF_ERROR(execute->input_edge(i, &edges[j]));
158       TF_RET_CHECK(edges[j]->src()->output_type(edges[j]->src_output()) ==
159                    arg_types[i])
160           << "Execute op has an unexpected input type.";
161       // Traverse backwards to find the Enter node from which the input is
162       // passed.
163       // This makes sure that we are checking the usages of all potential
164       // aliases of the input node as well.
165       TF_ASSIGN_OR_RETURN(auto enter_node, FindEnterNodeFromTPUExecuteNodeInput(
166                                                edges[j]->src()));
167       if (enter_node == nullptr) {
168         is_supported = false;
169         enter_to_execute_nodes.clear();
170         break;
171       }
172       enter_to_execute_nodes[enter_node].insert(edges[j]->dst());
173     }
174 
175     for (const auto& it : enter_to_execute_nodes) {
176       // Size of execute nodes should be either 1 (per-replica variables) or
177       // num_replicas (distributed variables).
178       if ((it.second.size() != 1) &&
179           (it.second.size() != new_metadata->num_replicas())) {
180         is_supported = false;
181         break;
182       }
183       TF_ASSIGN_OR_RETURN(bool no_other_use,
184                           ResourceOnlyUsedForTPUExecuteInLoop(
185                               graph, loop_nodes, it.first, it.second));
186       if (!no_other_use) {
187         is_supported = false;
188         break;
189       }
190     }
191 
192     // Add the variable input edges only when they are supported for all
193     // executes.
194     if (is_supported) {
195       for (int64_t j = 0; j < edges.size(); ++j) {
196         (*execute_node_info)[j].var_inputs.push_back(edges[j]);
197       }
198       new_metadata->mutable_args(i)->set_enable_xla_sharding(
199           TPUCompileMetadataProto::Arg::ALLOWED);
200     }
201   }
202 
203   int64_t total = 0;
204   for (const auto& a : new_metadata->args()) {
205     if (a.enable_xla_sharding() == TPUCompileMetadataProto::Arg::ALLOWED) {
206       total++;
207     }
208   }
209   TF_RET_CHECK(total == (*execute_node_info)[0].var_inputs.size())
210       << " total " << total << " var_inputs "
211       << (*execute_node_info)[0].var_inputs.size();
212   if (total == 0) {
213     // We don't need to process anything if no input is added.
214     execute_node_info->clear();
215   }
216   return Status::OK();
217 }
218 
IsTPUCompileOp(const Node & n)219 bool IsTPUCompileOp(const Node& n) { return n.type_string() == "TPUCompile"; }
220 
FindTPUCompileNodes(const std::string * current_function_name,const AttrValueMap * current_function_attr,const std::unordered_map<string,WhileLoopFrame> & frames,std::vector<HostTrainingLoopInfo> * host_training_loops_info)221 void FindTPUCompileNodes(
222     const std::string* current_function_name,
223     const AttrValueMap* current_function_attr,
224     const std::unordered_map<string, WhileLoopFrame>& frames,
225     std::vector<HostTrainingLoopInfo>* host_training_loops_info) {
226   // Adds frames with no children (i.e., the innermost frames) to a worklist.
227   std::deque<const WhileLoopFrame*> worklist;
228 
229   for (auto& frame : frames) {
230     if (frame.second.num_children == 0) {
231       worklist.push_back(&frame.second);
232     }
233   }
234 
235   // Check TPUCompile node from the innermost while loop to the outermost
236   // while loop.
237   while (!worklist.empty()) {
238     const WhileLoopFrame* frame = worklist.front();
239     worklist.pop_front();
240 
241     for (const auto& n : frame->nodes) {
242       if (!IsTPUCompileOp(*n)) continue;
243 
244       HostTrainingLoopInfo host_training_loop_info;
245       host_training_loop_info.compile_node_name = n->name();
246       host_training_loop_info.loop_cond_node_name = frame->loop_cond->name();
247       host_training_loop_info.while_loop_name = frame->name;
248 
249       for (const auto arg : frame->args) {
250         LoopArgInfo arg_info;
251         arg_info.enter_node_name = arg.enter->name();
252         if (arg.exit) arg_info.exit_node_name = arg.exit->name();
253 
254         host_training_loop_info.loop_arguments.push_back(std::move(arg_info));
255       }
256       host_training_loop_info.loop_nodes = frame->nodes;
257 
258       if (current_function_name) {
259         host_training_loop_info.encapsulating_function_name =
260             *current_function_name;
261       }
262       if (current_function_attr) {
263         host_training_loop_info.encapsulating_function_attrs =
264             *current_function_attr;
265       }
266 
267       host_training_loops_info->emplace_back(
268           std::move(host_training_loop_info));
269     }
270 
271     // If the parent has no remaining children, add it to the worklist.
272     --frame->parent->num_children;
273     if (frame->parent->num_children == 0) {
274       worklist.push_back(frame->parent);
275     }
276   }
277 }
278 
279 // From while loop cond node, finds all loop exit nodes by searching/traversing
280 // nodes in below pattern of nodes:
281 // LoopCond -----> Switch -----> Exit
FindLoopExitNodes(const Node & loop_cond)282 std::vector<Node*> FindLoopExitNodes(const Node& loop_cond) {
283   std::vector<Node*> loop_exit_nodes;
284   for (const auto e_cond : loop_cond.out_edges()) {
285     if (e_cond->IsControlEdge() || !e_cond->dst()->IsSwitch()) continue;
286     auto switch_node = e_cond->dst();
287 
288     for (const auto e_switch : switch_node->out_edges()) {
289       if (e_switch->IsControlEdge() || !e_switch->dst()->IsExit()) continue;
290 
291       loop_exit_nodes.push_back(e_switch->dst());
292     }
293   }
294   return loop_exit_nodes;
295 }
296 
297 // Returns or creates a node in that is executed before each loop iteration
298 // in the while loop.
299 // TODO(mdan): Inject this node between the Enter and Merge nodes instead.
300 // See AddNoOpAfterLastIteration for an example.
GetOrCreateBeforeEachIterationNode(const Node & loop_cond_node,Graph * graph,Node ** node_out)301 Status GetOrCreateBeforeEachIterationNode(const Node& loop_cond_node,
302                                           Graph* graph, Node** node_out) {
303   Node* loop_switch_node = nullptr;
304   for (auto n : loop_cond_node.out_nodes()) {
305     if (n->IsSwitch()) {
306       loop_switch_node = n;
307       break;
308     }
309   }
310   TF_RET_CHECK(loop_switch_node != nullptr)
311       << "Unable to find any switch nodes.";
312 
313   // If while loop switch node already has a outgoing data to true brach
314   // of the switch op, then reuse that node.
315   for (const auto out_edge : loop_switch_node->out_edges()) {
316     if (out_edge->src_output() == 1) {
317       *node_out = out_edge->dst();
318       return Status::OK();
319     }
320   }
321 
322   // Create Identity node that represents execution at every loop iteration.
323   NodeDef at_loop_iteration_nodedef;
324   at_loop_iteration_nodedef.set_op("Identity");
325   DataType dtype;
326   TF_RETURN_IF_ERROR(GetNodeAttr(loop_switch_node->def(), "T", &dtype));
327 
328   AddNodeAttr("T", dtype, &at_loop_iteration_nodedef);
329   at_loop_iteration_nodedef.set_name(graph->NewName(strings::StrCat(
330       "TPUVariableReshard/before_iteration", "/_", internal::GetNodeId())));
331 
332   Status status;
333   Node* at_loop_iteration_node =
334       graph->AddNode(at_loop_iteration_nodedef, &status);
335   TF_RETURN_IF_ERROR(status);
336 
337   graph->AddEdge(loop_switch_node, 1, at_loop_iteration_node, 0);
338   *node_out = at_loop_iteration_node;
339   return Status::OK();
340 }
341 
342 // Injects a NoOp node in that is executed after the very last iteration
343 // of the while loop but before the while loop exit node.
344 // This node is positioned between the False output of all Switch nodes (
345 // meaning, it executes after the loop ended all its iterations) and their
346 // corresponding Exit nodes (meaning, before the loop finally completed).
347 // See the white paper for details:
348 // http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
AddNoOpAfterLastIteration(const Node & loop_cond_node,Graph * graph,Node ** node_out)349 Status AddNoOpAfterLastIteration(const Node& loop_cond_node, Graph* graph,
350                                  Node** node_out) {
351   NodeDef after_last_iteration;
352   after_last_iteration.set_op("NoOp");
353 
354   after_last_iteration.set_name(graph->NewName(strings::StrCat(
355       "TPUVariableReshard/after_last_iteration", "/_", internal::GetNodeId())));
356 
357   Status status;
358   Node* after_last_iteration_node =
359       graph->AddNode(after_last_iteration, &status);
360   TF_RETURN_IF_ERROR(status);
361 
362   for (auto switch_node : loop_cond_node.out_nodes()) {
363     if (!switch_node->IsSwitch()) {
364       continue;
365     }
366 
367     NodeDef switch_exit;
368     switch_exit.set_op("Identity");
369 
370     DataType dtype;
371     TF_RETURN_IF_ERROR(GetNodeAttr(switch_node->def(), "T", &dtype));
372     AddNodeAttr("T", dtype, &switch_exit);
373     auto name = strings::StrCat("TPUVariableReshard/switch_exit/", "/_",
374                                 internal::GetNodeId());
375     switch_exit.set_name(graph->NewName(name));
376     // Introducing identity nodes risks a device copy, which isn't guaranteed
377     // to be available for all types. Hence the colocation constraint.
378     AddNodeAttr(kColocationAttrName,
379                 std::vector<string>{
380                     absl::StrCat(kColocationGroupPrefix, switch_node->name())},
381                 &switch_exit);
382 
383     Status status;
384     Node* after_switch_node = graph->AddNode(switch_exit, &status);
385     TF_RETURN_IF_ERROR(status);
386 
387     graph->AddEdge(switch_node, 0, after_switch_node, 0);
388     graph->AddControlEdge(after_switch_node, after_last_iteration_node);
389 
390     for (const auto out_node : switch_node->out_nodes()) {
391       if (out_node->IsExit()) {
392         graph->AddControlEdge(after_last_iteration_node, out_node);
393       }
394     }
395   }
396 
397   *node_out = after_last_iteration_node;
398   return Status::OK();
399 }
400 
401 }  // namespace
402 
DetectHostTrainingLoop(const std::string * current_function_name,const AttrValueMap * current_function_attr,const FunctionLibraryDefinition * library,Graph * graph,FunctionLibraryRuntime * flr,std::vector<HostTrainingLoopInfo> * host_training_loops_info)403 Status DetectHostTrainingLoop(
404     const std::string* current_function_name,
405     const AttrValueMap* current_function_attr,
406     const FunctionLibraryDefinition* library, Graph* graph,
407     FunctionLibraryRuntime* flr,
408     std::vector<HostTrainingLoopInfo>* host_training_loops_info) {
409   std::vector<AssociatedFunctionInfo> associated_function_list;
410   for (const auto* n : graph->nodes()) {
411     const auto associated_functions = GetAssociatedFunctions(*n, library);
412     if (associated_functions.empty()) continue;
413 
414     associated_function_list.insert(associated_function_list.end(),
415                                     associated_functions.begin(),
416                                     associated_functions.end());
417   }
418 
419   Status ret_status = Status::OK();
420   for (const auto& function : associated_function_list) {
421     if (function.type() != AssociatedFunctionInfo::kFunctionAttr) continue;
422 
423     // Convert the function to Graph.
424     FunctionLibraryRuntime::Handle handle;
425     TF_RETURN_IF_ERROR(flr->Instantiate(function.func_name(),
426                                         AttrSlice(&function.attrs()), &handle));
427     auto cleanup_handle = gtl::MakeCleanup([&]() {
428       auto s = flr->ReleaseHandle(handle);
429       if (!s.ok()) {
430         ret_status.Update(s);
431       }
432     });
433     const FunctionBody* body = flr->GetFunctionBody(handle);
434     Graph* function_graph = body->graph;
435     TF_RETURN_IF_ERROR(DetectHostTrainingLoop(
436         &function.func_name(), &function.attrs(), library, function_graph, flr,
437         host_training_loops_info));
438   }
439 
440   // BuildControlFlowInfo() requires that the graph's source node is connected
441   // to all source nodes in the graph. Many graphs violate this invariant.
442   // As so, add edges to source/sink nodes so that this invariant is kept.
443   FixupSourceAndSinkEdges(graph);
444   std::vector<ControlFlowInfo> cf_info;
445   TF_RETURN_IF_ERROR(
446       BuildControlFlowInfo(graph, &cf_info, /*unreachable_nodes=*/nullptr));
447 
448   std::unordered_map<string, WhileLoopFrame> frames;
449   TF_RETURN_IF_ERROR(ExtractWhileLoopFrames(cf_info, graph, &frames));
450   FindTPUCompileNodes(current_function_name, current_function_attr, frames,
451                       host_training_loops_info);
452   return ret_status;
453 }
454 
AddReshardOp(Graph * graph,const HostTrainingLoopInfo & host_loop_info)455 Status AddReshardOp(Graph* graph, const HostTrainingLoopInfo& host_loop_info) {
456   const auto& compile_node_name = host_loop_info.compile_node_name;
457   const auto node_name_map = graph->BuildNodeNameIndex();
458   const auto node_it = node_name_map.find(compile_node_name);
459   TF_RET_CHECK(node_it != node_name_map.end())
460       << "Unable to find compile node : " << compile_node_name;
461 
462   const auto compile_node = node_it->second;
463   std::vector<ExecuteNodeInfo> execute_nodes_info;
464 
465   Status status;
466   TPUCompileMetadataProto metadata;
467   status =
468       ExtractExecuteNodeInfo(compile_node, *graph, host_loop_info.loop_nodes,
469                              &execute_nodes_info, &metadata);
470   if (!status.ok()) {
471     LOG(ERROR) << "Encountered error when trying to extract execute nodes, "
472                   "skipping host loop optimization. Status: "
473                << status.ToString();
474     return Status::OK();
475   }
476 
477   if (execute_nodes_info.empty()) {
478     return Status::OK();
479   }
480 
481   // Update the TPUCompileMetadata such that sharding config of the
482   // sharded resource variable inputs is set to ALLOWED instead of
483   // TENTATIVE.
484   string new_metadata_string;
485   metadata.SerializeToString(&new_metadata_string);
486   compile_node->ClearAttr("metadata");
487   compile_node->AddAttr("metadata", new_metadata_string);
488 
489   // Unsharding of the model weight variables must happen only at the very
490   // last loop iteration. As so, add while loop condition predicate as an
491   // input to the sharding switch node. If loop condition is true, we do not
492   // unshard.
493   const auto& cond_node_name = host_loop_info.loop_cond_node_name;
494   auto loop_cond_node_it = node_name_map.find(cond_node_name);
495   TF_RET_CHECK(loop_cond_node_it != node_name_map.end())
496       << "Cannot find loop condition node : " << cond_node_name;
497   auto* loop_condition_node = loop_cond_node_it->second;
498 
499   // In order to make sure that shard/unshard operations are invoked
500   // at the start of every loop body and at the end of last iteration
501   // of the loop, respectively, create a pair of guiding nodes, which
502   // guaranteed to execute before each iteration and respectively after
503   // all iterations.
504 
505   Node* after_last_iteration_node;
506   TF_RETURN_IF_ERROR(AddNoOpAfterLastIteration(*loop_condition_node, graph,
507                                                &after_last_iteration_node));
508 
509   Node* before_loop_iteration_node;
510   TF_RETURN_IF_ERROR(GetOrCreateBeforeEachIterationNode(
511       *loop_condition_node, graph, &before_loop_iteration_node));
512 
513   // Create const op that represents default sharding value
514   // (i.e. no-op sharding).
515   NodeDef default_sharding;
516   default_sharding.set_op("Const");
517   default_sharding.set_name(graph->NewName(strings::StrCat(
518       "TPUVariableReshard/default_shard_state", "/_", internal::GetNodeId())));
519   AddNodeAttr("dtype", DT_STRING, &default_sharding);
520 
521   Tensor t(DT_STRING, {3});
522   t.vec<tstring>()(0) = kDefaultShardingValue;
523   t.vec<tstring>()(1) = kDefaultShardingValue;
524   t.vec<tstring>()(2) = kDefaultShardingValue;
525   t.AsProtoTensorContent(
526       (*default_sharding.mutable_attr())["value"].mutable_tensor());
527 
528   Node* default_sharding_node = graph->AddNode(default_sharding, &status);
529   TF_RETURN_IF_ERROR(status);
530   // Add control edge between loop condition to make sure that
531   // default_sharding_node node is inside the while loop frame.
532   graph->AddControlEdge(loop_condition_node, default_sharding_node);
533 
534   // Build a no-op node used to add control edges after unshard nodes.
535   NodeDef after_unshard;
536   after_unshard.set_op("NoOp");
537   after_unshard.set_name(graph->NewName(strings::StrCat(
538       "TPUVariableReshard/last_iteration", "/_", internal::GetNodeId())));
539   auto after_unshard_node = graph->AddNode(after_unshard, &status);
540   TF_RETURN_IF_ERROR(status);
541 
542   for (auto info : execute_nodes_info) {
543     auto execute_node = info.execute_node;
544     // Create Reshard op that optionally shards model weight variables
545     // prior to program execution.
546     NodeDef reshard_node_def;
547     reshard_node_def.set_name(graph->NewName(strings::StrCat(
548         "TPUVariableReshard/reshard", "/_", internal::GetNodeId())));
549     reshard_node_def.set_op("TPUReshardVariables");
550     AddNodeAttr("N", static_cast<int>(info.var_inputs.size()),
551                 &reshard_node_def);
552     Node* reshard_op_node = graph->AddNode(reshard_node_def, &status);
553     if (!status.ok()) return status;
554 
555     reshard_op_node->set_assigned_device_name(
556         execute_node->assigned_device_name());
557 
558     // Reshard op must execute at every loop iteration prior to
559     // TPUExecute node.
560     graph->AddControlEdge(before_loop_iteration_node, reshard_op_node);
561     graph->AddControlEdge(reshard_op_node, execute_node);
562 
563     for (int i = 0; i < info.var_inputs.size(); ++i) {
564       const auto variable_edge = info.var_inputs[i];
565       graph->AddEdge(variable_edge->src(), variable_edge->src_output(),
566                      reshard_op_node, i);
567     }
568 
569     const int new_key_input = info.var_inputs.size();
570     // Add program input edge from the compiler(i.e. compilation key).
571     const auto compilation_key_edge =
572         FindEdgeConnecting(compile_node, execute_node);
573     graph->AddEdge(compile_node, compilation_key_edge->src_output(),
574                    reshard_op_node, new_key_input);
575 
576     // Create VarHandleOp to store sharding state. Sharding state holds string
577     // compilation key that identifies whether the graph is re-compiled and the
578     // variables need to be sharded again.
579     NodeDef var_handle_def;
580     var_handle_def.set_op("VarHandleOp");
581     var_handle_def.set_name(graph->NewName(strings::StrCat(
582         "TPUVariableReshard/reshard_state", "/_", internal::GetNodeId())));
583     AddNodeAttr("dtype", DT_STRING, &var_handle_def);
584     AddNodeAttr("shape", TensorShape({}), &var_handle_def);
585     Node* var_handle_node = graph->AddNode(var_handle_def, &status);
586     if (!status.ok()) return status;
587 
588     // Add control edge between `var_handle_def` node and while loop
589     // loop condition so that `var_handle_def` is inside the same while loop
590     // frame.
591     // TODO(hongjunchoi): Consider adding control edge from another node--such
592     // as input control node.
593     graph->AddControlEdge(loop_condition_node, var_handle_node);
594 
595     // Connect data edge between var handle op and reshard op.
596     const int format_state_input = new_key_input + 1;
597     graph->AddEdge(var_handle_node, 0, reshard_op_node, format_state_input);
598 
599     // Create Reshard op that represents unsharding after TPUExecute.
600     NodeDef unshard_node_def;
601     unshard_node_def.set_name(graph->NewName(strings::StrCat(
602         "TPUVariableReshard/unshard", "/_", internal::GetNodeId())));
603     unshard_node_def.set_op("TPUReshardVariables");
604     AddNodeAttr("N", static_cast<int>(info.var_inputs.size()),
605                 &unshard_node_def);
606     Node* unshard_op_node = graph->AddNode(unshard_node_def, &status);
607     TF_RETURN_IF_ERROR(status);
608 
609     unshard_op_node->set_assigned_device_name(
610         execute_node->assigned_device_name());
611 
612     for (int i = 0; i < info.var_inputs.size(); ++i) {
613       const auto variable_edge = info.var_inputs[i];
614       // Connect model weight resource variables to unshard op. Since unshard op
615       // must be only invoked after the very last loop iteration, for each while
616       // loop inputs, we traverse backwards to find the switch node of the host
617       // training loop and connect `output_false` field of the switch node with
618       // unshard op.
619       TF_ASSIGN_OR_RETURN(
620           Node * enter_node,
621           FindEnterNodeFromTPUExecuteNodeInput(variable_edge->src()));
622       graph->AddEdge(enter_node, 0, unshard_op_node, i);
623     }
624 
625     // Add control dependency before/after unshard node and the control nodes.
626     graph->AddControlEdge(after_last_iteration_node, unshard_op_node);
627     graph->AddControlEdge(unshard_op_node, after_unshard_node);
628 
629     graph->AddEdge(default_sharding_node, 0, unshard_op_node, new_key_input);
630 
631     // Add data edge from sharding state var handle op to unshard op.
632     graph->AddEdge(var_handle_node, 0, unshard_op_node, format_state_input);
633   }
634   // Add control dependency from after_unshard_node to all exits nodes. This is
635   // to make sure that the unshard ops will be executed as long as any of the
636   // exits are used.
637   for (auto exit : FindLoopExitNodes(*loop_condition_node)) {
638     graph->AddControlEdge(after_unshard_node, exit);
639   }
640   return Status::OK();
641 }
642 
643 }  // namespace tpu
644 }  // namespace tensorflow
645