1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.h"
17
18 #include <deque>
19 #include <map>
20 #include <unordered_map>
21
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/container/node_hash_set.h"
24 #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
25 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
26 #include "tensorflow/core/graph/algorithm.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/lib/gtl/cleanup.h"
30 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
31 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.h"
32
33 namespace tensorflow {
34 namespace tpu {
35
36 namespace {
37
38 constexpr char kDefaultShardingValue[] = "";
39
FindEdgeConnecting(const Node * src,const Node * dst)40 const Edge* FindEdgeConnecting(const Node* src, const Node* dst) {
41 for (const auto e : src->out_edges()) {
42 if (e->dst()->name() == dst->name()) return &(*e);
43 }
44 return nullptr;
45 }
46
47 // Contains TPUExecute node and its DT_RESOURCE input nodes that
48 // correspond to model weights.
49 struct ExecuteNodeInfo {
50 Node* execute_node;
51 std::vector<const Edge*> var_inputs;
52 };
53
54 // Returns whether `node` is in `execute_nodes` or `(identity) -> execute`.
IsExecuteNodeOrIdentityToExecuteNode(const Graph & graph,const std::unordered_set<Node * > & loop_nodes,const absl::flat_hash_set<Node * > & execute_nodes,Node * node)55 bool IsExecuteNodeOrIdentityToExecuteNode(
56 const Graph& graph, const std::unordered_set<Node*>& loop_nodes, // NOLINT
57 const absl::flat_hash_set<Node*>& execute_nodes, Node* node) {
58 if (execute_nodes.find(node) != execute_nodes.end()) return true;
59 if (loop_nodes.find(node) == loop_nodes.end()) return false;
60 if (node->IsNextIteration()) return true;
61 if (!node->IsIdentity()) return false;
62
63 for (const Edge* e : node->out_edges()) {
64 if (e->IsControlEdge()) continue;
65
66 Node* node = e->dst();
67 if (!IsExecuteNodeOrIdentityToExecuteNode(graph, loop_nodes, execute_nodes,
68 node)) {
69 return false;
70 }
71 }
72
73 return true;
74 }
75
76 // From input node to the TPUExecute op, finds the corresponding Enter node
77 // by searching/traversing nodes in below pattern of nodes:
78 // Enter ----> (identity) ---> While body input
79 // Returns nullptr if the Enter node is not found.
FindEnterNodeFromTPUExecuteNodeInput(Node * input_node)80 xla::StatusOr<Node*> FindEnterNodeFromTPUExecuteNodeInput(Node* input_node) {
81 Node* node = input_node;
82 while (node->IsIdentity()) {
83 TF_RETURN_IF_ERROR(node->input_node(0, &node));
84 }
85
86 if (node->IsEnter()) {
87 return node;
88 }
89 return nullptr;
90 }
91
ResourceOnlyUsedForTPUExecuteInLoop(const Graph & graph,const std::unordered_set<Node * > & loop_nodes,const Node * enter_node,const absl::flat_hash_set<Node * > execute_nodes)92 xla::StatusOr<bool> ResourceOnlyUsedForTPUExecuteInLoop(
93 const Graph& graph, const std::unordered_set<Node*>& loop_nodes, // NOLINT
94 const Node* enter_node, const absl::flat_hash_set<Node*> execute_nodes) {
95 for (const Edge* output_edge : enter_node->out_edges()) {
96 Node* output_node = output_edge->dst();
97 if (output_edge->IsControlEdge() || output_node->IsExit()) continue;
98
99 // If output node is not execute node, it must be output node
100 // to the while loop body.
101 if (!IsExecuteNodeOrIdentityToExecuteNode(graph, loop_nodes, execute_nodes,
102 output_node)) {
103 return false;
104 }
105 }
106 return true;
107 }
108
109 // Given a TPUCompile node, find all TPUExecute nodes that executes the compiled
110 // program and its model weight variable inputs as well.
111 // TPUCompileMetadataProto of TPUCompile node must be reset to `new_metadata`
112 // if new reshard ops are added.
ExtractExecuteNodeInfo(const Node * compile_node,const Graph & graph,const std::unordered_set<Node * > & loop_nodes,std::vector<ExecuteNodeInfo> * execute_node_info,TPUCompileMetadataProto * new_metadata)113 Status ExtractExecuteNodeInfo(const Node* compile_node, const Graph& graph,
114 const std::unordered_set<Node*>& loop_nodes, // NOLINT
115 std::vector<ExecuteNodeInfo>* execute_node_info,
116 TPUCompileMetadataProto* new_metadata) {
117 string metadata_string;
118 TF_RETURN_IF_ERROR(
119 GetNodeAttr(compile_node->attrs(), "metadata", &metadata_string));
120 new_metadata->ParsePartialFromString(metadata_string);
121 if (new_metadata->num_cores_per_replica() != 1) {
122 // We do not support model parallelism yet.
123 return Status::OK();
124 }
125
126 execute_node_info->clear();
127 for (Node* node : compile_node->out_nodes()) {
128 if (node->type_string() == "TPUExecute") {
129 execute_node_info->push_back({node});
130 }
131 }
132 if (execute_node_info->empty()) {
133 return Status::OK();
134 }
135 TF_RET_CHECK(execute_node_info->size() == new_metadata->num_replicas())
136 << "Number of replicas does not equal number of execute nodes: "
137 << new_metadata->num_replicas() << " vs " << execute_node_info->size();
138 DataTypeVector arg_types;
139 TF_RETURN_IF_ERROR(GetNodeAttr((*execute_node_info)[0].execute_node->attrs(),
140 "Targs", &arg_types));
141 for (int64 i = 0; i < arg_types.size(); ++i) {
142 if (arg_types[i] != DT_RESOURCE) {
143 continue;
144 }
145 const auto sharding_config = new_metadata->args(i).enable_xla_sharding();
146 if (sharding_config != TPUCompileMetadataProto::Arg::TENTATIVE &&
147 sharding_config != TPUCompileMetadataProto::Arg::ALLOWED) {
148 continue;
149 }
150 std::vector<const Edge*> edges(execute_node_info->size());
151 bool is_supported = true;
152 std::unordered_map<Node*, absl::flat_hash_set<Node*>>
153 enter_to_execute_nodes;
154 for (int64 j = 0; j < edges.size(); ++j) {
155 auto execute = (*execute_node_info)[j].execute_node;
156 TF_RETURN_IF_ERROR(execute->input_edge(i, &edges[j]));
157 TF_RET_CHECK(edges[j]->src()->output_type(edges[j]->src_output()) ==
158 arg_types[i])
159 << "Execute op has an unexpected input type.";
160 // Traverse backwards to find the Enter node from which the input is
161 // passed.
162 // This makes sure that we are checking the usages of all potential
163 // aliases of the input node as well.
164 TF_ASSIGN_OR_RETURN(auto enter_node, FindEnterNodeFromTPUExecuteNodeInput(
165 edges[j]->src()));
166 if (enter_node == nullptr) {
167 is_supported = false;
168 enter_to_execute_nodes.clear();
169 break;
170 }
171 enter_to_execute_nodes[enter_node].insert(edges[j]->dst());
172 }
173
174 for (const auto& it : enter_to_execute_nodes) {
175 // Size of execute nodes should be either 1 (per-replica variables) or
176 // num_replicas (distributed variables).
177 if ((it.second.size() != 1) &&
178 (it.second.size() != new_metadata->num_replicas())) {
179 is_supported = false;
180 break;
181 }
182 TF_ASSIGN_OR_RETURN(bool no_other_use,
183 ResourceOnlyUsedForTPUExecuteInLoop(
184 graph, loop_nodes, it.first, it.second));
185 if (!no_other_use) {
186 is_supported = false;
187 break;
188 }
189 }
190
191 // Add the variable input edges only when they are supported for all
192 // executes.
193 if (is_supported) {
194 for (int64 j = 0; j < edges.size(); ++j) {
195 (*execute_node_info)[j].var_inputs.push_back(edges[j]);
196 }
197 new_metadata->mutable_args(i)->set_enable_xla_sharding(
198 TPUCompileMetadataProto::Arg::ALLOWED);
199 }
200 }
201
202 int64 total = 0;
203 for (const auto& a : new_metadata->args()) {
204 if (a.enable_xla_sharding() == TPUCompileMetadataProto::Arg::ALLOWED) {
205 total++;
206 }
207 }
208 TF_RET_CHECK(total == (*execute_node_info)[0].var_inputs.size())
209 << " total " << total << " var_inputs "
210 << (*execute_node_info)[0].var_inputs.size();
211 if (total == 0) {
212 // We don't need to process anything if no input is added.
213 execute_node_info->clear();
214 }
215 return Status::OK();
216 }
217
IsTPUCompileOp(const Node & n)218 bool IsTPUCompileOp(const Node& n) { return n.type_string() == "TPUCompile"; }
219
FindTPUCompileNodes(const std::string * current_function_name,const AttrValueMap * current_function_attr,const std::unordered_map<string,WhileLoopFrame> & frames,std::vector<HostTrainingLoopInfo> * host_training_loops_info)220 void FindTPUCompileNodes(
221 const std::string* current_function_name,
222 const AttrValueMap* current_function_attr,
223 const std::unordered_map<string, WhileLoopFrame>& frames,
224 std::vector<HostTrainingLoopInfo>* host_training_loops_info) {
225 // Adds frames with no children (i.e., the innermost frames) to a worklist.
226 std::deque<const WhileLoopFrame*> worklist;
227
228 for (auto& frame : frames) {
229 if (frame.second.num_children == 0) {
230 worklist.push_back(&frame.second);
231 }
232 }
233
234 // Check TPUCompile node from the innermost while loop to the outermost
235 // while loop.
236 while (!worklist.empty()) {
237 const WhileLoopFrame* frame = worklist.front();
238 worklist.pop_front();
239
240 for (const auto& n : frame->nodes) {
241 if (!IsTPUCompileOp(*n)) continue;
242
243 HostTrainingLoopInfo host_training_loop_info;
244 host_training_loop_info.compile_node_name = n->name();
245 host_training_loop_info.loop_cond_node_name = frame->loop_cond->name();
246 host_training_loop_info.while_loop_name = frame->name;
247
248 for (const auto arg : frame->args) {
249 LoopArgInfo arg_info;
250 arg_info.enter_node_name = arg.enter->name();
251 if (arg.exit) arg_info.exit_node_name = arg.exit->name();
252
253 host_training_loop_info.loop_arguments.push_back(std::move(arg_info));
254 }
255 host_training_loop_info.loop_nodes = frame->nodes;
256
257 if (current_function_name) {
258 host_training_loop_info.encapsulating_function_name =
259 *current_function_name;
260 }
261 if (current_function_attr) {
262 host_training_loop_info.encapsulating_function_attrs =
263 *current_function_attr;
264 }
265
266 host_training_loops_info->emplace_back(
267 std::move(host_training_loop_info));
268 }
269
270 // If the parent has no remaining children, add it to the worklist.
271 --frame->parent->num_children;
272 if (frame->parent->num_children == 0) {
273 worklist.push_back(frame->parent);
274 }
275 }
276 }
277
278 // From while loop cond node, finds all loop exit nodes by searching/traversing
279 // nodes in below pattern of nodes:
280 // LoopCond -----> Switch -----> Exit
FindLoopExitNodes(const Node & loop_cond)281 std::vector<Node*> FindLoopExitNodes(const Node& loop_cond) {
282 std::vector<Node*> loop_exit_nodes;
283 for (const auto e_cond : loop_cond.out_edges()) {
284 if (e_cond->IsControlEdge() || !e_cond->dst()->IsSwitch()) continue;
285 auto switch_node = e_cond->dst();
286
287 for (const auto e_switch : switch_node->out_edges()) {
288 if (e_switch->IsControlEdge() || !e_switch->dst()->IsExit()) continue;
289
290 loop_exit_nodes.push_back(e_switch->dst());
291 }
292 }
293 return loop_exit_nodes;
294 }
295
296 // Find any one of switch nodes in the while loop by traversing the graph
297 // from while loop condition node.
GetLoopSwitchNode(const Node & loop_cond_node)298 xla::StatusOr<Node*> GetLoopSwitchNode(const Node& loop_cond_node) {
299 Node* loop_switch_node;
300 for (auto n : loop_cond_node.out_nodes()) {
301 if (n->IsSwitch()) {
302 loop_switch_node = n;
303 break;
304 }
305 }
306
307 TF_RET_CHECK(loop_switch_node->IsSwitch())
308 << "Unable to find any switch nodes.";
309 return loop_switch_node;
310 }
311
312 // Returns or creates a node in that is executed before each loop iteration
313 // in the while loop.
GetOrCreateBeforeEachIterationNode(Graph * graph,Node * loop_switch_node,Node ** node_out)314 Status GetOrCreateBeforeEachIterationNode(Graph* graph, Node* loop_switch_node,
315 Node** node_out) {
316 // If while loop switch node already has a outgoing data to true brach
317 // of the switch op, then reuse that node.
318 for (const auto out_edge : loop_switch_node->out_edges()) {
319 if (out_edge->src_output() == 1) {
320 *node_out = out_edge->dst();
321 return Status::OK();
322 }
323 }
324
325 // Create Identity node that represents execution at every loop iteration.
326 NodeDef at_loop_iteration_nodedef;
327 at_loop_iteration_nodedef.set_op("Identity");
328 DataType dtype;
329 TF_RETURN_IF_ERROR(GetNodeAttr(loop_switch_node->def(), "T", &dtype));
330
331 AddNodeAttr("T", dtype, &at_loop_iteration_nodedef);
332 at_loop_iteration_nodedef.set_name(graph->NewName(strings::StrCat(
333 "TPUVariableReshard/before_iteration", "/_", internal::GetNodeId())));
334
335 Status status;
336 Node* at_loop_iteration_node =
337 graph->AddNode(at_loop_iteration_nodedef, &status);
338 TF_RETURN_IF_ERROR(status);
339
340 graph->AddEdge(loop_switch_node, 1, at_loop_iteration_node, 0);
341 *node_out = at_loop_iteration_node;
342 return Status::OK();
343 }
344
345 // Injects NoOp node in that is executed after the very last iteration
346 // of the while loop but before the while loop exit node.
AddNoOpAfterLastIteration(Graph * graph,Node * loop_switch_node,Node ** node_out)347 Status AddNoOpAfterLastIteration(Graph* graph, Node* loop_switch_node,
348 Node** node_out) {
349 // Find the exit node from loop switch node.
350 Node* exit_node;
351 for (const auto out_node : loop_switch_node->out_nodes()) {
352 if (out_node->IsExit()) {
353 exit_node = out_node;
354 break;
355 }
356 }
357
358 TF_RET_CHECK(exit_node != nullptr)
359 << "Cannot find exit node connected to switch node :"
360 << loop_switch_node->name();
361
362 // Create NoOp that represents execution at the end of while loop
363 // last iteration.
364 NodeDef after_last_loop_iteration;
365 after_last_loop_iteration.set_op("Identity");
366 DataType dtype;
367 TF_RETURN_IF_ERROR(GetNodeAttr(loop_switch_node->def(), "T", &dtype));
368
369 AddNodeAttr("T", dtype, &after_last_loop_iteration);
370 after_last_loop_iteration.set_name(graph->NewName(strings::StrCat(
371 "TPUVariableReshard/last_iteration", "/_", internal::GetNodeId())));
372
373 Status status;
374 Node* after_last_iteration_node =
375 graph->AddNode(after_last_loop_iteration, &status);
376 TF_RETURN_IF_ERROR(status);
377
378 // Newly created node must be executed once after last iteration of the while
379 // loop and before while loop exits.
380 graph->AddEdge(loop_switch_node, 0, after_last_iteration_node, 0);
381 graph->AddControlEdge(after_last_iteration_node, exit_node);
382 *node_out = after_last_iteration_node;
383 return Status::OK();
384 }
385
386 } // namespace
387
DetectHostTrainingLoop(const std::string * current_function_name,const AttrValueMap * current_function_attr,const FunctionLibraryDefinition * library,Graph * graph,FunctionLibraryRuntime * flr,std::vector<HostTrainingLoopInfo> * host_training_loops_info)388 Status DetectHostTrainingLoop(
389 const std::string* current_function_name,
390 const AttrValueMap* current_function_attr,
391 const FunctionLibraryDefinition* library, Graph* graph,
392 FunctionLibraryRuntime* flr,
393 std::vector<HostTrainingLoopInfo>* host_training_loops_info) {
394 std::vector<AssociatedFunctionInfo> associated_function_list;
395 for (const auto* n : graph->nodes()) {
396 const auto associated_functions = GetAssociatedFunctions(*n, library);
397 if (associated_functions.empty()) continue;
398
399 associated_function_list.insert(associated_function_list.end(),
400 associated_functions.begin(),
401 associated_functions.end());
402 }
403
404 Status ret_status = Status::OK();
405 for (const auto& function : associated_function_list) {
406 if (function.type() != AssociatedFunctionInfo::kFunctionAttr) continue;
407
408 // Convert the function to Graph.
409 FunctionLibraryRuntime::Handle handle;
410 TF_RETURN_IF_ERROR(flr->Instantiate(function.func_name(),
411 AttrSlice(&function.attrs()), &handle));
412 auto cleanup_handle = gtl::MakeCleanup([&]() {
413 auto s = flr->ReleaseHandle(handle);
414 if (!s.ok()) {
415 ret_status.Update(s);
416 }
417 });
418 const FunctionBody* body = flr->GetFunctionBody(handle);
419 Graph* function_graph = body->graph;
420 TF_RETURN_IF_ERROR(DetectHostTrainingLoop(
421 &function.func_name(), &function.attrs(), library, function_graph, flr,
422 host_training_loops_info));
423 }
424
425 // BuildControlFlowInfo() requires that the graph's source node is connected
426 // to all source nodes in the graph. Many graphs violate this invariant.
427 // As so, add edges to source/sink nodes so that this invariant is kept.
428 FixupSourceAndSinkEdges(graph);
429 std::vector<ControlFlowInfo> cf_info;
430 TF_RETURN_IF_ERROR(
431 BuildControlFlowInfo(graph, &cf_info, /*unreachable_nodes=*/nullptr));
432
433 std::unordered_map<string, WhileLoopFrame> frames;
434 TF_RETURN_IF_ERROR(ExtractWhileLoopFrames(cf_info, graph, &frames));
435 FindTPUCompileNodes(current_function_name, current_function_attr, frames,
436 host_training_loops_info);
437 return ret_status;
438 }
439
AddReshardOp(Graph * graph,const HostTrainingLoopInfo & host_loop_info)440 Status AddReshardOp(Graph* graph, const HostTrainingLoopInfo& host_loop_info) {
441 const auto& compile_node_name = host_loop_info.compile_node_name;
442 const auto node_name_map = graph->BuildNodeNameIndex();
443 const auto node_it = node_name_map.find(compile_node_name);
444 TF_RET_CHECK(node_it != node_name_map.end())
445 << "Unable to find compile node : " << compile_node_name;
446
447 const auto compile_node = node_it->second;
448 std::vector<ExecuteNodeInfo> execute_nodes_info;
449
450 Status status;
451 TPUCompileMetadataProto metadata;
452 status =
453 ExtractExecuteNodeInfo(compile_node, *graph, host_loop_info.loop_nodes,
454 &execute_nodes_info, &metadata);
455 if (!status.ok()) {
456 LOG(ERROR) << "Encountered error when trying to extract execute nodes, "
457 "skipping host loop optimization. Status: "
458 << status.ToString();
459 return Status::OK();
460 }
461
462 if (execute_nodes_info.empty()) {
463 return Status::OK();
464 }
465
466 // Update the TPUCompileMetadata such that sharding config of the
467 // sharded resource variable inputs is set to ALLOWED instead of
468 // TENTATIVE.
469 string new_metadata_string;
470 metadata.SerializeToString(&new_metadata_string);
471 compile_node->ClearAttr("metadata");
472 compile_node->AddAttr("metadata", new_metadata_string);
473
474 // Unsharding of the model weight variables must happen only at the very
475 // last loop iteration. As so, add while loop condition predicate as an
476 // input to the sharding switch node. If loop condition is true, we do not
477 // unshard.
478 const auto& cond_node_name = host_loop_info.loop_cond_node_name;
479 auto loop_cond_node_it = node_name_map.find(cond_node_name);
480 TF_RET_CHECK(loop_cond_node_it != node_name_map.end())
481 << "Cannot find loop condition node : " << cond_node_name;
482 auto* loop_condition_node = loop_cond_node_it->second;
483
484 // In order to make sure that shard/unshard operations are invoked
485 // at the start of every loop body and at the end of last iteration
486 // of the loop, respectively, traverse the graph and find a switch node
487 // of the host training loop.
488 TF_ASSIGN_OR_RETURN(Node * switch_node,
489 GetLoopSwitchNode(*loop_condition_node));
490
491 Node* after_last_iteration_node;
492 TF_RETURN_IF_ERROR(AddNoOpAfterLastIteration(graph, switch_node,
493 &after_last_iteration_node));
494
495 Node* before_loop_iteration_node;
496 TF_RETURN_IF_ERROR(GetOrCreateBeforeEachIterationNode(
497 graph, switch_node, &before_loop_iteration_node));
498
499 // Create const op that represents default sharding value
500 // (i.e. no-op sharding).
501 NodeDef default_sharding;
502 default_sharding.set_op("Const");
503 default_sharding.set_name(graph->NewName(strings::StrCat(
504 "TPUVariableReshard/default_shard_state", "/_", internal::GetNodeId())));
505 AddNodeAttr("dtype", DT_STRING, &default_sharding);
506
507 Tensor t(DT_STRING, {3});
508 t.vec<tstring>()(0) = kDefaultShardingValue;
509 t.vec<tstring>()(1) = kDefaultShardingValue;
510 t.vec<tstring>()(2) = kDefaultShardingValue;
511 t.AsProtoTensorContent(
512 (*default_sharding.mutable_attr())["value"].mutable_tensor());
513
514 Node* default_sharding_node = graph->AddNode(default_sharding, &status);
515 TF_RETURN_IF_ERROR(status);
516 // Add control edge between loop condition to make sure that
517 // default_sharding_node node is inside the while loop frame.
518 graph->AddControlEdge(loop_condition_node, default_sharding_node);
519
520 // Build a no-op node used to add control edges after unshard nodes.
521 NodeDef after_unshard;
522 after_unshard.set_op("NoOp");
523 after_unshard.set_name(graph->NewName(strings::StrCat(
524 "TPUVariableReshard/last_iteration", "/_", internal::GetNodeId())));
525 auto after_unshard_node = graph->AddNode(after_unshard, &status);
526 TF_RETURN_IF_ERROR(status);
527
528 for (auto info : execute_nodes_info) {
529 auto execute_node = info.execute_node;
530 // Create Reshard op that optionally shards model weight variables
531 // prior to program execution.
532 NodeDef reshard_node_def;
533 reshard_node_def.set_name(graph->NewName(strings::StrCat(
534 "TPUVariableReshard/reshard", "/_", internal::GetNodeId())));
535 reshard_node_def.set_op("TPUReshardVariables");
536 AddNodeAttr("N", static_cast<int>(info.var_inputs.size()),
537 &reshard_node_def);
538 Node* reshard_op_node = graph->AddNode(reshard_node_def, &status);
539 if (!status.ok()) return status;
540
541 reshard_op_node->set_assigned_device_name(
542 execute_node->assigned_device_name());
543
544 // Reshard op must execute at every loop iteration prior to
545 // TPUExecute node.
546 graph->AddControlEdge(before_loop_iteration_node, reshard_op_node);
547 graph->AddControlEdge(reshard_op_node, execute_node);
548
549 for (int i = 0; i < info.var_inputs.size(); ++i) {
550 const auto variable_edge = info.var_inputs[i];
551 graph->AddEdge(variable_edge->src(), variable_edge->src_output(),
552 reshard_op_node, i);
553 }
554
555 const int new_key_input = info.var_inputs.size();
556 // Add program input edge from the compiler(i.e. compilation key).
557 const auto compilation_key_edge =
558 FindEdgeConnecting(compile_node, execute_node);
559 graph->AddEdge(compile_node, compilation_key_edge->src_output(),
560 reshard_op_node, new_key_input);
561
562 // Create VarHandleOp to store sharding state. Sharding state holds string
563 // compilation key that identifies whether the graph is re-compiled and the
564 // variables need to be sharded again.
565 NodeDef var_handle_def;
566 var_handle_def.set_op("VarHandleOp");
567 var_handle_def.set_name(graph->NewName(strings::StrCat(
568 "TPUVariableReshard/reshard_state", "/_", internal::GetNodeId())));
569 AddNodeAttr("dtype", DT_STRING, &var_handle_def);
570 AddNodeAttr("shape", TensorShape({}), &var_handle_def);
571 Node* var_handle_node = graph->AddNode(var_handle_def, &status);
572 if (!status.ok()) return status;
573
574 // Add control edge between `var_handle_def` node and while loop
575 // loop condition so that `var_handle_def` is inside the same while loop
576 // frame.
577 // TODO(hongjunchoi): Consider adding control edge from another node--such
578 // as input control node.
579 graph->AddControlEdge(loop_condition_node, var_handle_node);
580
581 // Connect data edge between var handle op and reshard op.
582 const int format_state_input = new_key_input + 1;
583 graph->AddEdge(var_handle_node, 0, reshard_op_node, format_state_input);
584
585 // Create Reshard op that represents unsharding after TPUExecute.
586 NodeDef unshard_node_def;
587 unshard_node_def.set_name(graph->NewName(strings::StrCat(
588 "TPUVariableReshard/unshard", "/_", internal::GetNodeId())));
589 unshard_node_def.set_op("TPUReshardVariables");
590 AddNodeAttr("N", static_cast<int>(info.var_inputs.size()),
591 &unshard_node_def);
592 Node* unshard_op_node = graph->AddNode(unshard_node_def, &status);
593 TF_RETURN_IF_ERROR(status);
594
595 unshard_op_node->set_assigned_device_name(
596 execute_node->assigned_device_name());
597
598 for (int i = 0; i < info.var_inputs.size(); ++i) {
599 const auto variable_edge = info.var_inputs[i];
600 // Connect model weight resource variables to unshard op. Since unshard op
601 // must be only invoked after the very last loop iteration, for each while
602 // loop inputs, we traverse backwards to find the switch node of the host
603 // training loop and connect `output_false` field of the switch node with
604 // unshard op.
605 TF_ASSIGN_OR_RETURN(
606 Node * enter_node,
607 FindEnterNodeFromTPUExecuteNodeInput(variable_edge->src()));
608 graph->AddEdge(enter_node, 0, unshard_op_node, i);
609 }
610
611 // Add control dependency before/after unshard node and the control nodes.
612 graph->AddControlEdge(after_last_iteration_node, unshard_op_node);
613 graph->AddControlEdge(unshard_op_node, after_unshard_node);
614
615 graph->AddEdge(default_sharding_node, 0, unshard_op_node, new_key_input);
616
617 // Add data edge from sharding state var handle op to unshard op.
618 graph->AddEdge(var_handle_node, 0, unshard_op_node, format_state_input);
619 }
620 // Add control dependency from after_unshard_node to all exits nodes. This is
621 // to make sure that the unshard ops will be executed as long as any of the
622 // exits are used.
623 for (auto exit : FindLoopExitNodes(*loop_condition_node)) {
624 graph->AddControlEdge(after_unshard_node, exit);
625 }
626 return Status::OK();
627 }
628
629 } // namespace tpu
630 } // namespace tensorflow
631