1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.h"
17
18 #include <deque>
19 #include <map>
20 #include <unordered_map>
21 #include <unordered_set>
22
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/container/node_hash_set.h"
25 #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
26 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
27 #include "tensorflow/core/graph/algorithm.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/core/lib/gtl/cleanup.h"
31 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
32 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.h"
33
34 namespace tensorflow {
35 namespace tpu {
36
37 namespace {
38
39 constexpr char kDefaultShardingValue[] = "";
40
FindEdgeConnecting(const Node * src,const Node * dst)41 const Edge* FindEdgeConnecting(const Node* src, const Node* dst) {
42 for (const auto e : src->out_edges()) {
43 if (e->dst()->name() == dst->name()) return &(*e);
44 }
45 return nullptr;
46 }
47
48 // Contains TPUExecute node and its DT_RESOURCE input nodes that
49 // correspond to model weights.
50 struct ExecuteNodeInfo {
51 Node* execute_node;
52 std::vector<const Edge*> var_inputs;
53 };
54
55 // Returns whether `node` is in `execute_nodes` or `(identity) -> execute`.
IsExecuteNodeOrIdentityToExecuteNode(const Graph & graph,const std::unordered_set<Node * > & loop_nodes,const absl::flat_hash_set<Node * > & execute_nodes,Node * node)56 bool IsExecuteNodeOrIdentityToExecuteNode(
57 const Graph& graph, const std::unordered_set<Node*>& loop_nodes, // NOLINT
58 const absl::flat_hash_set<Node*>& execute_nodes, Node* node) {
59 if (execute_nodes.find(node) != execute_nodes.end()) return true;
60 if (loop_nodes.find(node) == loop_nodes.end()) return false;
61 if (node->IsNextIteration()) return true;
62 if (!node->IsIdentity()) return false;
63
64 for (const Edge* e : node->out_edges()) {
65 if (e->IsControlEdge()) continue;
66
67 Node* node = e->dst();
68 if (!IsExecuteNodeOrIdentityToExecuteNode(graph, loop_nodes, execute_nodes,
69 node)) {
70 return false;
71 }
72 }
73
74 return true;
75 }
76
77 // From input node to the TPUExecute op, finds the corresponding Enter node
78 // by searching/traversing nodes in below pattern of nodes:
79 // Enter ----> (identity) ---> While body input
80 // Returns nullptr if the Enter node is not found.
FindEnterNodeFromTPUExecuteNodeInput(Node * input_node)81 xla::StatusOr<Node*> FindEnterNodeFromTPUExecuteNodeInput(Node* input_node) {
82 Node* node = input_node;
83 while (node->IsIdentity()) {
84 TF_RETURN_IF_ERROR(node->input_node(0, &node));
85 }
86
87 if (node->IsEnter()) {
88 return node;
89 }
90 return nullptr;
91 }
92
ResourceOnlyUsedForTPUExecuteInLoop(const Graph & graph,const std::unordered_set<Node * > & loop_nodes,const Node * enter_node,const absl::flat_hash_set<Node * > execute_nodes)93 xla::StatusOr<bool> ResourceOnlyUsedForTPUExecuteInLoop(
94 const Graph& graph, const std::unordered_set<Node*>& loop_nodes, // NOLINT
95 const Node* enter_node, const absl::flat_hash_set<Node*> execute_nodes) {
96 for (const Edge* output_edge : enter_node->out_edges()) {
97 Node* output_node = output_edge->dst();
98 if (output_edge->IsControlEdge() || output_node->IsExit()) continue;
99
100 // If output node is not execute node, it must be output node
101 // to the while loop body.
102 if (!IsExecuteNodeOrIdentityToExecuteNode(graph, loop_nodes, execute_nodes,
103 output_node)) {
104 return false;
105 }
106 }
107 return true;
108 }
109
110 // Given a TPUCompile node, find all TPUExecute nodes that executes the compiled
111 // program and its model weight variable inputs as well.
112 // TPUCompileMetadataProto of TPUCompile node must be reset to `new_metadata`
113 // if new reshard ops are added.
ExtractExecuteNodeInfo(const Node * compile_node,const Graph & graph,const std::unordered_set<Node * > & loop_nodes,std::vector<ExecuteNodeInfo> * execute_node_info,TPUCompileMetadataProto * new_metadata)114 Status ExtractExecuteNodeInfo(const Node* compile_node, const Graph& graph,
115 const std::unordered_set<Node*>& loop_nodes, // NOLINT
116 std::vector<ExecuteNodeInfo>* execute_node_info,
117 TPUCompileMetadataProto* new_metadata) {
118 string metadata_string;
119 TF_RETURN_IF_ERROR(
120 GetNodeAttr(compile_node->attrs(), "metadata", &metadata_string));
121 new_metadata->ParsePartialFromString(metadata_string);
122 if (new_metadata->num_cores_per_replica() != 1) {
123 // We do not support model parallelism yet.
124 return Status::OK();
125 }
126
127 execute_node_info->clear();
128 for (Node* node : compile_node->out_nodes()) {
129 if (node->type_string() == "TPUExecute") {
130 execute_node_info->push_back({node});
131 }
132 }
133 if (execute_node_info->empty()) {
134 return Status::OK();
135 }
136 TF_RET_CHECK(execute_node_info->size() == new_metadata->num_replicas())
137 << "Number of replicas does not equal number of execute nodes: "
138 << new_metadata->num_replicas() << " vs " << execute_node_info->size();
139 DataTypeVector arg_types;
140 TF_RETURN_IF_ERROR(GetNodeAttr((*execute_node_info)[0].execute_node->attrs(),
141 "Targs", &arg_types));
142 for (int64_t i = 0; i < arg_types.size(); ++i) {
143 if (arg_types[i] != DT_RESOURCE) {
144 continue;
145 }
146 const auto sharding_config = new_metadata->args(i).enable_xla_sharding();
147 if (sharding_config != TPUCompileMetadataProto::Arg::TENTATIVE &&
148 sharding_config != TPUCompileMetadataProto::Arg::ALLOWED) {
149 continue;
150 }
151 std::vector<const Edge*> edges(execute_node_info->size());
152 bool is_supported = true;
153 std::unordered_map<Node*, absl::flat_hash_set<Node*>>
154 enter_to_execute_nodes;
155 for (int64_t j = 0; j < edges.size(); ++j) {
156 auto execute = (*execute_node_info)[j].execute_node;
157 TF_RETURN_IF_ERROR(execute->input_edge(i, &edges[j]));
158 TF_RET_CHECK(edges[j]->src()->output_type(edges[j]->src_output()) ==
159 arg_types[i])
160 << "Execute op has an unexpected input type.";
161 // Traverse backwards to find the Enter node from which the input is
162 // passed.
163 // This makes sure that we are checking the usages of all potential
164 // aliases of the input node as well.
165 TF_ASSIGN_OR_RETURN(auto enter_node, FindEnterNodeFromTPUExecuteNodeInput(
166 edges[j]->src()));
167 if (enter_node == nullptr) {
168 is_supported = false;
169 enter_to_execute_nodes.clear();
170 break;
171 }
172 enter_to_execute_nodes[enter_node].insert(edges[j]->dst());
173 }
174
175 for (const auto& it : enter_to_execute_nodes) {
176 // Size of execute nodes should be either 1 (per-replica variables) or
177 // num_replicas (distributed variables).
178 if ((it.second.size() != 1) &&
179 (it.second.size() != new_metadata->num_replicas())) {
180 is_supported = false;
181 break;
182 }
183 TF_ASSIGN_OR_RETURN(bool no_other_use,
184 ResourceOnlyUsedForTPUExecuteInLoop(
185 graph, loop_nodes, it.first, it.second));
186 if (!no_other_use) {
187 is_supported = false;
188 break;
189 }
190 }
191
192 // Add the variable input edges only when they are supported for all
193 // executes.
194 if (is_supported) {
195 for (int64_t j = 0; j < edges.size(); ++j) {
196 (*execute_node_info)[j].var_inputs.push_back(edges[j]);
197 }
198 new_metadata->mutable_args(i)->set_enable_xla_sharding(
199 TPUCompileMetadataProto::Arg::ALLOWED);
200 }
201 }
202
203 int64_t total = 0;
204 for (const auto& a : new_metadata->args()) {
205 if (a.enable_xla_sharding() == TPUCompileMetadataProto::Arg::ALLOWED) {
206 total++;
207 }
208 }
209 TF_RET_CHECK(total == (*execute_node_info)[0].var_inputs.size())
210 << " total " << total << " var_inputs "
211 << (*execute_node_info)[0].var_inputs.size();
212 if (total == 0) {
213 // We don't need to process anything if no input is added.
214 execute_node_info->clear();
215 }
216 return Status::OK();
217 }
218
IsTPUCompileOp(const Node & n)219 bool IsTPUCompileOp(const Node& n) { return n.type_string() == "TPUCompile"; }
220
FindTPUCompileNodes(const std::string * current_function_name,const AttrValueMap * current_function_attr,const std::unordered_map<string,WhileLoopFrame> & frames,std::vector<HostTrainingLoopInfo> * host_training_loops_info)221 void FindTPUCompileNodes(
222 const std::string* current_function_name,
223 const AttrValueMap* current_function_attr,
224 const std::unordered_map<string, WhileLoopFrame>& frames,
225 std::vector<HostTrainingLoopInfo>* host_training_loops_info) {
226 // Adds frames with no children (i.e., the innermost frames) to a worklist.
227 std::deque<const WhileLoopFrame*> worklist;
228
229 for (auto& frame : frames) {
230 if (frame.second.num_children == 0) {
231 worklist.push_back(&frame.second);
232 }
233 }
234
235 // Check TPUCompile node from the innermost while loop to the outermost
236 // while loop.
237 while (!worklist.empty()) {
238 const WhileLoopFrame* frame = worklist.front();
239 worklist.pop_front();
240
241 for (const auto& n : frame->nodes) {
242 if (!IsTPUCompileOp(*n)) continue;
243
244 HostTrainingLoopInfo host_training_loop_info;
245 host_training_loop_info.compile_node_name = n->name();
246 host_training_loop_info.loop_cond_node_name = frame->loop_cond->name();
247 host_training_loop_info.while_loop_name = frame->name;
248
249 for (const auto arg : frame->args) {
250 LoopArgInfo arg_info;
251 arg_info.enter_node_name = arg.enter->name();
252 if (arg.exit) arg_info.exit_node_name = arg.exit->name();
253
254 host_training_loop_info.loop_arguments.push_back(std::move(arg_info));
255 }
256 host_training_loop_info.loop_nodes = frame->nodes;
257
258 if (current_function_name) {
259 host_training_loop_info.encapsulating_function_name =
260 *current_function_name;
261 }
262 if (current_function_attr) {
263 host_training_loop_info.encapsulating_function_attrs =
264 *current_function_attr;
265 }
266
267 host_training_loops_info->emplace_back(
268 std::move(host_training_loop_info));
269 }
270
271 // If the parent has no remaining children, add it to the worklist.
272 --frame->parent->num_children;
273 if (frame->parent->num_children == 0) {
274 worklist.push_back(frame->parent);
275 }
276 }
277 }
278
279 // From while loop cond node, finds all loop exit nodes by searching/traversing
280 // nodes in below pattern of nodes:
281 // LoopCond -----> Switch -----> Exit
FindLoopExitNodes(const Node & loop_cond)282 std::vector<Node*> FindLoopExitNodes(const Node& loop_cond) {
283 std::vector<Node*> loop_exit_nodes;
284 for (const auto e_cond : loop_cond.out_edges()) {
285 if (e_cond->IsControlEdge() || !e_cond->dst()->IsSwitch()) continue;
286 auto switch_node = e_cond->dst();
287
288 for (const auto e_switch : switch_node->out_edges()) {
289 if (e_switch->IsControlEdge() || !e_switch->dst()->IsExit()) continue;
290
291 loop_exit_nodes.push_back(e_switch->dst());
292 }
293 }
294 return loop_exit_nodes;
295 }
296
297 // Returns or creates a node in that is executed before each loop iteration
298 // in the while loop.
299 // TODO(mdan): Inject this node between the Enter and Merge nodes instead.
300 // See AddNoOpAfterLastIteration for an example.
GetOrCreateBeforeEachIterationNode(const Node & loop_cond_node,Graph * graph,Node ** node_out)301 Status GetOrCreateBeforeEachIterationNode(const Node& loop_cond_node,
302 Graph* graph, Node** node_out) {
303 Node* loop_switch_node = nullptr;
304 for (auto n : loop_cond_node.out_nodes()) {
305 if (n->IsSwitch()) {
306 loop_switch_node = n;
307 break;
308 }
309 }
310 TF_RET_CHECK(loop_switch_node != nullptr)
311 << "Unable to find any switch nodes.";
312
313 // If while loop switch node already has a outgoing data to true brach
314 // of the switch op, then reuse that node.
315 for (const auto out_edge : loop_switch_node->out_edges()) {
316 if (out_edge->src_output() == 1) {
317 *node_out = out_edge->dst();
318 return Status::OK();
319 }
320 }
321
322 // Create Identity node that represents execution at every loop iteration.
323 NodeDef at_loop_iteration_nodedef;
324 at_loop_iteration_nodedef.set_op("Identity");
325 DataType dtype;
326 TF_RETURN_IF_ERROR(GetNodeAttr(loop_switch_node->def(), "T", &dtype));
327
328 AddNodeAttr("T", dtype, &at_loop_iteration_nodedef);
329 at_loop_iteration_nodedef.set_name(graph->NewName(strings::StrCat(
330 "TPUVariableReshard/before_iteration", "/_", internal::GetNodeId())));
331
332 Status status;
333 Node* at_loop_iteration_node =
334 graph->AddNode(at_loop_iteration_nodedef, &status);
335 TF_RETURN_IF_ERROR(status);
336
337 graph->AddEdge(loop_switch_node, 1, at_loop_iteration_node, 0);
338 *node_out = at_loop_iteration_node;
339 return Status::OK();
340 }
341
342 // Injects a NoOp node in that is executed after the very last iteration
343 // of the while loop but before the while loop exit node.
344 // This node is positioned between the False output of all Switch nodes (
345 // meaning, it executes after the loop ended all its iterations) and their
346 // corresponding Exit nodes (meaning, before the loop finally completed).
347 // See the white paper for details:
348 // http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
AddNoOpAfterLastIteration(const Node & loop_cond_node,Graph * graph,Node ** node_out)349 Status AddNoOpAfterLastIteration(const Node& loop_cond_node, Graph* graph,
350 Node** node_out) {
351 NodeDef after_last_iteration;
352 after_last_iteration.set_op("NoOp");
353
354 after_last_iteration.set_name(graph->NewName(strings::StrCat(
355 "TPUVariableReshard/after_last_iteration", "/_", internal::GetNodeId())));
356
357 Status status;
358 Node* after_last_iteration_node =
359 graph->AddNode(after_last_iteration, &status);
360 TF_RETURN_IF_ERROR(status);
361
362 for (auto switch_node : loop_cond_node.out_nodes()) {
363 if (!switch_node->IsSwitch()) {
364 continue;
365 }
366
367 NodeDef switch_exit;
368 switch_exit.set_op("Identity");
369
370 DataType dtype;
371 TF_RETURN_IF_ERROR(GetNodeAttr(switch_node->def(), "T", &dtype));
372 AddNodeAttr("T", dtype, &switch_exit);
373 auto name = strings::StrCat("TPUVariableReshard/switch_exit/", "/_",
374 internal::GetNodeId());
375 switch_exit.set_name(graph->NewName(name));
376 // Introducing identity nodes risks a device copy, which isn't guaranteed
377 // to be available for all types. Hence the colocation constraint.
378 AddNodeAttr(kColocationAttrName,
379 std::vector<string>{
380 absl::StrCat(kColocationGroupPrefix, switch_node->name())},
381 &switch_exit);
382
383 Status status;
384 Node* after_switch_node = graph->AddNode(switch_exit, &status);
385 TF_RETURN_IF_ERROR(status);
386
387 graph->AddEdge(switch_node, 0, after_switch_node, 0);
388 graph->AddControlEdge(after_switch_node, after_last_iteration_node);
389
390 for (const auto out_node : switch_node->out_nodes()) {
391 if (out_node->IsExit()) {
392 graph->AddControlEdge(after_last_iteration_node, out_node);
393 }
394 }
395 }
396
397 *node_out = after_last_iteration_node;
398 return Status::OK();
399 }
400
401 } // namespace
402
DetectHostTrainingLoop(const std::string * current_function_name,const AttrValueMap * current_function_attr,const FunctionLibraryDefinition * library,Graph * graph,FunctionLibraryRuntime * flr,std::vector<HostTrainingLoopInfo> * host_training_loops_info)403 Status DetectHostTrainingLoop(
404 const std::string* current_function_name,
405 const AttrValueMap* current_function_attr,
406 const FunctionLibraryDefinition* library, Graph* graph,
407 FunctionLibraryRuntime* flr,
408 std::vector<HostTrainingLoopInfo>* host_training_loops_info) {
409 std::vector<AssociatedFunctionInfo> associated_function_list;
410 for (const auto* n : graph->nodes()) {
411 const auto associated_functions = GetAssociatedFunctions(*n, library);
412 if (associated_functions.empty()) continue;
413
414 associated_function_list.insert(associated_function_list.end(),
415 associated_functions.begin(),
416 associated_functions.end());
417 }
418
419 Status ret_status = Status::OK();
420 for (const auto& function : associated_function_list) {
421 if (function.type() != AssociatedFunctionInfo::kFunctionAttr) continue;
422
423 // Convert the function to Graph.
424 FunctionLibraryRuntime::Handle handle;
425 TF_RETURN_IF_ERROR(flr->Instantiate(function.func_name(),
426 AttrSlice(&function.attrs()), &handle));
427 auto cleanup_handle = gtl::MakeCleanup([&]() {
428 auto s = flr->ReleaseHandle(handle);
429 if (!s.ok()) {
430 ret_status.Update(s);
431 }
432 });
433 const FunctionBody* body = flr->GetFunctionBody(handle);
434 Graph* function_graph = body->graph;
435 TF_RETURN_IF_ERROR(DetectHostTrainingLoop(
436 &function.func_name(), &function.attrs(), library, function_graph, flr,
437 host_training_loops_info));
438 }
439
440 // BuildControlFlowInfo() requires that the graph's source node is connected
441 // to all source nodes in the graph. Many graphs violate this invariant.
442 // As so, add edges to source/sink nodes so that this invariant is kept.
443 FixupSourceAndSinkEdges(graph);
444 std::vector<ControlFlowInfo> cf_info;
445 TF_RETURN_IF_ERROR(
446 BuildControlFlowInfo(graph, &cf_info, /*unreachable_nodes=*/nullptr));
447
448 std::unordered_map<string, WhileLoopFrame> frames;
449 TF_RETURN_IF_ERROR(ExtractWhileLoopFrames(cf_info, graph, &frames));
450 FindTPUCompileNodes(current_function_name, current_function_attr, frames,
451 host_training_loops_info);
452 return ret_status;
453 }
454
AddReshardOp(Graph * graph,const HostTrainingLoopInfo & host_loop_info)455 Status AddReshardOp(Graph* graph, const HostTrainingLoopInfo& host_loop_info) {
456 const auto& compile_node_name = host_loop_info.compile_node_name;
457 const auto node_name_map = graph->BuildNodeNameIndex();
458 const auto node_it = node_name_map.find(compile_node_name);
459 TF_RET_CHECK(node_it != node_name_map.end())
460 << "Unable to find compile node : " << compile_node_name;
461
462 const auto compile_node = node_it->second;
463 std::vector<ExecuteNodeInfo> execute_nodes_info;
464
465 Status status;
466 TPUCompileMetadataProto metadata;
467 status =
468 ExtractExecuteNodeInfo(compile_node, *graph, host_loop_info.loop_nodes,
469 &execute_nodes_info, &metadata);
470 if (!status.ok()) {
471 LOG(ERROR) << "Encountered error when trying to extract execute nodes, "
472 "skipping host loop optimization. Status: "
473 << status.ToString();
474 return Status::OK();
475 }
476
477 if (execute_nodes_info.empty()) {
478 return Status::OK();
479 }
480
481 // Update the TPUCompileMetadata such that sharding config of the
482 // sharded resource variable inputs is set to ALLOWED instead of
483 // TENTATIVE.
484 string new_metadata_string;
485 metadata.SerializeToString(&new_metadata_string);
486 compile_node->ClearAttr("metadata");
487 compile_node->AddAttr("metadata", new_metadata_string);
488
489 // Unsharding of the model weight variables must happen only at the very
490 // last loop iteration. As so, add while loop condition predicate as an
491 // input to the sharding switch node. If loop condition is true, we do not
492 // unshard.
493 const auto& cond_node_name = host_loop_info.loop_cond_node_name;
494 auto loop_cond_node_it = node_name_map.find(cond_node_name);
495 TF_RET_CHECK(loop_cond_node_it != node_name_map.end())
496 << "Cannot find loop condition node : " << cond_node_name;
497 auto* loop_condition_node = loop_cond_node_it->second;
498
499 // In order to make sure that shard/unshard operations are invoked
500 // at the start of every loop body and at the end of last iteration
501 // of the loop, respectively, create a pair of guiding nodes, which
502 // guaranteed to execute before each iteration and respectively after
503 // all iterations.
504
505 Node* after_last_iteration_node;
506 TF_RETURN_IF_ERROR(AddNoOpAfterLastIteration(*loop_condition_node, graph,
507 &after_last_iteration_node));
508
509 Node* before_loop_iteration_node;
510 TF_RETURN_IF_ERROR(GetOrCreateBeforeEachIterationNode(
511 *loop_condition_node, graph, &before_loop_iteration_node));
512
513 // Create const op that represents default sharding value
514 // (i.e. no-op sharding).
515 NodeDef default_sharding;
516 default_sharding.set_op("Const");
517 default_sharding.set_name(graph->NewName(strings::StrCat(
518 "TPUVariableReshard/default_shard_state", "/_", internal::GetNodeId())));
519 AddNodeAttr("dtype", DT_STRING, &default_sharding);
520
521 Tensor t(DT_STRING, {3});
522 t.vec<tstring>()(0) = kDefaultShardingValue;
523 t.vec<tstring>()(1) = kDefaultShardingValue;
524 t.vec<tstring>()(2) = kDefaultShardingValue;
525 t.AsProtoTensorContent(
526 (*default_sharding.mutable_attr())["value"].mutable_tensor());
527
528 Node* default_sharding_node = graph->AddNode(default_sharding, &status);
529 TF_RETURN_IF_ERROR(status);
530 // Add control edge between loop condition to make sure that
531 // default_sharding_node node is inside the while loop frame.
532 graph->AddControlEdge(loop_condition_node, default_sharding_node);
533
534 // Build a no-op node used to add control edges after unshard nodes.
535 NodeDef after_unshard;
536 after_unshard.set_op("NoOp");
537 after_unshard.set_name(graph->NewName(strings::StrCat(
538 "TPUVariableReshard/last_iteration", "/_", internal::GetNodeId())));
539 auto after_unshard_node = graph->AddNode(after_unshard, &status);
540 TF_RETURN_IF_ERROR(status);
541
542 for (auto info : execute_nodes_info) {
543 auto execute_node = info.execute_node;
544 // Create Reshard op that optionally shards model weight variables
545 // prior to program execution.
546 NodeDef reshard_node_def;
547 reshard_node_def.set_name(graph->NewName(strings::StrCat(
548 "TPUVariableReshard/reshard", "/_", internal::GetNodeId())));
549 reshard_node_def.set_op("TPUReshardVariables");
550 AddNodeAttr("N", static_cast<int>(info.var_inputs.size()),
551 &reshard_node_def);
552 Node* reshard_op_node = graph->AddNode(reshard_node_def, &status);
553 if (!status.ok()) return status;
554
555 reshard_op_node->set_assigned_device_name(
556 execute_node->assigned_device_name());
557
558 // Reshard op must execute at every loop iteration prior to
559 // TPUExecute node.
560 graph->AddControlEdge(before_loop_iteration_node, reshard_op_node);
561 graph->AddControlEdge(reshard_op_node, execute_node);
562
563 for (int i = 0; i < info.var_inputs.size(); ++i) {
564 const auto variable_edge = info.var_inputs[i];
565 graph->AddEdge(variable_edge->src(), variable_edge->src_output(),
566 reshard_op_node, i);
567 }
568
569 const int new_key_input = info.var_inputs.size();
570 // Add program input edge from the compiler(i.e. compilation key).
571 const auto compilation_key_edge =
572 FindEdgeConnecting(compile_node, execute_node);
573 graph->AddEdge(compile_node, compilation_key_edge->src_output(),
574 reshard_op_node, new_key_input);
575
576 // Create VarHandleOp to store sharding state. Sharding state holds string
577 // compilation key that identifies whether the graph is re-compiled and the
578 // variables need to be sharded again.
579 NodeDef var_handle_def;
580 var_handle_def.set_op("VarHandleOp");
581 var_handle_def.set_name(graph->NewName(strings::StrCat(
582 "TPUVariableReshard/reshard_state", "/_", internal::GetNodeId())));
583 AddNodeAttr("dtype", DT_STRING, &var_handle_def);
584 AddNodeAttr("shape", TensorShape({}), &var_handle_def);
585 Node* var_handle_node = graph->AddNode(var_handle_def, &status);
586 if (!status.ok()) return status;
587
588 // Add control edge between `var_handle_def` node and while loop
589 // loop condition so that `var_handle_def` is inside the same while loop
590 // frame.
591 // TODO(hongjunchoi): Consider adding control edge from another node--such
592 // as input control node.
593 graph->AddControlEdge(loop_condition_node, var_handle_node);
594
595 // Connect data edge between var handle op and reshard op.
596 const int format_state_input = new_key_input + 1;
597 graph->AddEdge(var_handle_node, 0, reshard_op_node, format_state_input);
598
599 // Create Reshard op that represents unsharding after TPUExecute.
600 NodeDef unshard_node_def;
601 unshard_node_def.set_name(graph->NewName(strings::StrCat(
602 "TPUVariableReshard/unshard", "/_", internal::GetNodeId())));
603 unshard_node_def.set_op("TPUReshardVariables");
604 AddNodeAttr("N", static_cast<int>(info.var_inputs.size()),
605 &unshard_node_def);
606 Node* unshard_op_node = graph->AddNode(unshard_node_def, &status);
607 TF_RETURN_IF_ERROR(status);
608
609 unshard_op_node->set_assigned_device_name(
610 execute_node->assigned_device_name());
611
612 for (int i = 0; i < info.var_inputs.size(); ++i) {
613 const auto variable_edge = info.var_inputs[i];
614 // Connect model weight resource variables to unshard op. Since unshard op
615 // must be only invoked after the very last loop iteration, for each while
616 // loop inputs, we traverse backwards to find the switch node of the host
617 // training loop and connect `output_false` field of the switch node with
618 // unshard op.
619 TF_ASSIGN_OR_RETURN(
620 Node * enter_node,
621 FindEnterNodeFromTPUExecuteNodeInput(variable_edge->src()));
622 graph->AddEdge(enter_node, 0, unshard_op_node, i);
623 }
624
625 // Add control dependency before/after unshard node and the control nodes.
626 graph->AddControlEdge(after_last_iteration_node, unshard_op_node);
627 graph->AddControlEdge(unshard_op_node, after_unshard_node);
628
629 graph->AddEdge(default_sharding_node, 0, unshard_op_node, new_key_input);
630
631 // Add data edge from sharding state var handle op to unshard op.
632 graph->AddEdge(var_handle_node, 0, unshard_op_node, format_state_input);
633 }
634 // Add control dependency from after_unshard_node to all exits nodes. This is
635 // to make sure that the unshard ops will be executed as long as any of the
636 // exits are used.
637 for (auto exit : FindLoopExitNodes(*loop_condition_node)) {
638 graph->AddControlEdge(after_unshard_node, exit);
639 }
640 return Status::OK();
641 }
642
643 } // namespace tpu
644 } // namespace tensorflow
645